Skip to content

Commit 4095438

Browse files
committed
update to new tdict format
1 parent 0abee08 commit 4095438

File tree

8 files changed

+2601
-142
lines changed

8 files changed

+2601
-142
lines changed

backends/model_converter/constants.py

Lines changed: 5 additions & 1 deletion
Large diffs are not rendered by default.
Lines changed: 71 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,104 @@
11
import json
22
import numpy as np
3-
from constants import SD_SHAPES, _ALPHAS_CUMPROD
3+
from constants import _ALPHAS_CUMPROD
44
import sys, getopt
55

66
# python convert_model.py "/Users/divamgupta/Downloads/hollie-mengert.ckpt" "/Users/divamgupta/Downloads/hollie-mengert.tdict"
77

88
# pyinstaller convert_model.py --onefile --noconfirm --clean # build using intel machine so that its cross platform lol
99

10-
unpickle = False
11-
10+
1211
try:
13-
optlist, args = getopt.getopt(sys.argv[1:], "hu", ["help", "unpickle"])
12+
optlist, args = getopt.getopt(sys.argv[1:], "hu", ["help" ])
1413
except getopt.GetoptError as err:
15-
print(err)
16-
#usage()
17-
sys.exit(2)
14+
print(err)
15+
sys.exit(2)
1816
for o, a in optlist:
19-
if o in ("-h", "--help"):
20-
usage()
21-
sys.exit()
22-
elif o in ("-u", "--unpickle"):
23-
unpickle = True
24-
else:
25-
assert False, "unhandled option"
17+
if o in ("-h", "--help"):
18+
usage()
19+
sys.exit()
20+
else:
21+
assert False, "unhandled option"
2622

2723
def usage():
28-
print("\nConverts .cpkt model files into .tdict model files for Diffusion Bee")
29-
print("\npython3 convert_py [--unpickle] input.ckpt output.tdict")
30-
print("\tNormal use.")
31-
print("\n\t--unpickle")
32-
print("\t\tWill use unpickling to extract the model, please use with caution as malicious code")
33-
print("\t\tcan be hidden in the .ckpt file, executed by unpickling. Without this option, the pickle")
34-
print("\t\tinside the .ckpt will instead be decompiled and the weights extracted from that with")
35-
print("\t\tno arbitrary code execution.")
36-
print("\n\tPlease report any errors on the Diffusion Bee GitHub project or the official Discord server.")
37-
print("\npython3 convert_py --help")
38-
print("\tDisplays this message")
24+
print("\nConverts .cpkt model files into .tdict model files for Diffusion Bee")
25+
print("\npython3 convert_py input.ckpt output.tdict")
26+
print("\tNormal use.")
27+
print("\n\tPlease report any errors on the Diffusion Bee GitHub project or the official Discord server.")
28+
print("\npython3 convert_py --help")
29+
print("\tDisplays this message")
3930

4031
if len(args) != 2:
41-
print("Incorrect number of arguments")
42-
usage()
43-
sys.exit(2)
32+
print("Incorrect number of arguments")
33+
usage()
34+
sys.exit(2)
4435

4536
checkpoint_filename = args[0]
4637
out_filename = args[1]
4738

48-
if unpickle:
49-
from fake_torch import fake_torch_load_zipped
50-
torch_weights = fake_torch_load_zipped(open(checkpoint_filename, "rb"))
51-
else:
52-
from fake_torch import extract_weights_from_checkpoint
53-
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
54-
55-
56-
#TODO add MD5s
57-
58-
_HEADER_BYTES = [42, 10 , 8, 42] + [0]*20
5939

40+
from fake_torch import extract_weights_from_checkpoint
41+
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
6042

61-
s = 24
6243

63-
keys_info = {}
64-
out_file = open( out_filename , "wb")
44+
from sd_shapes import get_model_type , possible_model_shapes , ctdict_ids
45+
from tdict import TDict
6546

66-
out_file.write(bytes(_HEADER_BYTES))
6747

6848

