Skip to content

Commit 00c46d9

Browse files
committed
Controlnet added
1 parent 5d0a237 commit 00c46d9

File tree

10 files changed

+912
-96
lines changed

10 files changed

+912
-96
lines changed

backends/model_converter/convert_model.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,49 @@
1818
def convert_model(checkpoint_filename, out_filename ):
1919
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
2020

21-
torch_weights['state_dict']['temb_coefficients_fp32'] = np.array([1.0, 0.94406086, 0.8912509, 0.84139514, 0.7943282, 0.7498942, 0.70794576, 0.6683439, 0.63095737, 0.5956621, 0.56234133, 0.53088444, 0.50118726, 0.47315124, 0.44668362, 0.42169648, 0.39810717, 0.3758374, 0.35481337, 0.33496544, 0.31622776, 0.29853827, 0.2818383, 0.2660725, 0.25118864, 0.23713735, 0.2238721, 0.21134889, 0.19952625, 0.18836491, 0.17782794, 0.16788039, 0.15848932, 0.14962357, 0.14125374, 0.13335215, 0.12589253, 0.11885023, 0.11220184, 0.10592538, 0.099999994, 0.094406076, 0.08912509, 0.0841395, 0.07943282, 0.074989416, 0.07079458, 0.06683439, 0.06309574, 0.05956621, 0.056234125, 0.05308844, 0.05011872, 0.047315124, 0.044668354, 0.04216965, 0.039810725, 0.037583742, 0.035481337, 0.033496536, 0.031622775, 0.029853828, 0.028183822, 0.026607247, 0.025118865, 0.02371374, 0.022387212, 0.021134889, 0.01995262, 0.018836489, 0.017782794, 0.01678804, 0.015848929, 0.014962353, 0.014125377, 0.013335214, 0.012589253, 0.01188502, 0.011220186, 0.010592538, 0.01, 0.009440607, 0.0089125065, 0.008413952, 0.007943282, 0.007498941, 0.007079456, 0.00668344, 0.0063095735, 0.005956621, 0.0056234123, 0.0053088428, 0.005011873, 0.0047315126, 0.0044668354, 0.004216964, 0.003981072, 0.0037583741, 0.0035481334, 0.0033496537, 0.0031622767, 0.0029853827, 0.0028183828, 0.0026607246, 0.0025118857, 0.0023713738, 0.0022387211, 0.0021134887, 0.0019952618, 0.0018836485, 0.0017782794, 0.0016788039, 0.0015848937, 0.001496236, 0.0014125376, 0.0013335207, 0.0012589252, 0.001188502, 0.0011220181, 0.0010592537, 0.0009999999, 0.00094406115, 0.0008912511, 0.0008413952, 0.0007943278, 0.00074989407, 0.0007079456, 0.0006683437, 0.00063095737, 0.0005956621, 0.0005623415, 0.00053088454, 0.0005011872, 0.000473151, 0.00044668352, 0.00042169637, 0.00039810702, 0.0003758374, 0.00035481335, 0.00033496553, 0.00031622782, 0.00029853827, 0.00028183826, 0.00026607246, 0.00025118855, 0.00023713727, 0.00022387199, 0.00021134898, 0.00019952627, 0.00018836492, 0.00017782794, 0.00016788038, 0.00015848929, 0.00014962352, 0.0001412537, 0.00013335208, 0.00012589258, 0.00011885024, 0.00011220186, 0.00010592537]).astype('float32')
22-
torch_weights['state_dict']['causal_mask'] = np.triu(np.ones((1,1,77,77), dtype=np.float16) * -65500.0, k=1).astype(np.float32)
23-
torch_weights['state_dict']['aux_output_conv.weight'] = np.array([0.14013671875, 0.0711669921875, -0.03271484375, -0.11407470703125, 0.126220703125, 0.10101318359375, 0.034515380859375, -0.1383056640625, 0.126220703125, 0.07733154296875, 0.042633056640625, -0.177978515625]).astype(np.float32)
24-
torch_weights['state_dict']['aux_output_conv.bias'] = np.array([0.423828125, 0.471923828125, 0.473876953125]).astype(np.float32)
25-
torch_weights['state_dict']['alphas_cumprod'] = np.array(_ALPHAS_CUMPROD).astype(np.float32)
26-
torch_weights['state_dict']['temb_coefficients_fp16'] = np.array( [1.0, 0.944, 0.891, 0.8413, 0.7944, 0.75, 0.708, 0.6685, 0.631, 0.5957, 0.5625, 0.531, 0.501, 0.4731, 0.4468, 0.4216, 0.3982, 0.3757, 0.3547, 0.335, 0.3162, 0.2986, 0.2817, 0.266, 0.2512, 0.2372, 0.2239, 0.2113, 0.1996, 0.1884, 0.1779, 0.1678, 0.1584, 0.1497, 0.1412, 0.1333, 0.1259, 0.11884, 0.1122, 0.1059, 0.1, 0.0944, 0.0891, 0.08417, 0.0794, 0.075, 0.0708, 0.06683, 0.0631, 0.05957, 0.05624, 0.0531, 0.0501, 0.0473, 0.04468, 0.04218, 0.03983, 0.0376, 0.0355, 0.0335, 0.03162, 0.02986, 0.02818, 0.02661, 0.02512, 0.02371, 0.02238, 0.02113, 0.01996, 0.01883, 0.01778, 0.01678, 0.01585, 0.01496, 0.01412, 0.013336, 0.01259, 0.01189, 0.01122, 0.01059, 0.01, 0.00944, 0.00891, 0.008415, 0.00794, 0.0075, 0.00708, 0.006683, 0.00631, 0.005955, 0.005623, 0.00531, 0.005013, 0.00473, 0.004467, 0.004215, 0.003983, 0.003757, 0.003548, 0.00335, 0.003162, 0.002985, 0.00282, 0.00266, 0.002512, 0.00237, 0.00224, 0.002113, 0.001995, 0.0018835, 0.001779, 0.001678, 0.001585, 0.001496, 0.001412, 0.001333, 0.001259, 0.001188, 0.001122, 0.00106, 0.001, 0.000944, 0.000891, 0.0008416, 0.0007944, 0.00075, 0.000708, 0.0006685, 0.000631, 0.0005956, 0.000562, 0.0005307, 0.000501, 0.0004733, 0.0004468, 0.0004218, 0.0003982, 0.0003757, 0.0003548, 0.000335, 0.0003161, 0.0002985, 0.0002818, 0.000266, 0.0002513, 0.0002371, 0.0002239, 0.0002114, 0.0001996, 0.0001884, 0.0001779, 0.0001678, 0.0001585, 0.0001496, 0.0001413, 0.0001334, 0.0001259, 0.00011885, 0.0001122, 0.0001059]).astype('float16')
21+
if 'state_dict' in torch_weights:
22+
state_dict = torch_weights['state_dict']
23+
else:
24+
state_dict = torch_weights
25+
26+
state_dict['temb_coefficients_fp32'] = np.array([1.0, 0.94406086, 0.8912509, 0.84139514, 0.7943282, 0.7498942, 0.70794576, 0.6683439, 0.63095737, 0.5956621, 0.56234133, 0.53088444, 0.50118726, 0.47315124, 0.44668362, 0.42169648, 0.39810717, 0.3758374, 0.35481337, 0.33496544, 0.31622776, 0.29853827, 0.2818383, 0.2660725, 0.25118864, 0.23713735, 0.2238721, 0.21134889, 0.19952625, 0.18836491, 0.17782794, 0.16788039, 0.15848932, 0.14962357, 0.14125374, 0.13335215, 0.12589253, 0.11885023, 0.11220184, 0.10592538, 0.099999994, 0.094406076, 0.08912509, 0.0841395, 0.07943282, 0.074989416, 0.07079458, 0.06683439, 0.06309574, 0.05956621, 0.056234125, 0.05308844, 0.05011872, 0.047315124, 0.044668354, 0.04216965, 0.039810725, 0.037583742, 0.035481337, 0.033496536, 0.031622775, 0.029853828, 0.028183822, 0.026607247, 0.025118865, 0.02371374, 0.022387212, 0.021134889, 0.01995262, 0.018836489, 0.017782794, 0.01678804, 0.015848929, 0.014962353, 0.014125377, 0.013335214, 0.012589253, 0.01188502, 0.011220186, 0.010592538, 0.01, 0.009440607, 0.0089125065, 0.008413952, 0.007943282, 0.007498941, 0.007079456, 0.00668344, 0.0063095735, 0.005956621, 0.0056234123, 0.0053088428, 0.005011873, 0.0047315126, 0.0044668354, 0.004216964, 0.003981072, 0.0037583741, 0.0035481334, 0.0033496537, 0.0031622767, 0.0029853827, 0.0028183828, 0.0026607246, 0.0025118857, 0.0023713738, 0.0022387211, 0.0021134887, 0.0019952618, 0.0018836485, 0.0017782794, 0.0016788039, 0.0015848937, 0.001496236, 0.0014125376, 0.0013335207, 0.0012589252, 0.001188502, 0.0011220181, 0.0010592537, 0.0009999999, 0.00094406115, 0.0008912511, 0.0008413952, 0.0007943278, 0.00074989407, 0.0007079456, 0.0006683437, 0.00063095737, 0.0005956621, 0.0005623415, 0.00053088454, 0.0005011872, 0.000473151, 0.00044668352, 0.00042169637, 0.00039810702, 0.0003758374, 0.00035481335, 0.00033496553, 0.00031622782, 0.00029853827, 0.00028183826, 0.00026607246, 0.00025118855, 0.00023713727, 0.00022387199, 0.00021134898, 0.00019952627, 0.00018836492, 0.00017782794, 0.00016788038, 0.00015848929, 0.00014962352, 0.0001412537, 0.00013335208, 0.00012589258, 0.00011885024, 0.00011220186, 0.00010592537]).astype('float32')
27+
state_dict['causal_mask'] = np.triu(np.ones((1,1,77,77), dtype=np.float16) * -65500.0, k=1).astype(np.float32)
28+
state_dict['aux_output_conv.weight'] = np.array([0.14013671875, 0.0711669921875, -0.03271484375, -0.11407470703125, 0.126220703125, 0.10101318359375, 0.034515380859375, -0.1383056640625, 0.126220703125, 0.07733154296875, 0.042633056640625, -0.177978515625]).astype(np.float32)
29+
state_dict['aux_output_conv.bias'] = np.array([0.423828125, 0.471923828125, 0.473876953125]).astype(np.float32)
30+
state_dict['alphas_cumprod'] = np.array(_ALPHAS_CUMPROD).astype(np.float32)
31+
state_dict['temb_coefficients_fp16'] = np.array( [1.0, 0.944, 0.891, 0.8413, 0.7944, 0.75, 0.708, 0.6685, 0.631, 0.5957, 0.5625, 0.531, 0.501, 0.4731, 0.4468, 0.4216, 0.3982, 0.3757, 0.3547, 0.335, 0.3162, 0.2986, 0.2817, 0.266, 0.2512, 0.2372, 0.2239, 0.2113, 0.1996, 0.1884, 0.1779, 0.1678, 0.1584, 0.1497, 0.1412, 0.1333, 0.1259, 0.11884, 0.1122, 0.1059, 0.1, 0.0944, 0.0891, 0.08417, 0.0794, 0.075, 0.0708, 0.06683, 0.0631, 0.05957, 0.05624, 0.0531, 0.0501, 0.0473, 0.04468, 0.04218, 0.03983, 0.0376, 0.0355, 0.0335, 0.03162, 0.02986, 0.02818, 0.02661, 0.02512, 0.02371, 0.02238, 0.02113, 0.01996, 0.01883, 0.01778, 0.01678, 0.01585, 0.01496, 0.01412, 0.013336, 0.01259, 0.01189, 0.01122, 0.01059, 0.01, 0.00944, 0.00891, 0.008415, 0.00794, 0.0075, 0.00708, 0.006683, 0.00631, 0.005955, 0.005623, 0.00531, 0.005013, 0.00473, 0.004467, 0.004215, 0.003983, 0.003757, 0.003548, 0.00335, 0.003162, 0.002985, 0.00282, 0.00266, 0.002512, 0.00237, 0.00224, 0.002113, 0.001995, 0.0018835, 0.001779, 0.001678, 0.001585, 0.001496, 0.001412, 0.001333, 0.001259, 0.001188, 0.001122, 0.00106, 0.001, 0.000944, 0.000891, 0.0008416, 0.0007944, 0.00075, 0.000708, 0.0006685, 0.000631, 0.0005956, 0.000562, 0.0005307, 0.000501, 0.0004733, 0.0004468, 0.0004218, 0.0003982, 0.0003757, 0.0003548, 0.000335, 0.0003161, 0.0002985, 0.0002818, 0.000266, 0.0002513, 0.0002371, 0.0002239, 0.0002114, 0.0001996, 0.0001884, 0.0001779, 0.0001678, 0.0001585, 0.0001496, 0.0001413, 0.0001334, 0.0001259, 0.00011885, 0.0001122, 0.0001059]).astype('float16')
2732

