11
2- from sd_shapes_consts import shapes_unet , shapes_encoder , shapes_decoder , shapes_text_encoder , shapes_params
2+ from sd_shapes_consts import shapes_unet , shapes_encoder , shapes_decoder , shapes_text_encoder , shapes_params , shapes_unet_v2 , text_encoder_open_clip
33import copy
44from collections import Counter
55
@@ -10,11 +10,19 @@ def add_aux_shapes(d):
1010
1111 if ".ff." in k :
1212 sh = list (d [k ])
13- sh [0 ] /= 2
13+ sh [0 ] = sh [ 0 ] // 2
1414 sh = tuple (sh )
1515 d [k + "._split_1" ] = sh
1616 d [k + "._split_2" ] = sh
1717
18+ elif "attn.in_proj" in k and d [k ][0 ] == 3072 :
19+ sh = list (d [k ])
20+ sh [0 ] = sh [0 ] // 3
21+ sh = tuple (sh )
22+ d [k + "._split_1" ] = sh
23+ d [k + "._split_2" ] = sh
24+ d [k + "._split_3" ] = sh
25+
1826 for i in range (1 ,21 ):
1927 nn = 320 * i
2028 d ["zeros_" + str (nn )] = (nn ,)
@@ -41,16 +49,35 @@ def add_aux_shapes(d):
4149add_aux_shapes (sd_1x_inpaint_shapes )
4250
4351
52+
53+
54+ sd_2x_shapes = {}
55+ sd_2x_shapes .update (shapes_unet_v2 )
56+ sd_2x_shapes .update (shapes_encoder )
57+ sd_2x_shapes .update (shapes_decoder )
58+ sd_2x_shapes .update (text_encoder_open_clip )
59+ sd_2x_shapes .update (shapes_params )
60+
61+
62+ add_aux_shapes (sd_2x_shapes )
63+
64+
4465possible_model_shapes = {"SD_1x_float32" : sd_1x_shapes ,
4566 "SD_1x_inpaint_float32" : sd_1x_inpaint_shapes ,
4667 "SD_1x_float16" : sd_1x_shapes ,
68+
69+ "SD_2x_float16" : sd_2x_shapes ,
70+ "SD_2x_float32" : sd_2x_shapes ,
71+
4772 "SD_1x_inpaint_float16" : sd_1x_inpaint_shapes }
4873
4974ctdict_ids = {"SD_1x_float32" : 12 ,
5075 "SD_1x_inpaint_float32" : 13 ,
76+ "SD_2x_float32" : 15 ,
5177 "SD_1x_float16" : 1012 ,
5278 "SD_1x_inpaint_float16" : 1013 ,
53- "SD_1x_just_controlnet_16" : 1014 }
79+ "SD_1x_just_controlnet_16" : 1014 ,
80+ "SD_2x_float16" : 1015 }
5481
5582
5683extra_keys = ['temb_coefficients_fp32' , 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias' , 'alphas_cumprod' ]
@@ -63,8 +90,9 @@ def are_shapes_matching(state_dict , template_shapes , name=None):
6390 print ("key" , k , "not found in state_dict" , state_dict .keys ())
6491 return False
6592 if tuple (template_shapes [k ]) != tuple (state_dict [k ].shape ):
66- print ("shape mismatch" , k , tuple (template_shapes [k ]) ,tuple (state_dict [k ].shape ) )
67- return False
93+ if tuple (template_shapes [k ]) != tuple (state_dict [k ].shape ) + (1 ,1 ):
94+ print ("shape mismatch" , k , tuple (template_shapes [k ]) ,tuple (state_dict [k ].shape ) )
95+ return False
6896
6997 return True
7098
@@ -101,6 +129,9 @@ def get_model_type(state_dict):
101129 if are_shapes_matching (state_dict , sd_1x_shapes ) :
102130 shapes = sd_1x_shapes
103131 mname = "SD_1x"
132+ elif are_shapes_matching (state_dict , sd_2x_shapes ) :
133+ shapes = sd_2x_shapes
134+ mname = "SD_2x"
104135 elif are_shapes_matching (state_dict , sd_1x_inpaint_shapes ) :
105136 shapes = sd_1x_inpaint_shapes
106137 mname = "SD_1x_inpaint"
0 commit comments