Skip to content

Commit 5f90f8b

Browse files
committed
added safetensors support
1 parent 6d3a5e3 commit 5f90f8b

File tree

8 files changed

+121
-24
lines changed

8 files changed

+121
-24
lines changed

backends/model_converter/convert_model.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,21 @@
88
# pyinstaller convert_model.py --onefile --noconfirm --clean # build using intel machine so that its cross platform lol
99

1010

11-
11+
from safetensor_wrapper import SafetensorWrapper
1212
from fake_torch import extract_weights_from_checkpoint
1313
from sd_shapes import get_model_type , possible_model_shapes , ctdict_ids
1414
from tdict import TDict
1515

1616

1717

1818
def convert_model(checkpoint_filename, out_filename ):
19-
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
19+
20+
if checkpoint_filename.lower().endswith(".ckpt"):
21+
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
22+
elif checkpoint_filename.lower().endswith(".safetensors"):
23+
torch_weights = SafetensorWrapper(checkpoint_filename)
24+
else:
25+
raise ValueError("Invalid import format")
2026

2127
if 'state_dict' in torch_weights:
2228
state_dict = torch_weights['state_dict']
@@ -65,6 +71,13 @@ def convert_model(checkpoint_filename, out_filename ):
6571
if model_type is None:
6672
raise ValueError("The model is not supported. Please make sure it is a valid SD 1.4/1.5 .ckpt file")
6773

74+
if "float16" in model_type:
75+
cur_dtype = "float16"
76+
elif "float32" in model_type:
77+
cur_dtype = "float32"
78+
else:
79+
assert False
80+
6881
print("model type " , model_type)
6982

7083
model_shapes = possible_model_shapes[model_type]
@@ -76,6 +89,8 @@ def convert_model(checkpoint_filename, out_filename ):
7689

7790
for k in model_shapes:
7891
np_arr = state_dict[k]
92+
if "float" in str(np_arr.dtype):
93+
np_arr = np_arr.astype(cur_dtype)
7994
shape = list(np_arr.shape)
8095
assert tuple(shape) == tuple(model_shapes[k]), ( "shape mismatch at" , k , shape , SD_SHAPES[k] )
8196
outfile.write_key(key=k , tensor=np_arr)
@@ -113,4 +128,6 @@ def usage():
113128
checkpoint_filename = args[0]
114129
out_filename = args[1]
115130

131+
convert_model(checkpoint_filename , out_filename )
132+
116133

backends/model_converter/fake_torch.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,17 @@ def examine_pickle(fb0, return_special=False):
140140
## 3: this massive line also assigns values to keys, but does so differently
141141
## _var2262.update({ 'cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias': _var2001, [ .... and on and on ]})
142142
##
143+
## 4: in some pruned models, the last line is instead a combination of 2/3 into the final variable:
144+
## result = {'model.diffusion_model.input_blocks.0.0.weight': _var1, 'model.diffusion_model.input_blocks.0.0.bias': _var3, }
145+
##
143146
## that's it
144147

145148
# make some REs to match the above.
146149
re_rebuild = re.compile('^_var\d+ = _rebuild_tensor_v2\(UNPICKLER\.persistent_load\(\(.*\)$')
147150
re_assign = re.compile('^_var\d+ = \{.*\}$')
148151
re_update = re.compile('^_var\d+\.update\(\{.*\}\)$')
149152
re_ordered_dict = re.compile('^_var\d+ = OrderedDict\(\)$')
153+
re_result = re.compile('^result = \{.*\}$')
150154

151155
load_instructions = {}
152156
assign_instructions = AssignInstructions()
@@ -157,7 +161,7 @@ def examine_pickle(fb0, return_special=False):
157161
if re_rebuild.match(line):
158162
variable_name, load_instruction = line.split(' = ', 1)
159163
load_instructions[variable_name] = LoadInstruction(line, variable_name)
160-
elif re_assign.match(line):
164+
elif re_assign.match(line) or re_result.match(line):
161165
assign_instructions.parse_assign_line(line)
162166
elif re_update.match(line):
163167
assign_instructions.parse_update_line(line)
@@ -184,11 +188,34 @@ def __init__(self, collect_special=False):
184188
self.integrated_instructions = {}
185189
self.collect_special = collect_special;
186190

