3131from schedulers .k_euler import KEulerSampler
3232
3333from utils .logging import log_object
34-
34+ from utils . extra_model_utils import add_lora_ti_weights
3535
3636image_preprocess_model_paths = {
3737 "body_pose" : None ,
@@ -86,6 +86,7 @@ class SDRun():
8686 num_steps :int = 25
8787 guidance_scale :float = 7.5
8888 seed :int = None
89+ seed_type :str = "np"
8990 soft_seed :int = None
9091 img_id :int = 1
9192
@@ -96,6 +97,7 @@ class SDRun():
9697
9798 input_image_strength :float = 0.5
9899 second_tdict_path :str = None
100+ weight_additions :tuple = ()
99101
100102
101103def get_scheduler (name ):
@@ -179,6 +181,7 @@ def __init__(self , ModelInterfaceClass , tdict_path , model_name="sd_1x", ca
179181 self .current_model_name = model_name
180182 self .current_tdict_path = tdict_path
181183 self .second_current_tdict_path = None
184+ self .current_weight_additions = () #represents weights which are added on top, eg. Lora, TI etc
182185 self .current_dtype = self .ModelInterfaceClass .default_float_type
183186
184187 if model_name is not None :
@@ -235,18 +238,37 @@ def prepare_model_interface(self , sd_run=None ):
235238 self .current_dtype = dtype
236239 self .current_model_name = model_name
237240
238- if tdict_path != self .current_tdict_path or second_tdict_path != self .second_current_tdict_path :
241+ weight_additions = sd_run .weight_additions
242+
243+ if tdict_path != self .current_tdict_path or second_tdict_path != self .second_current_tdict_path or weight_additions != self .current_weight_additions :
239244 assert tdict_path is not None
240-
241- self .current_tdict_path = tdict_path
242- self .second_current_tdict_path = second_tdict_path
243245
244- if second_tdict_path is not None :
245- tdict2 = TDict (second_tdict_path )
246+ tdict_1 = None
247+
248+ if (tdict_path == self .current_tdict_path and second_tdict_path == self .second_current_tdict_path and self .current_weight_additions == ()):
249+ pass
250+ # current weigh has already been loaded with some tdicts , and nothing has been added
246251 else :
247- tdict2 = None
252+ if second_tdict_path is not None :
253+ tdict2 = TDict (second_tdict_path )
254+ else :
255+ tdict2 = None
256+
257+ tdict_1 = TDict (tdict_path )
258+
259+ self .model .load_from_tdict (tdict_1 , tdict2 )
248260
249- self .model .load_from_tdict (TDict (tdict_path ), tdict2 )
261+ self .current_tdict_path = tdict_path
262+ self .second_current_tdict_path = second_tdict_path
263+
264+ if weight_additions is not None and weight_additions != ():
265+ if tdict_1 is None :
266+ tdict_1 = TDict (tdict_path )
267+
268+ print ("Loading LoRA weights" )
269+ extra_weights = add_lora_ti_weights (tdict_1 , weight_additions )
270+ self .model .load_from_state_dict (extra_weights )
271+ self .current_weight_additions = weight_additions
250272
251273
252274 def tokenize (self , prompt ):
@@ -346,8 +368,18 @@ def prepare_timesteps(self, sd_run):
346368
347369
348370
349- def get_noise (self , seed , shape ):
350- return np .random .RandomState (seed ).normal (size = shape ).astype ('float32' )
371+ def get_noise (self , seed , shape , seed_type ):
372+ if seed_type == 'np' :
373+ return np .random .RandomState (seed ).normal (size = shape ).astype ('float32' )
374+ elif seed_type == 'pt' :
375+ import torch
376+ torch .manual_seed (seed )
377+ if len (shape ) == 4 :
378+ shape = (shape [0 ] , shape [3 ] , shape [1 ] , shape [2 ])
379+ a = torch .randn ( (1 ,4 ,64 ,64 ) ).numpy ().astype ('float32' )
380+ return np .rollaxis (a , 1 , 4 )
381+ else :
382+ raise ValueError ("Invalid seed type" )
351383
352384
353385 def get_encoded_img (self , sd_run , processes_img ):
@@ -357,7 +389,7 @@ def get_encoded_img(self, sd_run , processes_img):
357389 enc_out_logvar = enc_out [... , 4 :]
358390 enc_out_std = np .exp (0.5 * enc_out_logvar )
359391
360- latent = enc_out_mu + enc_out_std * self .get_noise (sd_run .seed + 1 , enc_out_mu .shape )
392+ latent = enc_out_mu + enc_out_std * self .get_noise (sd_run .seed + 1 , enc_out_mu .shape , seed_type = sd_run . seed_type )
361393 latent = 0.18215 * latent
362394 return latent
363395
@@ -368,15 +400,15 @@ def prepare_init_latent(self , sd_run):
368400 n_w = sd_run .img_width // 8
369401
370402 if not sd_run .starting_img_given :
371- latent_np = self .get_noise (sd_run .seed ,(sd_run .batch_size , n_h , n_w , 4 ) )
403+ latent_np = self .get_noise (sd_run .seed ,(sd_run .batch_size , n_h , n_w , 4 ) , seed_type = sd_run . seed_type )
372404
373405 if self .debug_output_path is not None :
374406 log_object (latent_np , self .debug_output_path , key = "latent_np" )
375407
376408 if sd_run .soft_seed is not None and sd_run .soft_seed >= 0 :
377- # latent_np = latent_np + 0.1*self.get_noise(sd_run.soft_seed, latent_np.shape ) #option 1
409+ # latent_np = latent_np + 0.1*self.get_noise(sd_run.soft_seed, latent_np.shape, seed_type=sd_run.seed_type ) #option 1
378410 nmask = (np .random .RandomState (sd_run .soft_seed ).rand (* latent_np .shape ) > 0.99 )
379- latent_np = latent_np * (1 - nmask ) + nmask * self .get_noise (sd_run .soft_seed , latent_np .shape )
411+ latent_np = latent_np * (1 - nmask ) + nmask * self .get_noise (sd_run .soft_seed , latent_np .shape , seed_type = sd_run . seed_type )
380412
381413 # latent_np = latent_np * np.float64(self.scheduler.init_noise_sigma)
382414 sd_run .latent = latent_np
@@ -393,7 +425,7 @@ def prepare_init_latent(self , sd_run):
393425
394426 start_timestep = np .array ([self .t_to_i (sd_run .start_timestep )] * sd_run .batch_size , dtype = np .int64 )
395427
396- noise = self .get_noise (sd_run .seed , latent .shape )
428+ noise = self .get_noise (sd_run .seed , latent .shape , seed_type = sd_run . seed_type )
397429
398430 if self .debug_output_path is not None :
399431 log_object (noise , self .debug_output_path , key = "noise_e" )
@@ -509,7 +541,7 @@ def get_next_latent(self, sd_run ):
509541
510542 latent_proper = np .copy (sd_run .encoded_img_unmasked )
511543
512- noise = self .get_noise (sd_run .seed , latent_proper .shape )
544+ noise = self .get_noise (sd_run .seed , latent_proper .shape , seed_type = sd_run . seed_type )
513545 latent_proper = self .scheduler .add_noise (latent_proper , noise , np .array ([self .t_to_i (sd_run .current_t )] * sd_run .batch_size , dtype = np .int64 ) )
514546
515547 latents = (latent_proper * sd_run .processed_mask_downscaled ) + (latents * (1 - sd_run .processed_mask_downscaled ))
@@ -531,6 +563,7 @@ def generate(
531563 guidance_scale = 7.5 ,
532564 temperature = None ,
533565 seed = None ,
566+ seed_type = "np" ,
534567 soft_seed = None ,
535568 img_id = 0 ,
536569 input_image = None ,
@@ -540,6 +573,7 @@ def generate(
540573 scheduler = 'k_euler' ,
541574 tdict_path = None , # if none then it will just use current one
542575 second_tdict_path = None ,
576+ lora_tdict_paths = {}, # {tdict_path: ratio}
543577 inp_img_preprocesser = None , # for controlnet
544578 dtype = 'float16' ,
545579 mode = "txt2img" , # txt2img , img2img, inpaint_15
@@ -557,6 +591,10 @@ def generate(
557591 if tdict_path is None :
558592 tdict_path = self .current_tdict_path
559593
594+ weight_additions = ()
595+ for tpath in lora_tdict_paths :
596+ weight_additions += (('lora' ,tpath , lora_tdict_paths [tpath ] ),)
597+
560598 sd_run = SDRun (
561599 prompt = prompt ,
562600 img_height = img_height ,
@@ -565,6 +603,7 @@ def generate(
565603 num_steps = num_steps ,
566604 guidance_scale = guidance_scale ,
567605 seed = seed ,
606+ seed_type = seed_type ,
568607 soft_seed = soft_seed ,
569608 img_id = img_id ,
570609 input_image = input_image ,
@@ -573,6 +612,7 @@ def generate(
573612 input_image_strength = input_image_strength ,
574613 tdict_path = tdict_path ,
575614 second_tdict_path = second_tdict_path ,
615+ weight_additions = weight_additions ,
576616 mode = mode ,
577617 dtype = dtype ,
578618 inp_img_preprocesser = inp_img_preprocesser ,
0 commit comments