2833
extra_keys = ['temb_coefficients_fp32', 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod']
2934

3035

3136

32-
for k in (list(torch_weights['state_dict'].keys())):
37+
for k in (list(state_dict.keys())):
3338
if '.norm' in k and '.bias' in k:
3439
k2 = k.replace(".bias" , ".weight")
3540
k3 = k.replace(".bias" , ".bias_by_weight")
36-
torch_weights['state_dict'][k3] = torch_weights['state_dict'][k]/torch_weights['state_dict'][k2]
41+
state_dict[k3] = state_dict[k]/state_dict[k2]
3742

3843
if ".ff." in k:
39-
pp = torch_weights['state_dict'][k]
40-
torch_weights['state_dict'][k + "._split_1"] = pp[:pp.shape[0]//2].copy()
41-
torch_weights['state_dict'][k + "._split_2"] = pp[pp.shape[0]//2:].copy()
44+
pp = state_dict[k]
45+
state_dict[k + "._split_1"] = pp[:pp.shape[0]//2].copy()
46+
state_dict[k + "._split_2"] = pp[pp.shape[0]//2:].copy()
4247

48+
keys_list = list(state_dict.keys())
49+
mid_key = keys_list[len(keys_list)//2]
4350

4451
for i in range(1,21):
4552
nn = 320*i
46-
dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype
47-
torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
48-
torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype)
53+
dtype = state_dict[mid_key].dtype
54+
state_dict["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
55+
state_dict["ones_"+str(nn)] = np.ones(nn).astype(dtype)
4956

5057
nn = 128*i
51-
dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype
52-
torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
53-
torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype)
58+
dtype = state_dict[mid_key].dtype
59+
state_dict["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
60+
state_dict["ones_"+str(nn)] = np.ones(nn).astype(dtype)
5461

5562

56-
model_type = get_model_type(torch_weights['state_dict'])
63+
model_type = get_model_type(state_dict)
5764

5865
if model_type is None:
5966
raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5 .ckpt file")
@@ -68,7 +75,7 @@ def convert_model(checkpoint_filename, out_filename ):
6875
outfile.init_write(ctdict_version=ctdict_id )
6976

7077
for k in model_shapes:
71-
np_arr = torch_weights['state_dict'][k]
78+
np_arr = state_dict[k]
7279
shape = list(np_arr.shape)
7380
assert tuple(shape) == tuple(model_shapes[k]), ( "shape mismatch at" , k , shape , SD_SHAPES[k] )
7481
outfile.write_key(key=k , tensor=np_arr)

backends/model_converter/sd_shapes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def add_aux_shapes(d):
4949
ctdict_ids = {"SD_1x_float32": 12 ,
5050
"SD_1x_inpaint_float32": 13,
5151
"SD_1x_float16": 1012 ,
52-
"SD_1x_inpaint_float16": 1013}
52+
"SD_1x_inpaint_float16": 1013 ,
53+
"SD_1x_just_controlnet_16" : 1014}
5354

5455

5556
extra_keys = ['temb_coefficients_fp32' , 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod']
@@ -59,8 +60,10 @@ def add_aux_shapes(d):
5960
def are_shapes_matching(state_dict , template_shapes):
6061
for k in template_shapes:
6162
if k not in state_dict:
63+
print("key", k , "not found in state_dict" , state_dict.keys())
6264
return False
6365
if tuple(template_shapes[k]) != tuple(state_dict[k].shape):
66+
print("shape mismatch", k , tuple(template_shapes[k]) ,tuple(state_dict[k].shape) )
6467
return False
6568

6669
return True

backends/stable_diffusion/stable_diffusion.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class SDRun():
8080
mask_image:str =None
8181

8282
input_image_strength:float=0.5
83+
second_tdict_path:str = None
8384

8485

8586
def get_scheduler(name):
@@ -136,6 +137,7 @@ def __init__(self , ModelInterfaceClass , tdict_path , model_name="sd_1x", ca
136137

137138
self.current_model_name = model_name
138139
self.current_tdict_path = tdict_path
140+
self.second_current_tdict_path = None
139141
self.current_dtype = self.ModelInterfaceClass.default_float_type
140142

141143
self.model = self.ModelInterfaceClass( TDict(self.current_tdict_path ), dtype=self.current_dtype, model_name=self.current_model_name )
@@ -145,25 +147,43 @@ def prepare_model_interface(self , sd_run=None ):
145147

146148
if sd_run.mode == 'inpaint_15':
147149
model_name = "sd_1x_inpaint"
150+
elif sd_run.mode == "controlnet":
151+
model_name = "sd_1x_controlnet"
148152
else:
149153
model_name = "sd_1x"
150154

151155
dtype = sd_run.dtype
152156
tdict_path = sd_run.tdict_path
157+
second_tdict_path = sd_run.second_tdict_path
153158

154159
if self.current_model_name != model_name or self.current_dtype != dtype :
155160
print("Creating model interface")
156161
assert tdict_path is not None
157162
self.model.destroy()
158-
self.model = self.ModelInterfaceClass(TDict(tdict_path ) , dtype=dtype, model_name=model_name )
163+
164+
if second_tdict_path is not None:
165+
tdict2 = TDict(second_tdict_path)
166+
else:
167+
tdict2 = None
168+
169+
self.model = self.ModelInterfaceClass(TDict(tdict_path ) , dtype=dtype, model_name=model_name , second_tdict=tdict2)
159170
self.current_tdict_path = tdict_path
171+
self.second_current_tdict_path = second_tdict_path
160172
self.current_dtype = dtype
161173
self.current_model_name = model_name
162174

163-
if tdict_path != self.current_tdict_path:
175+
if tdict_path != self.current_tdict_path or second_tdict_path != self.second_current_tdict_path:
164176
assert tdict_path is not None
165-
self.model.load_from_tdict(TDict(tdict_path))
177+
166178
self.current_tdict_path = tdict_path
179+
self.second_current_tdict_path = second_tdict_path
180+
181+
if second_tdict_path is not None:
182+
tdict2 = TDict(second_tdict_path)
183+
else:
184+
tdict2 = None
185+
186+
self.model.load_from_tdict(TDict(tdict_path), tdict2 )
167187

168188

169189
def tokenize(self , prompt):
@@ -333,17 +353,25 @@ def get_unet_out(self, sd_run):
333353

334354
latent_model_input = latent_model_input * self.scheduler.get_input_scale(self.t_to_i(t))
335355

356+
if sd_run.mode == "controlnet":
357+
hint_img = (sd_run.input_image_processed+1)/2
358+
if sd_run.combine_unet_run:
359+
hint_img = np.repeat(hint_img, sd_run.batch_size, axis=0)
360+
controls = self.model.run_controlnet(unet_inp=latent_model_input, time_emb=t_emb, text_emb=sd_run.context, hint_img=hint_img )
361+
else:
362+
controls = None
363+
336364
if sd_run.combine_unet_run:
337365
latent_combined = np.concatenate([latent_model_input,latent_model_input])
338366
temb_combined = np.concatenate([t_emb,t_emb])
339367
text_emb_combined = np.concatenate([sd_run.unconditional_context , sd_run.context ])
340368

341-
o = self.model.run_unet(unet_inp=latent_combined, time_emb=temb_combined, text_emb=text_emb_combined )
369+
o = self.model.run_unet(unet_inp=latent_combined, time_emb=temb_combined, text_emb=text_emb_combined , control_inp=controls)
342370
sd_run.predicted_unconditional_latent = o[0: o.shape[0]//2 ]
343371
sd_run.predicted_latent = o[o.shape[0]//2 :]
344372
else:
345-
sd_run.predicted_unconditional_latent = self.model.run_unet(unet_inp=latent_model_input, time_emb=t_emb, text_emb=sd_run.unconditional_context )
346-
sd_run.predicted_latent = self.model.run_unet(unet_inp=latent_model_input, time_emb=t_emb, text_emb=sd_run.context)
373+
sd_run.predicted_unconditional_latent = self.model.run_unet(unet_inp=latent_model_input, time_emb=t_emb, text_emb=sd_run.unconditional_context , control_inp=controls)
374+
sd_run.predicted_latent = self.model.run_unet(unet_inp=latent_model_input, time_emb=t_emb, text_emb=sd_run.context, control_inp=controls)
347375

348376

349377
def get_next_latent(self, sd_run ):
@@ -391,13 +419,17 @@ def generate(
391419
input_image_strength=0.5,
392420
scheduler='k_euler',
393421
tdict_path=None, # if none then it will just use current one
422+
second_tdict_path=None,
394423
dtype='float16',
395424
mode="txt2img" # txt2img , img2img, inpaint_15
396425
):
397426

398427
self.scheduler = get_scheduler(scheduler)
399428

400-
assert mode in ['txt2img' , 'img2img' , 'inpaint_15']
429+
assert mode in ['txt2img' , 'img2img' , 'inpaint_15', 'controlnet']
430+
431+
if dtype not in self.ModelInterfaceClass.avail_float_types:
432+
dtype = self.ModelInterfaceClass.default_float_type
401433

402434
if tdict_path is None:
403435
tdict_path = self.current_tdict_path
@@ -417,6 +449,7 @@ def generate(
417449
negative_prompt=negative_prompt,
418450
input_image_strength=input_image_strength,
419451
tdict_path=tdict_path,
452+
second_tdict_path=second_tdict_path,
420453
mode=mode,
421454
dtype=dtype,
422455
)
@@ -428,6 +461,9 @@ def generate(
428461
if mask_image is not None and mask_image != "":
429462
sd_run.do_masking = True
430463

464+
if mode == "controlnet":
465+
assert input_image is not None and input_image != ""
466+
431467
signal = self.callback(state="Starting" , progress=-1 )
432468
if signal == "stop":
433469
return
84 KB
Loading

0 commit comments

Comments
 (0)