Skip to content

Commit 4f507cc

Browse files
committed
added support for LoRA
1 parent 7b9351f commit 4f507cc

File tree

6 files changed

+343
-32
lines changed

6 files changed

+343
-32
lines changed

backends/model_converter/convert_model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515

1616

1717

18-
def convert_model(checkpoint_filename, out_filename ):
18+
def convert_model(checkpoint_filename=None, out_filename=None, torch_weights=None):
1919

20-
if checkpoint_filename.lower().endswith(".ckpt"):
21-
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
22-
elif checkpoint_filename.lower().endswith(".safetensors"):
23-
torch_weights = SafetensorWrapper(checkpoint_filename)
24-
else:
25-
raise ValueError("Invalid import format")
20+
if torch_weights is None:
21+
if checkpoint_filename.lower().endswith(".ckpt"):
22+
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
23+
elif checkpoint_filename.lower().endswith(".safetensors"):
24+
torch_weights = SafetensorWrapper(checkpoint_filename)
25+
else:
26+
raise ValueError("Invalid import format")
2627

2728
if 'state_dict' in torch_weights:
2829
state_dict = torch_weights['state_dict']

backends/model_converter/tdict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def write_block(self, block_array):
5050
"n_end_data": n_end_data,
5151
"n_bytes": len(data_bytes)}
5252

53+
def keys(self):
54+
return self.keys_info.keys()
55+
5356

5457

5558
def read_block(self, header_pos, np_shape=None, np_dtype=None):

backends/stable_diffusion/stable_diffusion.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from schedulers.k_euler import KEulerSampler
3232

3333
from utils.logging import log_object
34-
34+
from utils.extra_model_utils import add_lora_ti_weights
3535

