Skip to content

Commit 865d2af

Browse files
committed
model converter endpoint merged
1 parent 4095438 commit 865d2af

File tree

3 files changed

+87
-71
lines changed

3 files changed

+87
-71
lines changed

backends/model_converter/build.sh

Lines changed: 0 additions & 1 deletion
This file was deleted.

backends/model_converter/convert_model.py

Lines changed: 72 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,97 +8,102 @@
88
# pyinstaller convert_model.py --onefile --noconfirm --clean # build using intel machine so that its cross platform lol
99

1010

11-
try:
12-
optlist, args = getopt.getopt(sys.argv[1:], "hu", ["help" ])
13-
except getopt.GetoptError as err:
14-
print(err)
15-
sys.exit(2)
16-
for o, a in optlist:
17-
if o in ("-h", "--help"):
18-
usage()
19-
sys.exit()
20-
else:
21-
assert False, "unhandled option"
2211

23-
def usage():
24-
print("\nConverts .cpkt model files into .tdict model files for Diffusion Bee")
25-
print("\npython3 convert_py input.ckpt output.tdict")
26-
print("\tNormal use.")
27-
print("\n\tPlease report any errors on the Diffusion Bee GitHub project or the official Discord server.")
28-
print("\npython3 convert_py --help")
29-
print("\tDisplays this message")
12+
from fake_torch import extract_weights_from_checkpoint
13+
from sd_shapes import get_model_type , possible_model_shapes , ctdict_ids
14+
from tdict import TDict
3015

31-
if len(args) != 2:
32-
print("Incorrect number of arguments")
33-
usage()
34-
sys.exit(2)
3516

36-
checkpoint_filename = args[0]
37-
out_filename = args[1]
3817

18+
def convert_model(checkpoint_filename, out_filename ):
19+
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
3920

