@@ -781,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
781781 old_denoised = denoised
782782 return x
783783
784+
784785@torch .no_grad ()
785786def sample_dpmpp_2m_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , solver_type = 'midpoint' ):
786787 """DPM-Solver++(2M) SDE."""
@@ -796,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
796797 noise_sampler = BrownianTreeNoiseSampler (x , sigma_min , sigma_max , seed = seed , cpu = True ) if noise_sampler is None else noise_sampler
797798 s_in = x .new_ones ([x .shape [0 ]])
798799
800+ model_sampling = model .inner_model .model_patcher .get_model_object ('model_sampling' )
801+ lambda_fn = partial (sigma_to_half_log_snr , model_sampling = model_sampling )
802+ sigmas = offset_first_sigma_for_snr (sigmas , model_sampling )
803+
799804 old_denoised = None
800- h_last = None
801- h = None
805+ h , h_last = None , None
802806
803807 for i in trange (len (sigmas ) - 1 , disable = disable ):
804808 denoised = model (x , sigmas [i ] * s_in , ** extra_args )
@@ -809,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
809813 x = denoised
810814 else :
811815 # DPM-Solver++(2M) SDE
812- t , s = - sigmas [i ].log (), - sigmas [i + 1 ].log ()
813- h = s - t
814- eta_h = eta * h
816+ lambda_s , lambda_t = lambda_fn (sigmas [i ]), lambda_fn (sigmas [i + 1 ])
817+ h = lambda_t - lambda_s
818+ h_eta = h * (eta + 1 )
819+
820+ alpha_t = sigmas [i + 1 ] * lambda_t .exp ()
815821
816- x = sigmas [i + 1 ] / sigmas [i ] * (- eta_h ).exp () * x + ( - h - eta_h ).expm1 ().neg () * denoised
822+ x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x + alpha_t * ( - h_eta ).expm1 ().neg () * denoised
817823
818824 if old_denoised is not None :
819825 r = h_last / h
820826 if solver_type == 'heun' :
821- x = x + ((- h - eta_h ).expm1 ().neg () / (- h - eta_h ) + 1 ) * (1 / r ) * (denoised - old_denoised )
827+ x = x + alpha_t * ((- h_eta ).expm1 ().neg () / (- h_eta ) + 1 ) * (1 / r ) * (denoised - old_denoised )
822828 elif solver_type == 'midpoint' :
823- x = x + 0.5 * ( - h - eta_h ).expm1 ().neg () * (1 / r ) * (denoised - old_denoised )
829+ x = x + 0.5 * alpha_t * ( - h_eta ).expm1 ().neg () * (1 / r ) * (denoised - old_denoised )
824830
825- if eta :
826- x = x + noise_sampler (sigmas [i ], sigmas [i + 1 ]) * sigmas [i + 1 ] * (- 2 * eta_h ).expm1 ().neg ().sqrt () * s_noise
831+ if eta > 0 and s_noise > 0 :
832+ x = x + noise_sampler (sigmas [i ], sigmas [i + 1 ]) * sigmas [i + 1 ] * (- 2 * h * eta ).expm1 ().neg ().sqrt () * s_noise
827833
828834 old_denoised = denoised
829835 h_last = h
830836 return x
831837
838+
832839@torch .no_grad ()
833840def sample_dpmpp_3m_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None ):
834841 """DPM-Solver++(3M) SDE."""
@@ -842,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
842849 noise_sampler = BrownianTreeNoiseSampler (x , sigma_min , sigma_max , seed = seed , cpu = True ) if noise_sampler is None else noise_sampler
843850 s_in = x .new_ones ([x .shape [0 ]])
844851
852+ model_sampling = model .inner_model .model_patcher .get_model_object ('model_sampling' )
853+ lambda_fn = partial (sigma_to_half_log_snr , model_sampling = model_sampling )
854+ sigmas = offset_first_sigma_for_snr (sigmas , model_sampling )
855+
845856 denoised_1 , denoised_2 = None , None
846857 h , h_1 , h_2 = None , None , None
847858
@@ -853,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
853864 # Denoising step
854865 x = denoised
855866 else :
856- t , s = - sigmas [i ]. log ( ), - sigmas [i + 1 ]. log ( )
857- h = s - t
867+ lambda_s , lambda_t = lambda_fn ( sigmas [i ]), lambda_fn ( sigmas [i + 1 ])
868+ h = lambda_t - lambda_s
858869 h_eta = h * (eta + 1 )
859870
860- x = torch .exp (- h_eta ) * x + (- h_eta ).expm1 ().neg () * denoised
871+ alpha_t = sigmas [i + 1 ] * lambda_t .exp ()
872+
873+ x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x + alpha_t * (- h_eta ).expm1 ().neg () * denoised
861874
862875 if h_2 is not None :
876+ # DPM-Solver++(3M) SDE
863877 r0 = h_1 / h
864878 r1 = h_2 / h
865879 d1_0 = (denoised - denoised_1 ) / r0
@@ -868,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
868882 d2 = (d1_0 - d1_1 ) / (r0 + r1 )
869883 phi_2 = h_eta .neg ().expm1 () / h_eta + 1
870884 phi_3 = phi_2 / h_eta - 0.5
871- x = x + phi_2 * d1 - phi_3 * d2
885+ x = x + ( alpha_t * phi_2 ) * d1 - ( alpha_t * phi_3 ) * d2
872886 elif h_1 is not None :
887+ # DPM-Solver++(2M) SDE
873888 r = h_1 / h
874889 d = (denoised - denoised_1 ) / r
875890 phi_2 = h_eta .neg ().expm1 () / h_eta + 1
876- x = x + phi_2 * d
891+ x = x + ( alpha_t * phi_2 ) * d
877892
878- if eta :
893+ if eta > 0 and s_noise > 0 :
879894 x = x + noise_sampler (sigmas [i ], sigmas [i + 1 ]) * sigmas [i + 1 ] * (- 2 * h * eta ).expm1 ().neg ().sqrt () * s_noise
880895
881896 denoised_1 , denoised_2 = denoised , denoised_1
882897 h_1 , h_2 = h , h_1
883898 return x
884899
900+
885901@torch .no_grad ()
886902def sample_dpmpp_3m_sde_gpu (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None ):
887903 if len (sigmas ) <= 1 :
@@ -891,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
891907 noise_sampler = BrownianTreeNoiseSampler (x , sigma_min , sigma_max , seed = extra_args .get ("seed" , None ), cpu = False ) if noise_sampler is None else noise_sampler
892908 return sample_dpmpp_3m_sde (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , eta = eta , s_noise = s_noise , noise_sampler = noise_sampler )
893909
910+
894911@torch .no_grad ()
895912def sample_dpmpp_2m_sde_gpu (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , solver_type = 'midpoint' ):
896913 if len (sigmas ) <= 1 :
@@ -900,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
900917 noise_sampler = BrownianTreeNoiseSampler (x , sigma_min , sigma_max , seed = extra_args .get ("seed" , None ), cpu = False ) if noise_sampler is None else noise_sampler
901918 return sample_dpmpp_2m_sde (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , eta = eta , s_noise = s_noise , noise_sampler = noise_sampler , solver_type = solver_type )
902919
920+
903921@torch .no_grad ()
904922def sample_dpmpp_sde_gpu (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , r = 1 / 2 ):
905923 if len (sigmas ) <= 1 :
0 commit comments