3636
image_preprocess_model_paths = {
3737
"body_pose" : None ,
@@ -86,6 +86,7 @@ class SDRun():
8686
num_steps:int =25
8787
guidance_scale:float=7.5
8888
seed:int=None
89+
seed_type:str="np"
8990
soft_seed:int=None
9091
img_id:int = 1
9192

@@ -96,6 +97,7 @@ class SDRun():
9697

9798
input_image_strength:float=0.5
9899
second_tdict_path:str = None
100+
weight_additions:tuple = ()
99101

100102

101103
def get_scheduler(name):
@@ -179,6 +181,7 @@ def __init__(self , ModelInterfaceClass , tdict_path , model_name="sd_1x", ca
179181
self.current_model_name = model_name
180182
self.current_tdict_path = tdict_path
181183
self.second_current_tdict_path = None
184+
self.current_weight_additions = () #represents weights which are added on top, eg. Lora, TI etc
182185
self.current_dtype = self.ModelInterfaceClass.default_float_type
183186

184187
if model_name is not None:
@@ -235,18 +238,37 @@ def prepare_model_interface(self , sd_run=None ):
235238
self.current_dtype = dtype
236239
self.current_model_name = model_name
237240

238-
if tdict_path != self.current_tdict_path or second_tdict_path != self.second_current_tdict_path:
241+
weight_additions = sd_run.weight_additions
242+
243+
if tdict_path != self.current_tdict_path or second_tdict_path != self.second_current_tdict_path or weight_additions != self.current_weight_additions:
239244
assert tdict_path is not None
240-
241-
self.current_tdict_path = tdict_path
242-
self.second_current_tdict_path = second_tdict_path
243245

244-
if second_tdict_path is not None:
245-
tdict2 = TDict(second_tdict_path)
246+
tdict_1 = None
247+
248+
if (tdict_path == self.current_tdict_path and second_tdict_path == self.second_current_tdict_path and self.current_weight_additions == ()):
249+
pass
250+
# current weigh has already been loaded with some tdicts , and nothing has been added
246251
else:
247-
tdict2 = None
252+
if second_tdict_path is not None:
253+
tdict2 = TDict(second_tdict_path)
254+
else:
255+
tdict2 = None
256+
257+
tdict_1 = TDict(tdict_path)
258+
259+
self.model.load_from_tdict(tdict_1, tdict2 )
248260

249-
self.model.load_from_tdict(TDict(tdict_path), tdict2 )
261+
self.current_tdict_path = tdict_path
262+
self.second_current_tdict_path = second_tdict_path
263+
264+
if weight_additions is not None and weight_additions != ():
265+
if tdict_1 is None:
266+
tdict_1 = TDict(tdict_path)
267+
268+
print("Loading LoRA weights")
269+
extra_weights = add_lora_ti_weights(tdict_1 , weight_additions )
270+
self.model.load_from_state_dict(extra_weights )
271+
self.current_weight_additions = weight_additions
250272

251273

252274
def tokenize(self , prompt):
@@ -346,8 +368,18 @@ def prepare_timesteps(self, sd_run):
346368

347369

348370

349-
def get_noise(self, seed, shape):
350-
return np.random.RandomState(seed).normal(size=shape).astype('float32')
371+
def get_noise(self, seed, shape, seed_type):
372+
if seed_type == 'np':
373+
return np.random.RandomState(seed).normal(size=shape).astype('float32')
374+
elif seed_type == 'pt':
375+
import torch
376+
torch.manual_seed(seed)
377+
if len(shape) == 4:
378+
shape = (shape[0] , shape[3] , shape[1] , shape[2])
379+
a = torch.randn( (1,4,64,64) ).numpy().astype('float32')
380+
return np.rollaxis(a , 1 , 4 )
381+
else:
382+
raise ValueError("Invalid seed type")
351383

352384

353385
def get_encoded_img(self, sd_run , processes_img):
@@ -357,7 +389,7 @@ def get_encoded_img(self, sd_run , processes_img):
357389
enc_out_logvar = enc_out[... , 4:]
358390
enc_out_std = np.exp(0.5 * enc_out_logvar)
359391

360-
latent = enc_out_mu + enc_out_std*self.get_noise(sd_run.seed+1, enc_out_mu.shape )
392+
latent = enc_out_mu + enc_out_std*self.get_noise(sd_run.seed+1, enc_out_mu.shape , seed_type=sd_run.seed_type)
361393
latent = 0.18215 * latent
362394
return latent
363395

@@ -368,15 +400,15 @@ def prepare_init_latent(self , sd_run):
368400
n_w = sd_run.img_width // 8
369401

370402
if not sd_run.starting_img_given:
371-
latent_np = self.get_noise(sd_run.seed ,(sd_run.batch_size, n_h, n_w, 4) )
403+
latent_np = self.get_noise(sd_run.seed ,(sd_run.batch_size, n_h, n_w, 4) , seed_type=sd_run.seed_type)
372404

373405
if self.debug_output_path is not None:
374406
log_object(latent_np , self.debug_output_path , key="latent_np")
375407

376408
if sd_run.soft_seed is not None and sd_run.soft_seed >= 0:
377-
# latent_np = latent_np + 0.1*self.get_noise(sd_run.soft_seed, latent_np.shape ) #option 1
409+
# latent_np = latent_np + 0.1*self.get_noise(sd_run.soft_seed, latent_np.shape, seed_type=sd_run.seed_type ) #option 1
378410
nmask = (np.random.RandomState(sd_run.soft_seed).rand(*latent_np.shape ) > 0.99)
379-
latent_np = latent_np*(1-nmask) + nmask*self.get_noise(sd_run.soft_seed, latent_np.shape )
411+
latent_np = latent_np*(1-nmask) + nmask*self.get_noise(sd_run.soft_seed, latent_np.shape , seed_type=sd_run.seed_type )
380412

381413
# latent_np = latent_np * np.float64(self.scheduler.init_noise_sigma)
382414
sd_run.latent = latent_np
@@ -393,7 +425,7 @@ def prepare_init_latent(self , sd_run):
393425

394426
start_timestep = np.array([self.t_to_i(sd_run.start_timestep)] * sd_run.batch_size, dtype=np.int64 )
395427

396-
noise = self.get_noise(sd_run.seed , latent.shape )
428+
noise = self.get_noise(sd_run.seed , latent.shape , seed_type=sd_run.seed_type )
397429

398430
if self.debug_output_path is not None:
399431
log_object(noise , self.debug_output_path , key="noise_e")
@@ -509,7 +541,7 @@ def get_next_latent(self, sd_run ):
509541

510542
latent_proper = np.copy(sd_run.encoded_img_unmasked)
511543

512-
noise = self.get_noise(sd_run.seed , latent_proper.shape )
544+
noise = self.get_noise(sd_run.seed , latent_proper.shape , seed_type=sd_run.seed_type )
513545
latent_proper = self.scheduler.add_noise(latent_proper, noise, np.array([self.t_to_i(sd_run.current_t)] * sd_run.batch_size, dtype=np.int64 ) )
514546

515547
latents = (latent_proper * sd_run.processed_mask_downscaled) + (latents * (1 - sd_run.processed_mask_downscaled))
@@ -531,6 +563,7 @@ def generate(
531563
guidance_scale=7.5,
532564
temperature=None,
533565
seed=None,
566+
seed_type="np",
534567
soft_seed=None,
535568
img_id=0,
536569
input_image=None,
@@ -540,6 +573,7 @@ def generate(
540573
scheduler='k_euler',
541574
tdict_path=None, # if none then it will just use current one
542575
second_tdict_path=None,
576+
lora_tdict_paths={}, # {tdict_path: ratio}
543577
inp_img_preprocesser=None, # for controlnet
544578
dtype='float16',
545579
mode="txt2img", # txt2img , img2img, inpaint_15
@@ -557,6 +591,10 @@ def generate(
557591
if tdict_path is None:
558592
tdict_path = self.current_tdict_path
559593

594+
weight_additions = ()
595+
for tpath in lora_tdict_paths:
596+
weight_additions += (('lora' ,tpath , lora_tdict_paths[tpath] ),)
597+
560598
sd_run = SDRun(
561599
prompt=prompt,
562600
img_height=img_height,
@@ -565,6 +603,7 @@ def generate(
565603
num_steps=num_steps,
566604
guidance_scale=guidance_scale,
567605
seed=seed,
606+
seed_type=seed_type,
568607
soft_seed=soft_seed,
569608
img_id=img_id,
570609
input_image=input_image,
@@ -573,6 +612,7 @@ def generate(
573612
input_image_strength=input_image_strength,
574613
tdict_path=tdict_path,
575614
second_tdict_path=second_tdict_path,
615+
weight_additions=weight_additions,
576616
mode=mode,
577617
dtype=dtype,
578618
inp_img_preprocesser=inp_img_preprocesser,

0 commit comments

Comments
 (0)