|
8 | 8 | # pyinstaller convert_model.py --onefile --noconfirm --clean # build using intel machine so that its cross platform lol |
9 | 9 |
|
10 | 10 |
|
11 | | -try: |
12 | | - optlist, args = getopt.getopt(sys.argv[1:], "hu", ["help" ]) |
13 | | -except getopt.GetoptError as err: |
14 | | - print(err) |
15 | | - sys.exit(2) |
16 | | -for o, a in optlist: |
17 | | - if o in ("-h", "--help"): |
18 | | - usage() |
19 | | - sys.exit() |
20 | | - else: |
21 | | - assert False, "unhandled option" |
22 | 11 |
|
23 | | -def usage(): |
24 | | - print("\nConverts .cpkt model files into .tdict model files for Diffusion Bee") |
25 | | - print("\npython3 convert_py input.ckpt output.tdict") |
26 | | - print("\tNormal use.") |
27 | | - print("\n\tPlease report any errors on the Diffusion Bee GitHub project or the official Discord server.") |
28 | | - print("\npython3 convert_py --help") |
29 | | - print("\tDisplays this message") |
| 12 | +from fake_torch import extract_weights_from_checkpoint |
| 13 | +from sd_shapes import get_model_type , possible_model_shapes , ctdict_ids |
| 14 | +from tdict import TDict |
30 | 15 |
|
31 | | -if len(args) != 2: |
32 | | - print("Incorrect number of arguments") |
33 | | - usage() |
34 | | - sys.exit(2) |
35 | 16 |
|
36 | | -checkpoint_filename = args[0] |
37 | | -out_filename = args[1] |
38 | 17 |
|
| 18 | +def convert_model(checkpoint_filename, out_filename ): |
| 19 | + torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb")) |
39 | 20 |
|
40 | | -from fake_torch import extract_weights_from_checkpoint |
41 | | -torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb")) |
| 21 | + torch_weights['state_dict']['temb_coefficients_fp32'] = np.array([1.0, 0.94406086, 0.8912509, 0.84139514, 0.7943282, 0.7498942, 0.70794576, 0.6683439, 0.63095737, 0.5956621, 0.56234133, 0.53088444, 0.50118726, 0.47315124, 0.44668362, 0.42169648, 0.39810717, 0.3758374, 0.35481337, 0.33496544, 0.31622776, 0.29853827, 0.2818383, 0.2660725, 0.25118864, 0.23713735, 0.2238721, 0.21134889, 0.19952625, 0.18836491, 0.17782794, 0.16788039, 0.15848932, 0.14962357, 0.14125374, 0.13335215, 0.12589253, 0.11885023, 0.11220184, 0.10592538, 0.099999994, 0.094406076, 0.08912509, 0.0841395, 0.07943282, 0.074989416, 0.07079458, 0.06683439, 0.06309574, 0.05956621, 0.056234125, 0.05308844, 0.05011872, 0.047315124, 0.044668354, 0.04216965, 0.039810725, 0.037583742, 0.035481337, 0.033496536, 0.031622775, 0.029853828, 0.028183822, 0.026607247, 0.025118865, 0.02371374, 0.022387212, 0.021134889, 0.01995262, 0.018836489, 0.017782794, 0.01678804, 0.015848929, 0.014962353, 0.014125377, 0.013335214, 0.012589253, 0.01188502, 0.011220186, 0.010592538, 0.01, 0.009440607, 0.0089125065, 0.008413952, 0.007943282, 0.007498941, 0.007079456, 0.00668344, 0.0063095735, 0.005956621, 0.0056234123, 0.0053088428, 0.005011873, 0.0047315126, 0.0044668354, 0.004216964, 0.003981072, 0.0037583741, 0.0035481334, 0.0033496537, 0.0031622767, 0.0029853827, 0.0028183828, 0.0026607246, 0.0025118857, 0.0023713738, 0.0022387211, 0.0021134887, 0.0019952618, 0.0018836485, 0.0017782794, 0.0016788039, 0.0015848937, 0.001496236, 0.0014125376, 0.0013335207, 0.0012589252, 0.001188502, 0.0011220181, 0.0010592537, 0.0009999999, 0.00094406115, 0.0008912511, 0.0008413952, 0.0007943278, 0.00074989407, 0.0007079456, 0.0006683437, 0.00063095737, 0.0005956621, 0.0005623415, 0.00053088454, 0.0005011872, 0.000473151, 0.00044668352, 0.00042169637, 0.00039810702, 0.0003758374, 0.00035481335, 0.00033496553, 0.00031622782, 0.00029853827, 0.00028183826, 0.00026607246, 0.00025118855, 0.00023713727, 0.00022387199, 0.00021134898, 0.00019952627, 0.00018836492, 0.00017782794, 0.00016788038, 0.00015848929, 0.00014962352, 0.0001412537, 0.00013335208, 0.00012589258, 0.00011885024, 0.00011220186, 0.00010592537]).astype('float32') |
| 22 | + torch_weights['state_dict']['causal_mask'] = np.triu(np.ones((1,1,77,77), dtype=np.float16) * -65500.0, k=1).astype(np.float32) |
| 23 | + torch_weights['state_dict']['aux_output_conv.weight'] = np.array([0.14013671875, 0.0711669921875, -0.03271484375, -0.11407470703125, 0.126220703125, 0.10101318359375, 0.034515380859375, -0.1383056640625, 0.126220703125, 0.07733154296875, 0.042633056640625, -0.177978515625]).astype(np.float32) |
| 24 | + torch_weights['state_dict']['aux_output_conv.bias'] = np.array([0.423828125, 0.471923828125, 0.473876953125]).astype(np.float32) |
| 25 | + torch_weights['state_dict']['alphas_cumprod'] = np.array(_ALPHAS_CUMPROD).astype(np.float32) |
| 26 | + torch_weights['state_dict']['temb_coefficients_fp16'] = np.array( [1.0, 0.944, 0.891, 0.8413, 0.7944, 0.75, 0.708, 0.6685, 0.631, 0.5957, 0.5625, 0.531, 0.501, 0.4731, 0.4468, 0.4216, 0.3982, 0.3757, 0.3547, 0.335, 0.3162, 0.2986, 0.2817, 0.266, 0.2512, 0.2372, 0.2239, 0.2113, 0.1996, 0.1884, 0.1779, 0.1678, 0.1584, 0.1497, 0.1412, 0.1333, 0.1259, 0.11884, 0.1122, 0.1059, 0.1, 0.0944, 0.0891, 0.08417, 0.0794, 0.075, 0.0708, 0.06683, 0.0631, 0.05957, 0.05624, 0.0531, 0.0501, 0.0473, 0.04468, 0.04218, 0.03983, 0.0376, 0.0355, 0.0335, 0.03162, 0.02986, 0.02818, 0.02661, 0.02512, 0.02371, 0.02238, 0.02113, 0.01996, 0.01883, 0.01778, 0.01678, 0.01585, 0.01496, 0.01412, 0.013336, 0.01259, 0.01189, 0.01122, 0.01059, 0.01, 0.00944, 0.00891, 0.008415, 0.00794, 0.0075, 0.00708, 0.006683, 0.00631, 0.005955, 0.005623, 0.00531, 0.005013, 0.00473, 0.004467, 0.004215, 0.003983, 0.003757, 0.003548, 0.00335, 0.003162, 0.002985, 0.00282, 0.00266, 0.002512, 0.00237, 0.00224, 0.002113, 0.001995, 0.0018835, 0.001779, 0.001678, 0.001585, 0.001496, 0.001412, 0.001333, 0.001259, 0.001188, 0.001122, 0.00106, 0.001, 0.000944, 0.000891, 0.0008416, 0.0007944, 0.00075, 0.000708, 0.0006685, 0.000631, 0.0005956, 0.000562, 0.0005307, 0.000501, 0.0004733, 0.0004468, 0.0004218, 0.0003982, 0.0003757, 0.0003548, 0.000335, 0.0003161, 0.0002985, 0.0002818, 0.000266, 0.0002513, 0.0002371, 0.0002239, 0.0002114, 0.0001996, 0.0001884, 0.0001779, 0.0001678, 0.0001585, 0.0001496, 0.0001413, 0.0001334, 0.0001259, 0.00011885, 0.0001122, 0.0001059]).astype('float16') |
42 | 27 |
|
| 28 | + extra_keys = ['temb_coefficients_fp32', 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod'] |
43 | 29 |
|
44 | | -from sd_shapes import get_model_type , possible_model_shapes , ctdict_ids |
45 | | -from tdict import TDict |
46 | 30 |
|
47 | 31 |
|
| 32 | + for k in (list(torch_weights['state_dict'].keys())): |
| 33 | + if '.norm' in k and '.bias' in k: |
| 34 | + k2 = k.replace(".bias" , ".weight") |
| 35 | + k3 = k.replace(".bias" , ".bias_by_weight") |
| 36 | + torch_weights['state_dict'][k3] = torch_weights['state_dict'][k]/torch_weights['state_dict'][k2] |
48 | 37 |
|
49 | | -torch_weights['state_dict']['temb_coefficients_fp32'] = np.array([1.0, 0.94406086, 0.8912509, 0.84139514, 0.7943282, 0.7498942, 0.70794576, 0.6683439, 0.63095737, 0.5956621, 0.56234133, 0.53088444, 0.50118726, 0.47315124, 0.44668362, 0.42169648, 0.39810717, 0.3758374, 0.35481337, 0.33496544, 0.31622776, 0.29853827, 0.2818383, 0.2660725, 0.25118864, 0.23713735, 0.2238721, 0.21134889, 0.19952625, 0.18836491, 0.17782794, 0.16788039, 0.15848932, 0.14962357, 0.14125374, 0.13335215, 0.12589253, 0.11885023, 0.11220184, 0.10592538, 0.099999994, 0.094406076, 0.08912509, 0.0841395, 0.07943282, 0.074989416, 0.07079458, 0.06683439, 0.06309574, 0.05956621, 0.056234125, 0.05308844, 0.05011872, 0.047315124, 0.044668354, 0.04216965, 0.039810725, 0.037583742, 0.035481337, 0.033496536, 0.031622775, 0.029853828, 0.028183822, 0.026607247, 0.025118865, 0.02371374, 0.022387212, 0.021134889, 0.01995262, 0.018836489, 0.017782794, 0.01678804, 0.015848929, 0.014962353, 0.014125377, 0.013335214, 0.012589253, 0.01188502, 0.011220186, 0.010592538, 0.01, 0.009440607, 0.0089125065, 0.008413952, 0.007943282, 0.007498941, 0.007079456, 0.00668344, 0.0063095735, 0.005956621, 0.0056234123, 0.0053088428, 0.005011873, 0.0047315126, 0.0044668354, 0.004216964, 0.003981072, 0.0037583741, 0.0035481334, 0.0033496537, 0.0031622767, 0.0029853827, 0.0028183828, 0.0026607246, 0.0025118857, 0.0023713738, 0.0022387211, 0.0021134887, 0.0019952618, 0.0018836485, 0.0017782794, 0.0016788039, 0.0015848937, 0.001496236, 0.0014125376, 0.0013335207, 0.0012589252, 0.001188502, 0.0011220181, 0.0010592537, 0.0009999999, 0.00094406115, 0.0008912511, 0.0008413952, 0.0007943278, 0.00074989407, 0.0007079456, 0.0006683437, 0.00063095737, 0.0005956621, 0.0005623415, 0.00053088454, 0.0005011872, 0.000473151, 0.00044668352, 0.00042169637, 0.00039810702, 0.0003758374, 0.00035481335, 0.00033496553, 0.00031622782, 0.00029853827, 0.00028183826, 0.00026607246, 0.00025118855, 0.00023713727, 0.00022387199, 0.00021134898, 0.00019952627, 0.00018836492, 0.00017782794, 0.00016788038, 0.00015848929, 0.00014962352, 0.0001412537, 0.00013335208, 0.00012589258, 0.00011885024, 0.00011220186, 0.00010592537]).astype('float32') |
50 | | -torch_weights['state_dict']['causal_mask'] = np.triu(np.ones((1,1,77,77), dtype=np.float16) * -65500.0, k=1).astype(np.float32) |
51 | | -torch_weights['state_dict']['aux_output_conv.weight'] = np.array([0.14013671875, 0.0711669921875, -0.03271484375, -0.11407470703125, 0.126220703125, 0.10101318359375, 0.034515380859375, -0.1383056640625, 0.126220703125, 0.07733154296875, 0.042633056640625, -0.177978515625]).astype(np.float32) |
52 | | -torch_weights['state_dict']['aux_output_conv.bias'] = np.array([0.423828125, 0.471923828125, 0.473876953125]).astype(np.float32) |
53 | | -torch_weights['state_dict']['alphas_cumprod'] = np.array(_ALPHAS_CUMPROD).astype(np.float32) |
54 | | -torch_weights['state_dict']['temb_coefficients_fp16'] = np.array( [1.0, 0.944, 0.891, 0.8413, 0.7944, 0.75, 0.708, 0.6685, 0.631, 0.5957, 0.5625, 0.531, 0.501, 0.4731, 0.4468, 0.4216, 0.3982, 0.3757, 0.3547, 0.335, 0.3162, 0.2986, 0.2817, 0.266, 0.2512, 0.2372, 0.2239, 0.2113, 0.1996, 0.1884, 0.1779, 0.1678, 0.1584, 0.1497, 0.1412, 0.1333, 0.1259, 0.11884, 0.1122, 0.1059, 0.1, 0.0944, 0.0891, 0.08417, 0.0794, 0.075, 0.0708, 0.06683, 0.0631, 0.05957, 0.05624, 0.0531, 0.0501, 0.0473, 0.04468, 0.04218, 0.03983, 0.0376, 0.0355, 0.0335, 0.03162, 0.02986, 0.02818, 0.02661, 0.02512, 0.02371, 0.02238, 0.02113, 0.01996, 0.01883, 0.01778, 0.01678, 0.01585, 0.01496, 0.01412, 0.013336, 0.01259, 0.01189, 0.01122, 0.01059, 0.01, 0.00944, 0.00891, 0.008415, 0.00794, 0.0075, 0.00708, 0.006683, 0.00631, 0.005955, 0.005623, 0.00531, 0.005013, 0.00473, 0.004467, 0.004215, 0.003983, 0.003757, 0.003548, 0.00335, 0.003162, 0.002985, 0.00282, 0.00266, 0.002512, 0.00237, 0.00224, 0.002113, 0.001995, 0.0018835, 0.001779, 0.001678, 0.001585, 0.001496, 0.001412, 0.001333, 0.001259, 0.001188, 0.001122, 0.00106, 0.001, 0.000944, 0.000891, 0.0008416, 0.0007944, 0.00075, 0.000708, 0.0006685, 0.000631, 0.0005956, 0.000562, 0.0005307, 0.000501, 0.0004733, 0.0004468, 0.0004218, 0.0003982, 0.0003757, 0.0003548, 0.000335, 0.0003161, 0.0002985, 0.0002818, 0.000266, 0.0002513, 0.0002371, 0.0002239, 0.0002114, 0.0001996, 0.0001884, 0.0001779, 0.0001678, 0.0001585, 0.0001496, 0.0001413, 0.0001334, 0.0001259, 0.00011885, 0.0001122, 0.0001059]).astype('float16') |
| 38 | + if ".ff." in k: |
| 39 | + pp = torch_weights['state_dict'][k] |
| 40 | + torch_weights['state_dict'][k + "._split_1"] = pp[:pp.shape[0]//2].copy() |
| 41 | + torch_weights['state_dict'][k + "._split_2"] = pp[pp.shape[0]//2:].copy() |
55 | 42 |
|
56 | | -extra_keys = ['temb_coefficients_fp32', 'temb_coefficients_fp16' , 'causal_mask' , 'aux_output_conv.weight' , 'aux_output_conv.bias', 'alphas_cumprod'] |
57 | 43 |
|
| 44 | + for i in range(1,21): |
| 45 | + nn = 320*i |
| 46 | + dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype |
| 47 | + torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype) |
| 48 | + torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype) |
58 | 49 |
|
| 50 | + nn = 128*i |
| 51 | + dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype |
| 52 | + torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype) |
| 53 | + torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype) |
59 | 54 |
|
60 | | -for k in (list(torch_weights['state_dict'].keys())): |
61 | | - if '.norm' in k and '.bias' in k: |
62 | | - k2 = k.replace(".bias" , ".weight") |
63 | | - k3 = k.replace(".bias" , ".bias_by_weight") |
64 | | - torch_weights['state_dict'][k3] = torch_weights['state_dict'][k]/torch_weights['state_dict'][k2] |
65 | 55 |
|
66 | | - if ".ff." in k: |
67 | | - pp = torch_weights['state_dict'][k] |
68 | | - torch_weights['state_dict'][k + "._split_1"] = pp[:pp.shape[0]//2].copy() |
69 | | - torch_weights['state_dict'][k + "._split_2"] = pp[pp.shape[0]//2:].copy() |
| 56 | + model_type = get_model_type(torch_weights['state_dict']) |
70 | 57 |
|
| 58 | + if model_type is None: |
| 59 | + raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5 .ckpt file") |
71 | 60 |
|
72 | | -for i in range(1,21): |
73 | | - nn = 320*i |
74 | | - dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype |
75 | | - torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype) |
76 | | - torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype) |
| 61 | + print("model type " , model_type) |
77 | 62 |
|
78 | | - nn = 128*i |
79 | | - dtype = torch_weights['state_dict']['model.diffusion_model.input_blocks.1.0.in_layers.0.weight'].dtype |
80 | | - torch_weights['state_dict']["zeros_"+str(nn)] = np.zeros(nn).astype(dtype) |
81 | | - torch_weights['state_dict']["ones_"+str(nn)] = np.ones(nn).astype(dtype) |
| 63 | + model_shapes = possible_model_shapes[model_type] |
| 64 | + ctdict_id = ctdict_ids[model_type] |
82 | 65 |
|
| 66 | + outfile = TDict(fpath=out_filename) |
83 | 67 |
|
84 | | -model_type = get_model_type(torch_weights['state_dict']) |
| 68 | + outfile.init_write(ctdict_version=ctdict_id ) |
85 | 69 |
|
86 | | -if model_type is None: |
87 | | - raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5 .ckpt file") |
| 70 | + for k in model_shapes: |
| 71 | + np_arr = torch_weights['state_dict'][k] |
| 72 | + shape = list(np_arr.shape) |
| 73 | + assert tuple(shape) == tuple(model_shapes[k]), ( "shape mismatch at" , k , shape , SD_SHAPES[k] ) |
| 74 | + outfile.write_key(key=k , tensor=np_arr) |
88 | 75 |
|
89 | | -print("model type " , model_type) |
| 76 | + outfile.finish_write() |
90 | 77 |
|
91 | | -model_shapes = possible_model_shapes[model_type] |
92 | | -ctdict_id = ctdict_ids[model_type] |
93 | 78 |
|
94 | | -outfile = TDict(fpath=out_filename) |
| 79 | +def usage(): |
| 80 | + print("\nConverts .cpkt model files into .tdict model files for Diffusion Bee") |
| 81 | + print("\npython3 convert_py input.ckpt output.tdict") |
| 82 | + print("\tNormal use.") |
| 83 | + print("\n\tPlease report any errors on the Diffusion Bee GitHub project or the official Discord server.") |
| 84 | + print("\npython3 convert_py --help") |
| 85 | + print("\tDisplays this message") |
| 86 | + |
| 87 | + |
| 88 | +if __name__ == "__main__": |
| 89 | + try: |
| 90 | + optlist, args = getopt.getopt(sys.argv[1:], "hu", ["help" ]) |
| 91 | + except getopt.GetoptError as err: |
| 92 | + print(err) |
| 93 | + sys.exit(2) |
| 94 | + for o, a in optlist: |
| 95 | + if o in ("-h", "--help"): |
| 96 | + usage() |
| 97 | + sys.exit() |
| 98 | + else: |
| 99 | + assert False, "unhandled option" |
| 100 | + |
| 101 | + if len(args) != 2: |
| 102 | + print("Incorrect number of arguments") |
| 103 | + usage() |
| 104 | + sys.exit(2) |
95 | 105 |
|
96 | | -outfile.init_write(ctdict_version=ctdict_id ) |
| 106 | + checkpoint_filename = args[0] |
| 107 | + out_filename = args[1] |
97 | 108 |
|
98 | | -for k in model_shapes: |
99 | | - np_arr = torch_weights['state_dict'][k] |
100 | | - shape = list(np_arr.shape) |
101 | | - assert tuple(shape) == tuple(model_shapes[k]), ( "shape mismatch at" , k , shape , SD_SHAPES[k] ) |
102 | | - outfile.write_key(key=k , tensor=np_arr) |
103 | 109 |
|
104 | | -outfile.finish_write() |
|
0 commit comments