Skip to content

Commit 52d5364

Browse files
committed
support for SD2 added, debug mode added
1 parent ee51c0c commit 52d5364

File tree

13 files changed

+2627
-40
lines changed

13 files changed

+2627
-40
lines changed

backends/model_converter/convert_model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ def convert_model(checkpoint_filename, out_filename ):
5151
state_dict[k + "._split_1"] = pp[:pp.shape[0]//2].copy()
5252
state_dict[k + "._split_2"] = pp[pp.shape[0]//2:].copy()
5353

54+
elif "attn.in_proj" in k and state_dict[k].shape[0] == 3072 :
55+
pp = state_dict[k]
56+
state_dict[k + "._split_1"] = pp[:pp.shape[0]//3].copy()
57+
state_dict[k + "._split_2"] = pp[pp.shape[0]//3:2*pp.shape[0]//3].copy()
58+
state_dict[k + "._split_3"] = pp[2*pp.shape[0]//3: ].copy()
59+
60+
5461
keys_list = list(state_dict.keys())
5562
mid_key = keys_list[len(keys_list)//2]
5663

@@ -69,7 +76,7 @@ def convert_model(checkpoint_filename, out_filename ):
6976
model_type = get_model_type(state_dict)
7077

7178
if model_type is None:
72-
raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5 .ckpt file")
79+
raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5/2.1 .ckpt/safetensor file")
7380

7481
if "float16" in model_type:
7582
cur_dtype = "float16"
@@ -88,7 +95,9 @@ def convert_model(checkpoint_filename, out_filename ):
8895
outfile.init_write(ctdict_version=ctdict_id )
8996

9097
for k in model_shapes:
91-
np_arr = state_dict[k]
98+
np_arr = np.copy(state_dict[k])
99+
np_arr = np.reshape(np_arr , model_shapes[k] )
100+
92101
if "float" in str(np_arr.dtype):
93102
np_arr = np_arr.astype(cur_dtype)
94103
shape = list(np_arr.shape)

backends/model_converter/sd_shapes.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
from sd_shapes_consts import shapes_unet , shapes_encoder, shapes_decoder , shapes_text_encoder, shapes_params
2+
from sd_shapes_consts import shapes_unet , shapes_encoder, shapes_decoder , shapes_text_encoder, shapes_params , shapes_unet_v2 , text_encoder_open_clip
33
import copy
44
from collections import Counter
55

@@ -10,11 +10,19 @@ def add_aux_shapes(d):
1010

1111
if ".ff." in k:
1212
sh = list(d[k])
13-
sh[0] /= 2
13+
sh[0] = sh[0] // 2
1414
sh = tuple(sh)
1515
d[k + "._split_1"] = sh
1616
d[k + "._split_2"] = sh
1717

18+
elif "attn.in_proj" in k and d[k][0] == 3072 :
19+
sh = list(d[k])
20+
sh[0] = sh[0] // 3
21+
sh = tuple(sh)
22+
d[k + "._split_1"] = sh
23+
d[k + "._split_2"] = sh
24+
d[k + "._split_3"] = sh
25+
1826
for i in range(1,21):
1927
nn = 320*i
2028
d["zeros_"+str(nn)] = (nn,)
@@ -41,16 +49,35 @@ def add_aux_shapes(d):
4149
add_aux_shapes(sd_1x_inpaint_shapes)
4250

4351

52+
53+
54+
sd_2x_shapes = {}
55+
sd_2x_shapes.update(shapes_unet_v2)
56+
sd_2x_shapes.update(shapes_encoder)
57+
sd_2x_shapes.update(shapes_decoder)
58+
sd_2x_shapes.update(text_encoder_open_clip)
59+
sd_2x_shapes.update(shapes_params)
60+
61+
62+
add_aux_shapes(sd_2x_shapes)
63+
64+
4465
possible_model_shapes = {"SD_1x_float32": sd_1x_shapes ,
4566
"SD_1x_inpaint_float32": sd_1x_inpaint_shapes,
4667
"SD_1x_float16": sd_1x_shapes ,
68+
69+
"SD_2x_float16": sd_2x_shapes ,
70+
"SD_2x_float32": sd_2x_shapes ,
71+
4772
"SD_1x_inpaint_float16": sd_1x_inpaint_shapes}
4873

4974
ctdict_ids = {"SD_1x_float32": 12 ,
5075
"SD_1x_inpaint_float32": 13,
76+
"SD_2x_float32": 15 ,
5177
"SD_1x_float16": 1012 ,
5278
"SD_1x_inpaint_float16": 1013 ,
53-
"SD_1x_just_controlnet_16" : 1014}
79+
"SD_1x_just_controlnet_16" : 1014,
80+
"SD_2x_float16": 1015 }
5481

5582

5683
extra_keys = ['temb_coefficients_fp32' , 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod']
@@ -63,8 +90,9 @@ def are_shapes_matching(state_dict , template_shapes , name=None):
6390
print("key", k , "not found in state_dict" , state_dict.keys())
6491
return False
6592
if tuple(template_shapes[k]) != tuple(state_dict[k].shape):
66-
print("shape mismatch", k , tuple(template_shapes[k]) ,tuple(state_dict[k].shape) )
67-
return False
93+
if tuple(template_shapes[k]) != tuple(state_dict[k].shape) + (1,1):
94+
print("shape mismatch", k , tuple(template_shapes[k]) ,tuple(state_dict[k].shape) )
95+
return False
6896

6997
return True
7098

@@ -101,6 +129,9 @@ def get_model_type(state_dict):
101129
if are_shapes_matching(state_dict , sd_1x_shapes) :
102130
shapes = sd_1x_shapes
103131
mname = "SD_1x"
132+
elif are_shapes_matching(state_dict , sd_2x_shapes) :
133+
shapes = sd_2x_shapes
134+
mname = "SD_2x"
104135
elif are_shapes_matching(state_dict , sd_1x_inpaint_shapes) :
105136
shapes = sd_1x_inpaint_shapes
106137
mname = "SD_1x_inpaint"

0 commit comments

Comments
 (0)