4242from schedulers .scheduling_ddim import DDIMScheduler
4343from schedulers .scheduling_lms_discrete import LMSDiscreteScheduler
4444from schedulers .scheduling_pndm import PNDMScheduler
45+ from schedulers .k_euler_ancestral import KEulerAncestralSampler
46+ from schedulers .k_euler import KEulerSampler
4547
4648def process_inp_img (input_image ):
4749 input_image = Image .open (input_image )
@@ -128,6 +130,12 @@ def get_scheduler(name):
128130 skip_prk_steps = True ,
129131 tensor_format = "np" )
130132
133+ if name == "k_euler_ancestral" :
134+ return KEulerAncestralSampler ()
135+
136+ if name == "k_euler" :
137+ return KEulerSampler ()
138+
131139
132140def dummy_callback (state = "" , progress = - 1 ):
133141 pass
@@ -283,18 +291,14 @@ def prepare_init_latent(self , sd_run):
283291 # latent_np = latent_np * np.float64(self.scheduler.init_noise_sigma)
284292 sd_run .latent = latent_np
285293
286- if isinstance (self .scheduler , LMSDiscreteScheduler ):
287- sd_run .latent = sd_run .latent * self .scheduler .sigmas [0 ]
294+ sd_run .latent = sd_run .latent * self .scheduler .initial_scale
288295
289296 else :
290297 latent = self .get_encoded_img (sd_run , sd_run .input_image_processed )
291298 sd_run .encoded_img_unmasked = np .copy (latent )
292299
293- if isinstance (self .scheduler , LMSDiscreteScheduler ):
294- start_timestep = np .array ([self .t_to_i (sd_run .start_timestep )] * sd_run .batch_size , dtype = np .int64 )
295- else :
296- start_timestep = np .array ([sd_run .start_timestep ] * sd_run .batch_size , dtype = np .int64 )
297-
300+ start_timestep = np .array ([self .t_to_i (sd_run .start_timestep )] * sd_run .batch_size , dtype = np .int64 )
301+
298302 noise = self .get_noise (sd_run .seed , latent .shape )
299303 latent = self .scheduler .add_noise (latent , noise , start_timestep )
300304 sd_run .latent = latent
@@ -344,12 +348,8 @@ def get_unet_out(self, sd_run):
344348 np .repeat ( (1 - sd_run .processed_mask_downscaled ), sd_run .batch_size , axis = 0 ) ,
345349 sd_run .encoded_masked_img
346350 ], axis = - 1 )
347-
348- if isinstance (self .scheduler , LMSDiscreteScheduler ):
349- sigma = self .scheduler .sigmas [self .t_to_i (t )]
350- latent_model_input = latent_model_input / ((sigma ** 2 + 1 ) ** 0.5 )
351-
352- # latent_model_input = self.scheduler.scale_model_input(sd_run.latent, t)
351+
352+ latent_model_input = latent_model_input * self .scheduler .get_input_scale (self .t_to_i (t ))
353353
354354 if sd_run .combine_unet_run :
355355 latent_combined = np .concatenate ([latent_model_input ,latent_model_input ])
@@ -374,22 +374,16 @@ def get_next_latent(self, sd_run ):
374374 eta = 0.0 # should be between 0 and 1, but 0 for now
375375 extra_step_kwargs ["eta" ] = eta
376376
377- if isinstance (self .scheduler , LMSDiscreteScheduler ):
378- latents = self .scheduler .step (noise_pred , self .t_to_i (t ), sd_run .latent , ** extra_step_kwargs )["prev_sample" ]
379- else :
380- latents = self .scheduler .step (noise_pred , t , sd_run .latent , ** extra_step_kwargs )["prev_sample" ]
377+ step_seed = sd_run .seed + 10000 + self .t_to_i (t )* 4242
378+ latents = self .scheduler .step (noise_pred , self .t_to_i (t ), sd_run .latent , seed = step_seed , ** extra_step_kwargs )["prev_sample" ]
379+
381380
382381 if sd_run .do_masking :
383382
384383 latent_proper = np .copy (sd_run .encoded_img_unmasked )
385384
386385 noise = self .get_noise (sd_run .seed , latent_proper .shape )
387-
388- if isinstance (self .scheduler , LMSDiscreteScheduler ):
389- latent_proper = self .scheduler .add_noise (latent_proper , noise , np .array ([self .t_to_i (sd_run .current_t )] * sd_run .batch_size , dtype = np .int64 ) )
390- else :
391- latent_proper = self .scheduler .add_noise (latent_proper , noise , np .array ([sd_run .current_t ] * sd_run .batch_size , dtype = np .int64 ) )
392-
386+ latent_proper = self .scheduler .add_noise (latent_proper , noise , np .array ([self .t_to_i (sd_run .current_t )] * sd_run .batch_size , dtype = np .int64 ) )
393387
394388 latents = (latent_proper * sd_run .processed_mask_downscaled ) + (latents * (1 - sd_run .processed_mask_downscaled ))
395389
@@ -413,7 +407,7 @@ def generate(
413407 mask_image = None ,
414408 negative_prompt = "" ,
415409 input_image_strength = 0.5 ,
416- scheduler = 'pndm ' ,
410+ scheduler = 'k_euler ' ,
417411 tdict_path = None , # if none then it will just use current one
418412 dtype = 'float16' ,
419413 mode = "txt2img" # txt2img , img2img, inpaint_15
0 commit comments