191+
def parse_result_line(self, line):
192+
garbage, huge_mess = line.split(' = {', 1)
193+
assignments = huge_mess.split(', ')
194+
del huge_mess
195+
assignments[-1] = assignments[-1].strip('}')
196+
197+
#compile RE here to avoid doing it every loop iteration:
198+
re_var = re.compile('^_var\d+$')
199+
200+
assignment_count = 0
201+
for a in assignments:
202+
if self._add_assignment(a, re_var):
203+
assignment_count = assignment_count + 1
204+
if NO_PICKLE_DEBUG:
205+
print(f"Added/merged {assignment_count} assignments. Total of {len(self.instructions)} assignment instructions")
206+
187207
def parse_assign_line(self, line):
188208
# input looks like this:
189209
# _var2262 = {'model.diffusion_model.input_blocks.0.0.weight': _var1, 'model.diffusion_model.input_blocks.0.0.bias': _var3,\
190210
# ...\
191211
# 'cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight': _var1999}
212+
213+
# input looks like the above, but with 'result' in place of _var2262:
214+
# result = {'model.diffusion_model.input_blocks.0.0.weight': _var1, ... }
215+
#
216+
# or also look like:
217+
# result = {'state_dict': _var2314}
218+
# ... which will be ignored later
192219
garbage, huge_mess = line.split(' = {', 1)
193220
assignments = huge_mess.split(', ')
194221
del huge_mess
@@ -211,7 +238,7 @@ def _add_assignment(self, assignment, re_var):
211238
# 'embedding_manager.embedder.transformer.text_model.encoder.layers.6.mlp.fc1': {'version': 1}
212239
sd_key, fickling_var = assignment.split(': ', 1)
213240
sd_key = sd_key.strip("'")
214-
if re_var.match(fickling_var):
241+
if sd_key != 'state_dict' and re_var.match(fickling_var):
215242
self.instructions[sd_key] = fickling_var
216243
return True
217244
elif self.collect_special:
@@ -225,7 +252,8 @@ def _add_assignment(self, assignment, re_var):
225252
v = v.strip("'")
226253
special_dict[k] = v
227254
self.special_instructions[sd_key] = special_dict
228-
return False
255+
256+
return False
229257

230258
def integrate(self, load_instructions):
231259
unfound_keys = {}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from safetensors import safe_open
2+
3+
4+
class SafetensorWrapper:
5+
6+
def __init__(self , fname ):
7+
self.file = safe_open(fname, framework="np", device="cpu")
8+
self.new_items = {}
9+
10+
def keys(self):
11+
return list(self.file.keys()) + list(self.new_items.keys())
12+
13+
def __contains__(self, k):
14+
if k in self.file.keys():
15+
return True
16+
if k in self.new_items:
17+
return True
18+
return False
19+
20+
def __getitem__(self , k):
21+
if k in self.new_items:
22+
return self.new_items[k]
23+
else:
24+
return self.file.get_tensor(k)
25+
26+
def __setitem__(self, key , item ):
27+
self.new_items[key] = item
28+
29+
def __iter__(self):
30+
return iter(self.keys())

backends/model_converter/sd_shapes.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
from sd_shapes_consts import shapes_unet , shapes_encoder, shapes_decoder , shapes_text_encoder, shapes_params
33
import copy
4-
4+
from collections import Counter
55

66
def add_aux_shapes(d):
77
for k in list(d.keys()):
@@ -57,7 +57,7 @@ def add_aux_shapes(d):
5757

5858

5959

60-
def are_shapes_matching(state_dict , template_shapes):
60+
def are_shapes_matching(state_dict , template_shapes , name=None):
6161
for k in template_shapes:
6262
if k not in state_dict:
6363
print("key", k , "not found in state_dict" , state_dict.keys())
@@ -68,28 +68,50 @@ def are_shapes_matching(state_dict , template_shapes):
6868

