1111else :
1212 from astunparse import unparse
1313
14- NO_PICKLE_DEBUG = True
14+ NO_PICKLE_DEBUG = False
1515
1616def extract_weights_from_checkpoint (fb0 ):
1717 torch_weights = {}
@@ -22,32 +22,40 @@ def extract_weights_from_checkpoint(fb0):
2222 raise ValueError ("Looks like the checkpoints file is in the wrong format" )
2323 folder_name = folder_name [0 ].replace ("/data.pkl" , "" ).replace ("\\ data.pkl" , "" )
2424 with myzip .open (folder_name + '/data.pkl' ) as myfile :
25- load_instructions , special_instructions = examine_pickle (myfile )
25+ load_instructions = examine_pickle (myfile )
2626 for sd_key ,load_instruction in load_instructions .items ():
2727 with myzip .open (folder_name + f'/data/{ load_instruction .obj_key } ' ) as myfile :
2828 if (load_instruction .load_from_file_buffer (myfile )):
2929 torch_weights ['state_dict' ][sd_key ] = load_instruction .get_data ()
30- if len (special_instructions ) > 0 :
31- torch_weights ['state_dict' ]['_metadata' ] = {}
32- for sd_key ,special in special_instructions .items ():
33- torch_weights ['state_dict' ]['_metadata' ][sd_key ] = special
30+ # if len(special_instructions) > 0:
31+ # torch_weights['state_dict']['_metadata'] = {}
32+ # for sd_key,special in special_instructions.items():
33+ # torch_weights['state_dict']['_metadata'][sd_key] = special
3434 return torch_weights
3535
36- def examine_pickle (fb0 ):
36+ def examine_pickle (fb0 , return_special = False ):
37+ ## return_special:
38+ ## A rabbit hole I chased trying to debug a model that wouldn't import that had 1300 useless metadata statements
39+ ## If for some reason it's needed in the future turn it on. It is passed into the class AssignInstructions and
40+ ## if turned on collect_special will be True
41+ ##
3742
43+ #turn the pickle file into text we can parse
3844 decompiled = unparse (Pickled .load (fb0 ).ast ).splitlines ()
3945
40- ## LINES WE CARE ABOUT:
41- ## 1: this defines a data file and what kind of data is in it
42- ## _var1 = _rebuild_tensor_v2(UNPICKLER.persistent_load(('storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
43- ##
44- ## 2: this massive line assigns the previous data to dictionary entries
45- ## _var2262 = {'model.diffusion_model.input_blocks.0.0.weight': _var1, [..... continue for ever]}
46- ##
47- ## 3: this massive line also assigns values to keys, but does so differently
48- ## _var2262.update({ 'cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias': _var2001, [ .... and on and on ]})
49- ##
50- ## that's it
46+ ## Parsing the decompiled pickle:
47+ ## LINES WE CARE ABOUT:
48+ ## 1: this defines a data file and what kind of data is in it
49+ ## _var1 = _rebuild_tensor_v2(UNPICKLER.persistent_load(('storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
50+ ##
51+ ## 2: this massive line assigns the previous data to dictionary entries
52+ ## _var2262 = {'model.diffusion_model.input_blocks.0.0.weight': _var1, [..... continue for ever]}
53+ ##
54+ ## 3: this massive line also assigns values to keys, but does so differently
55+ ## _var2262.update({ 'cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias': _var2001, [ .... and on and on ]})
56+ ##
57+ ## that's it
58+
5159 # make some REs to match the above.
5260 re_rebuild = re .compile ('^_var\d+ = _rebuild_tensor_v2\(UNPICKLER\.persistent_load\(\(.*\)$' )
5361 re_assign = re .compile ('^_var\d+ = \{.*\}$' )
@@ -62,15 +70,15 @@ def examine_pickle(fb0):
6270 line = line .strip ()
6371 if re_rebuild .match (line ):
6472 variable_name , load_instruction = line .split (' = ' , 1 )
65- load_instructions [variable_name ] = LoadInstruction (line )
73+ load_instructions [variable_name ] = LoadInstruction (line , variable_name )
6674 elif re_assign .match (line ):
6775 assign_instructions .parse_assign_line (line )
6876 elif re_update .match (line ):
6977 assign_instructions .parse_update_line (line )
7078 elif re_ordered_dict .match (line ):
7179 #do nothing
7280 continue
73- else :
81+ elif NO_PICKLE_DEBUG :
7482 print (f'unmatched line: { line } ' )
7583
7684
@@ -79,14 +87,16 @@ def examine_pickle(fb0):
7987
8088 assign_instructions .integrate (load_instructions )
8189
82- return assign_instructions .integrated_instructions , assign_instructions .special_instructions
83- #return assign_instructions.integrated_instructions, {}
90+ if return_special :
91+ return assign_instructions .integrated_instructions , assign_instructions .special_instructions
92+ return assign_instructions .integrated_instructions
8493
8594class AssignInstructions :
86- def __init__ (self ):
95+ def __init__ (self , collect_special = False ):
8796 self .instructions = {}
8897 self .special_instructions = {}
8998 self .integrated_instructions = {}
99+ self .collect_special = collect_special ;
90100
91101 def parse_assign_line (self , line ):
92102 # input looks like this:
@@ -115,7 +125,7 @@ def _add_assignment(self, assignment, re_var):
115125 if re_var .match (fickling_var ):
116126 self .instructions [sd_key ] = fickling_var
117127 return True
118- else :
128+ elif self . collect_special :
119129 # now convert the string "{'version': 1}" into a dictionary {'version': 1}
120130 entries = fickling_var .split (',' )
121131 special_dict = {}
@@ -133,14 +143,9 @@ def integrate(self, load_instructions):
133143 for sd_key , fickling_var in self .instructions .items ():
134144 if fickling_var in load_instructions :
135145 self .integrated_instructions [sd_key ] = load_instructions [fickling_var ]
136- if sd_key in self .special_instructions :
137- if NO_PICKLE_DEBUG :
138- print (f"Key found in both load and special instructions: { sd_key } " )
139146 else :
140- unfound_keys [sd_key ] = True ;
141- #for sd_key, special in self.special_instructions.items():
142- # if sd_key in unfound_keys:
143- # #todo
147+ if NO_PICKLE_DEBUG :
148+ print (f"no load instruction found for { sd_key } " )
144149
145150 if NO_PICKLE_DEBUG :
146151 print (f"Have { len (self .integrated_instructions )} integrated load/assignment instructions" )
@@ -164,14 +169,16 @@ def parse_update_line(self, line):
164169 print (f"Added/merged { update_count } updates. Total of { len (self .instructions )} assignment instructions" )
165170
166171class LoadInstruction :
167- def __init__ (self , instruction_string ):
172+ def __init__ (self , instruction_string , variable_name , extra_debugging = False ):
168173 self .ident = False
169174 self .storage_type = False
170175 self .obj_key = False
171176 self .location = False #unused
172177 self .obj_size = False
173178 self .stride = False #unused
174- self .data = False ;
179+ self .data = False
180+ self .variable_name = variable_name
181+ self .extra_debugging = extra_debugging
175182 self .parse_instruction (instruction_string )
176183
177184 def parse_instruction (self , instruction_string ):
@@ -185,12 +192,24 @@ def parse_instruction(self, instruction_string):
185192 #
186193 # the following comments will show the output of each string manipulation as if it started with the above.
187194
195+ if self .extra_debugging :
196+ print (f"input: '{ instruction_string } '" )
197+
188198 garbage , storage_etc = instruction_string .split ('((' , 1 )
189199 # storage_etc = 'storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
200+
201+ if self .extra_debugging :
202+ print ("storage_etc, reference: ''storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'" )
203+ print (f"storage_etc, actual: '{ storage_etc } '\n " )
190204
191205 storage , etc = storage_etc .split ('))' , 1 )
192206 # storage = 'storage', HalfStorage, '0', 'cpu', 11520
193- # etc = 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
207+ # etc = , 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
208+ if self .extra_debugging :
209+ print ("storage, reference: ''storage', HalfStorage, '0', 'cpu', 11520'" )
210+ print (f"storage, actual: '{ storage } '\n " )
211+ print ("etc, reference: ', 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'" )
212+ print (f"etc, actual: '{ etc } '\n " )
194213
195214 ## call below maps to: ('storage', HalfStorage, '0', 'cpu', 11520)
196215 self .ident , self .storage_type , self .obj_key , self .location , self .obj_size = storage .split (', ' , 4 )
@@ -201,10 +220,16 @@ def parse_instruction(self, instruction_string):
201220 self .obj_size = int (self .obj_size )
202221 self .storage_type = self ._torch_to_numpy (self .storage_type )
203222
223+ if self .extra_debugging :
224+ print (f"{ self .ident } , { self .obj_key } , { self .location } , { self .obj_size } , { self .storage_type } " )
225+
204226 assert (self .ident == 'storage' )
205227
206228 garbage , etc = etc .split (', (' , 1 )
207229 # etc = 320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
230+ if self .extra_debugging :
231+ print ("etc, reference: '320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'" )
232+ print (f"etc, actual: '{ etc } '\n " )
208233
209234 size , stride , garbage = etc .split ('), ' , 2 )
210235 # size = 320, 4, 3, 3
@@ -223,12 +248,19 @@ def parse_instruction(self, instruction_string):
223248 else :
224249 self .stride = tuple (map (int , stride .split (', ' )))
225250
251+
252+ if self .extra_debugging :
253+ print (f"size: { self .size_tuple } , stride: { self .stride } " )
254+
226255 prod_size = prod (self .size_tuple )
227256 assert prod (self .size_tuple ) == self .obj_size # does the size in the storage call match the size tuple
228257
229258 # zero out the data
230259 self .data = np .zeros (self .size_tuple , dtype = self .storage_type )
231260
261+ def sayHi (self ):
262+ print (f"Hi, I'm an instance of LoadInstruction that will be used to load datafile { self .obj_key } " )
263+
232264 @staticmethod
233265 def _torch_to_numpy (storage_type ):
234266 if storage_type == 'FloatStorage' :
0 commit comments