40-
from fake_torch import extract_weights_from_checkpoint
41-
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
21+
torch_weights['state_dict']['temb_coefficients_fp32'] = np.array([1.0, 0.94406086, 0.8912509, 0.84139514, 0.7943282, 0.7498942, 0.70794576, 0.6683439, 0.63095737, 0.5956621, 0.56234133, 0.53088444, 0.50118726, 0.47315124, 0.44668362, 0.42169648, 0.39810717, 0.3758374, 0.35481337, 0.33496544, 0.31622776, 0.29853827, 0.2818383, 0.2660725, 0.25118864, 0.23713735, 0.2238721, 0.21134889, 0.19952625, 0.18836491, 0.17782794, 0.16788039, 0.15848932, 0.14962357, 0.14125374, 0.13335215, 0.12589253, 0.11885023, 0.11220184, 0.10592538, 0.099999994, 0.094406076, 0.08912509, 0.0841395, 0.07943282, 0.074989416, 0.07079458, 0.06683439, 0.06309574, 0.05956621, 0.056234125, 0.05308844, 0.05011872, 0.047315124, 0.044668354, 0.04216965, 0.039810725, 0.037583742, 0.035481337, 0.033496536, 0.031622775, 0.029853828, 0.028183822, 0.026607247, 0.025118865, 0.02371374, 0.022387212, 0.021134889, 0.01995262, 0.018836489, 0.017782794, 0.01678804, 0.015848929, 0.014962353, 0.014125377, 0.013335214, 0.012589253, 0.01188502, 0.011220186, 0.010592538, 0.01, 0.009440607, 0.0089125065, 0.008413952, 0.007943282, 0.007498941, 0.007079456, 0.00668344, 0.0063095735, 0.005956621, 0.0056234123, 0.0053088428, 0.005011873, 0.0047315126, 0.0044668354, 0.004216964, 0.003981072, 0.0037583741, 0.0035481334, 0.0033496537, 0.0031622767, 0.0029853827, 0.0028183828, 0.0026607246, 0.0025118857, 0.0023713738, 0.0022387211, 0.0021134887, 0.0019952618, 0.0018836485, 0.0017782794, 0.0016788039, 0.0015848937, 0.001496236, 0.0014125376, 0.0013335207, 0.0012589252, 0.001188502, 0.0011220181, 0.0010592537, 0.0009999999, 0.00094406115, 0.0008912511, 0.0008413952, 0.0007943278, 0.00074989407, 0.0007079456, 0.0006683437, 0.00063095737, 0.0005956621, 0.0005623415, 0.00053088454, 0.0005011872, 0.000473151, 0.00044668352, 0.00042169637, 0.00039810702, 0.0003758374, 0.00035481335, 0.00033496553, 0.00031622782, 0.00029853827, 0.00028183826, 0.00026607246, 0.00025118855, 0.00023713727, 0.00022387199, 0.00021134898, 0.00019952627, 0.00018836492, 0.00017782794, 0.00016788038, 0.00015848929, 0.00014962352, 0.0001412537, 0.00013335208, 0.00012589258, 0.00011885024, 0.00011220186, 0.00010592537]).astype('float32')
22+
torch_weights['state_dict']['causal_mask'] = np.triu(np.ones((1,1,77,77), dtype=np.float16) * -65500.0, k=1).astype(np.float32)
23+
torch_weights['state_dict']['aux_output_conv.weight'] = np.array([0.14013671875, 0.0711669921875, -0.03271484375, -0.11407470703125, 0.126220703125, 0.10101318359375, 0.034515380859375, -0.1383056640625, 0.126220703125, 0.07733154296875, 0.042633056640625, -0.177978515625]).astype(np.float32)
24+
torch_weights['state_dict']['aux_output_conv.bias'] = np.array([0.423828125, 0.471923828125, 0.473876953125]).astype(np.float32)
25+
torch_weights['state_dict']['alphas_cumprod'] = np.array(_ALPHAS_CUMPROD).astype(np.float32)
26+
torch_weights['state_dict']['temb_coefficients_fp16'] = np.array( [1.0, 0.944, 0.891, 0.8413, 0.7944, 0.75, 0.708, 0.6685, 0.631, 0.5957, 0.5625, 0.531, 0.501, 0.4731, 0.4468, 0.4216, 0.3982, 0.3757, 0.3547, 0.335, 0.3162, 0.2986, 0.2817, 0.266, 0.2512, 0.2372, 0.2239, 0.2113, 0.1996, 0.1884, 0.1779, 0.1678, 0.1584, 0.1497, 0.1412, 0.1333, 0.1259, 0.11884, 0.1122, 0.1059, 0.1, 0.0944, 0.0891, 0.08417, 0.0794, 0.075, 0.0708, 0.06683, 0.0631, 0.05957, 0.05624, 0.0531, 0.0501, 0.0473, 0.04468, 0.04218, 0.03983, 0.0376, 0.0355, 0.0335, 0.03162, 0.02986, 0.02818, 0.02661, 0.02512, 0.02371, 0.02238, 0.02113, 0.01996, 0.01883, 0.01778, 0.01678, 0.01585, 0.01496, 0.01412, 0.013336, 0.01259, 0.01189, 0.01122, 0.01059, 0.01, 0.00944, 0.00891, 0.008415, 0.00794, 0.0075, 0.00708, 0.006683, 0.00631, 0.005955, 0.005623, 0.00531, 0.005013, 0.00473, 0.004467, 0.004215, 0.003983, 0.003757, 0.003548, 0.00335, 0.003162, 0.002985, 0.00282, 0.00266, 0.002512, 0.00237, 0.00224, 0.002113, 0.001995, 0.0018835, 0.001779, 0.001678, 0.001585, 0.001496, 0.001412, 0.001333, 0.001259, 0.001188, 0.001122, 0.00106, 0.001, 0.000944, 0.000891, 0.0008416, 0.0007944, 0.00075, 0.000708, 0.0006685, 0.000631, 0.0005956, 0.000562, 0.0005307, 0.000501, 0.0004733, 0.0004468, 0.0004218, 0.0003982, 0.0003757, 0.0003548, 0.000335, 0.0003161, 0.0002985, 0.0002818, 0.000266, 0.0002513, 0.0002371, 0.0002239, 0.0002114, 0.0001996, 0.0001884, 0.0001779, 0.0001678, 0.0001585, 0.0001496, 0.0001413, 0.0001334, 0.0001259, 0.00011885, 0.0001122, 0.0001059]).astype('float16')
4227