6949
torch_weights['state_dict']['temb_coefficients_fp32'] = np.array([1.0, 0.94406086, 0.8912509, 0.84139514, 0.7943282, 0.7498942, 0.70794576, 0.6683439, 0.63095737, 0.5956621, 0.56234133, 0.53088444, 0.50118726, 0.47315124, 0.44668362, 0.42169648, 0.39810717, 0.3758374, 0.35481337, 0.33496544, 0.31622776, 0.29853827, 0.2818383, 0.2660725, 0.25118864, 0.23713735, 0.2238721, 0.21134889, 0.19952625, 0.18836491, 0.17782794, 0.16788039, 0.15848932, 0.14962357, 0.14125374, 0.13335215, 0.12589253, 0.11885023, 0.11220184, 0.10592538, 0.099999994, 0.094406076, 0.08912509, 0.0841395, 0.07943282, 0.074989416, 0.07079458, 0.06683439, 0.06309574, 0.05956621, 0.056234125, 0.05308844, 0.05011872, 0.047315124, 0.044668354, 0.04216965, 0.039810725, 0.037583742, 0.035481337, 0.033496536, 0.031622775, 0.029853828, 0.028183822, 0.026607247, 0.025118865, 0.02371374, 0.022387212, 0.021134889, 0.01995262, 0.018836489, 0.017782794, 0.01678804, 0.015848929, 0.014962353, 0.014125377, 0.013335214, 0.012589253, 0.01188502, 0.011220186, 0.010592538, 0.01, 0.009440607, 0.0089125065, 0.008413952, 0.007943282, 0.007498941, 0.007079456, 0.00668344, 0.0063095735, 0.005956621, 0.0056234123, 0.0053088428, 0.005011873, 0.0047315126, 0.0044668354, 0.004216964, 0.003981072, 0.0037583741, 0.0035481334, 0.0033496537, 0.0031622767, 0.0029853827, 0.0028183828, 0.0026607246, 0.0025118857, 0.0023713738, 0.0022387211, 0.0021134887, 0.0019952618, 0.0018836485, 0.0017782794, 0.0016788039, 0.0015848937, 0.001496236, 0.0014125376, 0.0013335207, 0.0012589252, 0.001188502, 0.0011220181, 0.0010592537, 0.0009999999, 0.00094406115, 0.0008912511, 0.0008413952, 0.0007943278, 0.00074989407, 0.0007079456, 0.0006683437, 0.00063095737, 0.0005956621, 0.0005623415, 0.00053088454, 0.0005011872, 0.000473151, 0.00044668352, 0.00042169637, 0.00039810702, 0.0003758374, 0.00035481335, 0.00033496553, 0.00031622782, 0.00029853827, 0.00028183826, 0.00026607246, 0.00025118855, 0.00023713727, 0.00022387199, 0.00021134898, 0.00019952627, 0.00018836492, 0.00017782794, 0.00016788038, 0.00015848929, 0.00014962352, 0.0001412537, 0.00013335208, 0.00012589258, 0.00011885024, 0.00011220186, 0.00010592537]).astype('float32')
7050
torch_weights['state_dict']['causal_mask'] = np.triu(np.ones((1,1,77,77), dtype=np.float16) * -65500.0, k=1).astype(np.float32)
7151
torch_weights['state_dict']['aux_output_conv.weight'] = np.array([0.14013671875, 0.0711669921875, -0.03271484375, -0.11407470703125, 0.126220703125, 0.10101318359375, 0.034515380859375, -0.1383056640625, 0.126220703125, 0.07733154296875, 0.042633056640625, -0.177978515625]).astype(np.float32)
7252
torch_weights['state_dict']['aux_output_conv.bias'] = np.array([0.423828125, 0.471923828125, 0.473876953125]).astype(np.float32)
7353
torch_weights['state_dict']['alphas_cumprod'] = np.array(_ALPHAS_CUMPROD).astype(np.float32)
74-
extra_keys = ['temb_coefficients_fp32' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod']
54+
torch_weights['state_dict']['temb_coefficients_fp16'] = np.array( [1.0, 0.944, 0.891, 0.8413, 0.7944, 0.75, 0.708, 0.6685, 0.631, 0.5957, 0.5625, 0.531, 0.501, 0.4731, 0.4468, 0.4216, 0.3982, 0.3757, 0.3547, 0.335, 0.3162, 0.2986, 0.2817, 0.266, 0.2512, 0.2372, 0.2239, 0.2113, 0.1996, 0.1884, 0.1779, 0.1678, 0.1584, 0.1497, 0.1412, 0.1333, 0.1259, 0.11884, 0.1122, 0.1059, 0.1, 0.0944, 0.0891, 0.08417, 0.0794, 0.075, 0.0708, 0.06683, 0.0631, 0.05957, 0.05624, 0.0531, 0.0501, 0.0473, 0.04468, 0.04218, 0.03983, 0.0376, 0.0355, 0.0335, 0.03162, 0.02986, 0.02818, 0.02661, 0.02512, 0.02371, 0.02238, 0.02113, 0.01996, 0.01883, 0.01778, 0.01678, 0.01585, 0.01496, 0.01412, 0.013336, 0.01259, 0.01189, 0.01122, 0.01059, 0.01, 0.00944, 0.00891, 0.008415, 0.00794, 0.0075, 0.00708, 0.006683, 0.00631, 0.005955, 0.005623, 0.00531, 0.005013, 0.00473, 0.004467, 0.004215, 0.003983, 0.003757, 0.003548, 0.00335, 0.003162, 0.002985, 0.00282, 0.00266, 0.002512, 0.00237, 0.00224, 0.002113, 0.001995, 0.0018835, 0.001779, 0.001678, 0.001585, 0.001496, 0.001412, 0.001333, 0.001259, 0.001188, 0.001122, 0.00106, 0.001, 0.000944, 0.000891, 0.0008416, 0.0007944, 0.00075, 0.000708, 0.0006685, 0.000631, 0.0005956, 0.000562, 0.0005307, 0.000501, 0.0004733, 0.0004468, 0.0004218, 0.0003982, 0.0003757, 0.0003548, 0.000335, 0.0003161, 0.0002985, 0.0002818, 0.000266, 0.0002513, 0.0002371, 0.0002239, 0.0002114, 0.0001996, 0.0001884, 0.0001779, 0.0001678, 0.0001585, 0.0001496, 0.0001413, 0.0001334, 0.0001259, 0.00011885, 0.0001122, 0.0001059]).astype('float16')
55+
56+
extra_keys = ['temb_coefficients_fp32', 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod']
57+
58+
59+
60+
for k in (list(torch_weights['state_dict'].keys())):
61+
if '.norm' in k and '.bias' in k:
62+
k2 = k.replace(".bias" , ".weight")
63+
k3 = k.replace(".bias" , ".bias_by_weight")
64+
torch_weights['state_dict'][k3] = torch_weights['state_dict'][k]/torch_weights['state_dict'][k2]
65+
66+
if ".ff." in k:
67+
pp = torch_weights['state_dict'][k]
68+
torch_weights['state_dict'][k + "._split_1"] = pp[:pp.shape[0]//2].copy()
69+
torch_weights['state_dict'][k + "._split_2"] = pp[pp.shape[0]//2:].copy()
7570

76-
for k in torch_weights['state_dict']:
77-
if k not in SD_SHAPES and k not in extra_keys:
78-
continue
79-
if 'model_ema' in k:
80-
continue
71+
72+
for i in range(1,21):
73+
nn = 320*i
74+
dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype
75+
torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
76+
torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype)
77+
78+
nn = 128*i
79+
dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype
80+
torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
81+
torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype)
82+
83+
84+
model_type = get_model_type(torch_weights['state_dict'])
85+
86+
if model_type is None:
87+
raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5 .ckpt file")
88+
89+
print("model type " , model_type)
90+
91+
model_shapes = possible_model_shapes[model_type]
92+
ctdict_id = ctdict_ids[model_type]
93+
94+
outfile = TDict(fpath=out_filename)
95+
96+
outfile.init_write(ctdict_version=ctdict_id )
97+
98+
for k in model_shapes:
8199
np_arr = torch_weights['state_dict'][k]
82-
key_bytes = np_arr.tobytes()
83100
shape = list(np_arr.shape)
84-
if k not in extra_keys:
85-
assert tuple(shape) == SD_SHAPES[k], ( "shape mismatch at" , k , shape , SD_SHAPES[k] )
86-
dtype = str(np_arr.dtype)
87-
if dtype == 'int64':
88-
np_arr = np_arr.astype('float32')
89-
dtype = 'float32'
90-
assert dtype in ['float16' , 'float32'] , (dtype, k)
91-
e = s + len(key_bytes)
92-
out_file.write(key_bytes)
93-
keys_info[k] = {"start": s , "end" : e , "shape": shape , "dtype" : dtype }
94-
s = e
95-
96-
for k in SD_SHAPES:
97-
if 'model_ema' in k or 'betas' in k or 'alphas' in k or 'posterior_' in k:
98-
continue
99-
assert k in keys_info , k
100-
101-
json_start = s
102-
info_json = bytes( json.dumps(keys_info) , 'ascii')
103-
json_end = s + len(info_json)
104-
105-
out_file.write(info_json)
106-
107-
out_file.seek(5)
108-
out_file.write(np.array(json_start).astype('long').tobytes())
109-
110-
out_file.seek(14)
111-
out_file.write(np.array(json_end).astype('long').tobytes())
101+
assert tuple(shape) == tuple(model_shapes[k]), ( "shape mismatch at" , k , shape , SD_SHAPES[k] )
102+
outfile.write_key(key=k , tensor=np_arr)
103+
104+
outfile.finish_write()
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
from sd_shapes_consts import shapes_unet , shapes_encoder, shapes_decoder , shapes_text_encoder, shapes_params
3+
import copy
4+
5+
6+
def add_aux_shapes(d):
7+
for k in list(d.keys()):
8+
if '.norm' in k and '.bias' in k:
9+
d[k.replace(".bias" , ".bias_by_weight")] = d[k]
10+
11+
if ".ff." in k:
12+
sh = list(d[k])
13+
sh[0] /= 2
14+
sh = tuple(sh)
15+
d[k + "._split_1"] = sh
16+
d[k + "._split_2"] = sh
17+
18+
for i in range(1,21):
19+
nn = 320*i
20+
d["zeros_"+str(nn)] = (nn,)
21+
d["ones_"+str(nn)] = (nn,)
22+
23+
nn = 128*i
24+
d["zeros_"+str(nn)] = (nn,)
25+
d["ones_"+str(nn)] = (nn,)
26+
27+
28+
29+
30+
sd_1x_shapes = {}
31+
sd_1x_shapes.update(shapes_unet)
32+
sd_1x_shapes.update(shapes_encoder)
33+
sd_1x_shapes.update(shapes_decoder)
34+
sd_1x_shapes.update(shapes_text_encoder)
35+
sd_1x_shapes.update(shapes_params)
36+
37+
sd_1x_inpaint_shapes = copy.deepcopy(sd_1x_shapes)
38+
sd_1x_inpaint_shapes['model.diffusion_model.input_blocks.0.0.weight'] = [320, 9, 3, 3]
39+
40+
add_aux_shapes(sd_1x_shapes)
41+
add_aux_shapes(sd_1x_inpaint_shapes)
42+
43+
44+
possible_model_shapes = {"SD_1x_float32": sd_1x_shapes ,
45+
"SD_1x_inpaint_float32": sd_1x_inpaint_shapes,
46+
"SD_1x_float16": sd_1x_shapes ,
47+
"SD_1x_inpaint_float16": sd_1x_inpaint_shapes}
48+
49+
ctdict_ids = {"SD_1x_float32": 12 ,
50+
"SD_1x_inpaint_float32": 13,
51+
"SD_1x_float16": 1012 ,
52+
"SD_1x_inpaint_float16": 1013}
53+
54+
55+
extra_keys = ['temb_coefficients_fp32' , 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod']
56+
57+
58+
59+
def are_shapes_matching(state_dict , template_shapes):
60+
for k in template_shapes:
61+
if k not in state_dict:
62+
return False
63+
if tuple(template_shapes[k]) != tuple(state_dict[k].shape):
64+
return False
65+
66+
return True
67+
68+
def are_shapes_dtype(state_dict, template_shapes , dtype):
69+
for k in state_dict:
70+
if k in extra_keys:
71+
continue
72+
if k not in template_shapes:
73+
continue
74+
if state_dict[k].dtype != dtype:
75+
return False
76+
77+
return True
78+
79+
80+
def get_model_type(state_dict):
81+
if are_shapes_matching(state_dict , sd_1x_shapes) and are_shapes_dtype(state_dict , sd_1x_shapes, "float32"):
82+
return "SD_1x_float32"
83+
elif are_shapes_matching(state_dict , sd_1x_inpaint_shapes) and are_shapes_dtype(state_dict , sd_1x_inpaint_shapes , "float32"):
84+
return "SD_1x_inpaint_float32"
85+
elif are_shapes_matching(state_dict , sd_1x_shapes) and are_shapes_dtype(state_dict , sd_1x_shapes , "float16"):
86+
return "SD_1x_float16"
87+
elif are_shapes_matching(state_dict , sd_1x_inpaint_shapes) and are_shapes_dtype(state_dict , sd_1x_inpaint_shapes, "float16"):
88+
return "SD_1x_inpaint_float16"
89+
else:
90+
return None
91+
92+

0 commit comments

Comments
 (0)