6969
return True
7070

71-
def are_shapes_dtype(state_dict, template_shapes , dtype):
71+
72+
def get_dtype(state_dict, template_shapes ):
73+
c = Counter()
74+
7275
for k in state_dict:
7376
if k in extra_keys:
7477
continue
7578
if k not in template_shapes:
7679
continue
77-
if state_dict[k].dtype != dtype:
78-
return False
7980

80-
return True
81+
if 'float' in str(state_dict[k].dtype):
82+
c[ str(state_dict[k].dtype)] += 1
83+
print(c.most_common())
84+
return c.most_common(1)[0][0]
85+
86+
87+
88+
def check_shapes_float(state_dict, template_shapes ):
89+
for k in state_dict:
90+
if k in extra_keys:
91+
continue
92+
if k not in template_shapes:
93+
continue
94+
95+
assert 'float' in str(state_dict[k].dtype )
96+
8197

8298

8399
def get_model_type(state_dict):
84-
if are_shapes_matching(state_dict , sd_1x_shapes) and are_shapes_dtype(state_dict , sd_1x_shapes, "float32"):
85-
return "SD_1x_float32"
86-
elif are_shapes_matching(state_dict , sd_1x_inpaint_shapes) and are_shapes_dtype(state_dict , sd_1x_inpaint_shapes , "float32"):
87-
return "SD_1x_inpaint_float32"
88-
elif are_shapes_matching(state_dict , sd_1x_shapes) and are_shapes_dtype(state_dict , sd_1x_shapes , "float16"):
89-
return "SD_1x_float16"
90-
elif are_shapes_matching(state_dict , sd_1x_inpaint_shapes) and are_shapes_dtype(state_dict , sd_1x_inpaint_shapes, "float16"):
91-
return "SD_1x_inpaint_float16"
100+
101+
if are_shapes_matching(state_dict , sd_1x_shapes) :
102+
shapes = sd_1x_shapes
103+
mname = "SD_1x"
104+
elif are_shapes_matching(state_dict , sd_1x_inpaint_shapes) :
105+
shapes = sd_1x_inpaint_shapes
106+
mname = "SD_1x_inpaint"
92107
else:
93108
return None
94109

110+
check_shapes_float(state_dict , shapes)
111+
c_dtype = get_dtype(state_dict , shapes)
112+
if c_dtype not in ["float32" , "float16"]:
113+
raise ValueError("The weights should either be float32 or float16, but these are " + c_dtype)
114+
115+
return mname + "_" + c_dtype
116+
95117

File renamed without changes.

electron_app/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "DiffusionBee",
3-
"version": "1.6.0",
4-
"build_number": "0017",
3+
"version": "1.7.0",
4+
"build_number": "0018",
55
"website": "https://diffusionbee.com",
66
"description": "Diffusion Bee - Stable Diffusion App.",
77
"is_dev": false,

electron_app/src/components/Settings.vue

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ export default {
8181
add_model(){
8282
let that = this;
8383
84-
let pytorch_model_path = window.ipcRenderer.sendSync('file_dialog', "ckpt_file" );
84+
let pytorch_model_path = window.ipcRenderer.sendSync('file_dialog', "weights_file" );
8585
if(!pytorch_model_path)
8686
return;
8787

electron_app/src/native_functions.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ ipcMain.on('file_dialog', (event, arg) => {
6161
properties = ['openFile' ]
6262
options = { filters :[ {name: 'Images', extensions: ['jpg', 'jpeg', 'png', 'bmp']}] , properties: properties } ;
6363
}
64-
else if(arg == 'ckpt_file') // single image file
64+
else if(arg == 'weights_file') // single image file
6565
{
6666
properties = ['openFile' ]
67-
options = { filters :[ {name: 'Checkpoints', extensions: ['ckpt']}] , properties: properties } ;
67+
options = { filters :[ {name: 'Checkpoints', extensions: ['ckpt' , 'safetensors' ]}] , properties: properties } ;
6868
}
6969
else if(arg == 'img_files') // multi image files
7070
{

0 commit comments

Comments
 (0)