28+
extra_keys = ['temb_coefficients_fp32', 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod']
4329

44-
from sd_shapes import get_model_type , possible_model_shapes , ctdict_ids
45-
from tdict import TDict
4630

4731

32+
for k in (list(torch_weights['state_dict'].keys())):
33+
if '.norm' in k and '.bias' in k:
34+
k2 = k.replace(".bias" , ".weight")
35+
k3 = k.replace(".bias" , ".bias_by_weight")
36+
torch_weights['state_dict'][k3] = torch_weights['state_dict'][k]/torch_weights['state_dict'][k2]
4837

49-
torch_weights['state_dict']['temb_coefficients_fp32'] = np.array([1.0, 0.94406086, 0.8912509, 0.84139514, 0.7943282, 0.7498942, 0.70794576, 0.6683439, 0.63095737, 0.5956621, 0.56234133, 0.53088444, 0.50118726, 0.47315124, 0.44668362, 0.42169648, 0.39810717, 0.3758374, 0.35481337, 0.33496544, 0.31622776, 0.29853827, 0.2818383, 0.2660725, 0.25118864, 0.23713735, 0.2238721, 0.21134889, 0.19952625, 0.18836491, 0.17782794, 0.16788039, 0.15848932, 0.14962357, 0.14125374, 0.13335215, 0.12589253, 0.11885023, 0.11220184, 0.10592538, 0.099999994, 0.094406076, 0.08912509, 0.0841395, 0.07943282, 0.074989416, 0.07079458, 0.06683439, 0.06309574, 0.05956621, 0.056234125, 0.05308844, 0.05011872, 0.047315124, 0.044668354, 0.04216965, 0.039810725, 0.037583742, 0.035481337, 0.033496536, 0.031622775, 0.029853828, 0.028183822, 0.026607247, 0.025118865, 0.02371374, 0.022387212, 0.021134889, 0.01995262, 0.018836489, 0.017782794, 0.01678804, 0.015848929, 0.014962353, 0.014125377, 0.013335214, 0.012589253, 0.01188502, 0.011220186, 0.010592538, 0.01, 0.009440607, 0.0089125065, 0.008413952, 0.007943282, 0.007498941, 0.007079456, 0.00668344, 0.0063095735, 0.005956621, 0.0056234123, 0.0053088428, 0.005011873, 0.0047315126, 0.0044668354, 0.004216964, 0.003981072, 0.0037583741, 0.0035481334, 0.0033496537, 0.0031622767, 0.0029853827, 0.0028183828, 0.0026607246, 0.0025118857, 0.0023713738, 0.0022387211, 0.0021134887, 0.0019952618, 0.0018836485, 0.0017782794, 0.0016788039, 0.0015848937, 0.001496236, 0.0014125376, 0.0013335207, 0.0012589252, 0.001188502, 0.0011220181, 0.0010592537, 0.0009999999, 0.00094406115, 0.0008912511, 0.0008413952, 0.0007943278, 0.00074989407, 0.0007079456, 0.0006683437, 0.00063095737, 0.0005956621, 0.0005623415, 0.00053088454, 0.0005011872, 0.000473151, 0.00044668352, 0.00042169637, 0.00039810702, 0.0003758374, 0.00035481335, 0.00033496553, 0.00031622782, 0.00029853827, 0.00028183826, 0.00026607246, 0.00025118855, 0.00023713727, 0.00022387199, 0.00021134898, 0.00019952627, 0.00018836492, 0.00017782794, 0.00016788038, 0.00015848929, 0.00014962352, 0.0001412537, 0.00013335208, 0.00012589258, 0.00011885024, 0.00011220186, 0.00010592537]).astype('float32')
50-
torch_weights['state_dict']['causal_mask'] = np.triu(np.ones((1,1,77,77), dtype=np.float16) * -65500.0, k=1).astype(np.float32)
51-
torch_weights['state_dict']['aux_output_conv.weight'] = np.array([0.14013671875, 0.0711669921875, -0.03271484375, -0.11407470703125, 0.126220703125, 0.10101318359375, 0.034515380859375, -0.1383056640625, 0.126220703125, 0.07733154296875, 0.042633056640625, -0.177978515625]).astype(np.float32)
52-
torch_weights['state_dict']['aux_output_conv.bias'] = np.array([0.423828125, 0.471923828125, 0.473876953125]).astype(np.float32)
53-
torch_weights['state_dict']['alphas_cumprod'] = np.array(_ALPHAS_CUMPROD).astype(np.float32)
54-
torch_weights['state_dict']['temb_coefficients_fp16'] = np.array( [1.0, 0.944, 0.891, 0.8413, 0.7944, 0.75, 0.708, 0.6685, 0.631, 0.5957, 0.5625, 0.531, 0.501, 0.4731, 0.4468, 0.4216, 0.3982, 0.3757, 0.3547, 0.335, 0.3162, 0.2986, 0.2817, 0.266, 0.2512, 0.2372, 0.2239, 0.2113, 0.1996, 0.1884, 0.1779, 0.1678, 0.1584, 0.1497, 0.1412, 0.1333, 0.1259, 0.11884, 0.1122, 0.1059, 0.1, 0.0944, 0.0891, 0.08417, 0.0794, 0.075, 0.0708, 0.06683, 0.0631, 0.05957, 0.05624, 0.0531, 0.0501, 0.0473, 0.04468, 0.04218, 0.03983, 0.0376, 0.0355, 0.0335, 0.03162, 0.02986, 0.02818, 0.02661, 0.02512, 0.02371, 0.02238, 0.02113, 0.01996, 0.01883, 0.01778, 0.01678, 0.01585, 0.01496, 0.01412, 0.013336, 0.01259, 0.01189, 0.01122, 0.01059, 0.01, 0.00944, 0.00891, 0.008415, 0.00794, 0.0075, 0.00708, 0.006683, 0.00631, 0.005955, 0.005623, 0.00531, 0.005013, 0.00473, 0.004467, 0.004215, 0.003983, 0.003757, 0.003548, 0.00335, 0.003162, 0.002985, 0.00282, 0.00266, 0.002512, 0.00237, 0.00224, 0.002113, 0.001995, 0.0018835, 0.001779, 0.001678, 0.001585, 0.001496, 0.001412, 0.001333, 0.001259, 0.001188, 0.001122, 0.00106, 0.001, 0.000944, 0.000891, 0.0008416, 0.0007944, 0.00075, 0.000708, 0.0006685, 0.000631, 0.0005956, 0.000562, 0.0005307, 0.000501, 0.0004733, 0.0004468, 0.0004218, 0.0003982, 0.0003757, 0.0003548, 0.000335, 0.0003161, 0.0002985, 0.0002818, 0.000266, 0.0002513, 0.0002371, 0.0002239, 0.0002114, 0.0001996, 0.0001884, 0.0001779, 0.0001678, 0.0001585, 0.0001496, 0.0001413, 0.0001334, 0.0001259, 0.00011885, 0.0001122, 0.0001059]).astype('float16')
38+
if ".ff." in k:
39+
pp = torch_weights['state_dict'][k]
40+
torch_weights['state_dict'][k + "._split_1"] = pp[:pp.shape[0]//2].copy()
41+
torch_weights['state_dict'][k + "._split_2"] = pp[pp.shape[0]//2:].copy()
5542

56-
extra_keys = ['temb_coefficients_fp32', 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod']
5743

44+
for i in range(1,21):
45+
nn = 320*i
46+
dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype
47+
torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
48+
torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype)
5849

50+
nn = 128*i
51+
dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype
52+
torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
53+
torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype)
5954

60-
for k in (list(torch_weights['state_dict'].keys())):
61-
if '.norm' in k and '.bias' in k:
62-
k2 = k.replace(".bias" , ".weight")
63-
k3 = k.replace(".bias" , ".bias_by_weight")
64-
torch_weights['state_dict'][k3] = torch_weights['state_dict'][k]/torch_weights['state_dict'][k2]
6555

66-
if ".ff." in k:
67-
pp = torch_weights['state_dict'][k]
68-
torch_weights['state_dict'][k + "._split_1"] = pp[:pp.shape[0]//2].copy()
69-
torch_weights['state_dict'][k + "._split_2"] = pp[pp.shape[0]//2:].copy()
56+
model_type = get_model_type(torch_weights['state_dict'])
7057

58+
if model_type is None:
59+
raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5 .ckpt file")
7160

72-
for i in range(1,21):
73-
nn = 320*i
74-
dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype
75-
torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
76-
torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype)
61+
print("model type " , model_type)
7762

78-
nn = 128*i
79-
dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype
80-
torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype)
81-
torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype)
63+
model_shapes = possible_model_shapes[model_type]
64+
ctdict_id = ctdict_ids[model_type]
8265

66+
outfile = TDict(fpath=out_filename)
8367

84-
model_type = get_model_type(torch_weights['state_dict'])
68+
outfile.init_write(ctdict_version=ctdict_id )
8569

86-
if model_type is None:
87-
raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5 .ckpt file")
70+
for k in model_shapes:
71+
np_arr = torch_weights['state_dict'][k]
72+
shape = list(np_arr.shape)
73+
assert tuple(shape) == tuple(model_shapes[k]), ( "shape mismatch at" , k , shape , SD_SHAPES[k] )
74+
outfile.write_key(key=k , tensor=np_arr)
8875

89-
print("model type " , model_type)
76+
outfile.finish_write()
9077

91-
model_shapes = possible_model_shapes[model_type]
92-
ctdict_id = ctdict_ids[model_type]
9378

94-
outfile = TDict(fpath=out_filename)
79+
def usage():
80+
print("\nConverts .cpkt model files into .tdict model files for Diffusion Bee")
81+
print("\npython3 convert_py input.ckpt output.tdict")
82+
print("\tNormal use.")
83+
print("\n\tPlease report any errors on the Diffusion Bee GitHub project or the official Discord server.")
84+
print("\npython3 convert_py --help")
85+
print("\tDisplays this message")
86+
87+
88+
if __name__ == "__main__":
89+
try:
90+
optlist, args = getopt.getopt(sys.argv[1:], "hu", ["help" ])
91+
except getopt.GetoptError as err:
92+
print(err)
93+
sys.exit(2)
94+
for o, a in optlist:
95+
if o in ("-h", "--help"):
96+
usage()
97+
sys.exit()
98+
else:
99+
assert False, "unhandled option"
100+
101+
if len(args) != 2:
102+
print("Incorrect number of arguments")
103+
usage()
104+
sys.exit(2)
95105

96-
outfile.init_write(ctdict_version=ctdict_id )
106+
checkpoint_filename = args[0]
107+
out_filename = args[1]
97108

98-
for k in model_shapes:
99-
np_arr = torch_weights['state_dict'][k]
100-
shape = list(np_arr.shape)
101-
assert tuple(shape) == tuple(model_shapes[k]), ( "shape mismatch at" , k , shape , SD_SHAPES[k] )
102-
outfile.write_key(key=k , tensor=np_arr)
103109

104-
outfile.finish_write()

backends/stable_diffusion/diffusionbee_backend.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
import time
1414
import traceback
1515
from stable_diffusion import StableDiffusion
16-
16+
import os
1717
# b2py t2im {"prompt": "sun glasses" , "W":640 , "H" : 640 , "num_imgs" : 10 , "input_image":"/Users/divamgupta/Downloads/inn.png" , "mask_image" : "/Users/divamgupta/Downloads/maa.png" , "is_inpaint":true }
1818

1919

20+
dir_path = os.path.dirname(os.path.realpath(__file__))
21+
sys.path.append(os.path.join(dir_path , "../model_converter"))
22+
from convert_model import convert_model
23+
2024

2125
from pathlib import Path
2226
import os
@@ -149,7 +153,7 @@ def process_opt(d, generator):
149153

150154

151155

152-
def main():
156+
def diffusion_bee_main():
153157

154158

155159
global p_14 , p_14_np
@@ -205,4 +209,12 @@ def callback(state="" , progress=-1):
205209

206210
if __name__ == "__main__":
207211
multiprocessing.freeze_support() # for pyinstaller
208-
main()
212+
213+
if len(sys.argv) > 1 and sys.argv[1] == 'convert_model':
214+
checkpoint_filename = sys.argv[2]
215+
out_filename = sys.argv[3]
216+
convert_model(checkpoint_filename, out_filename )
217+
print("model converted ")
218+
else:
219+
diffusion_bee_main()
220+

0 commit comments

Comments
 (0)