1111else :
1212 from astunparse import unparse
1313
14+ NO_PICKLE_DEBUG = True
1415
1516def extract_weights_from_checkpoint (fb0 ):
1617 torch_weights = {}
@@ -21,11 +22,13 @@ def extract_weights_from_checkpoint(fb0):
2122 raise ValueError ("Looks like the checkpoints file is in the wrong format" )
2223 folder_name = folder_name [0 ].replace ("/data.pkl" , "" ).replace ("\\ data.pkl" , "" )
2324 with myzip .open (folder_name + '/data.pkl' ) as myfile :
24- instructions = examine_pickle (myfile )
25- for sd_key ,load_instruction in instructions .items ():
25+ load_instructions , special_instructions = examine_pickle (myfile )
26+ for sd_key ,load_instruction in load_instructions .items ():
2627 with myzip .open (folder_name + f'/data/{ load_instruction .obj_key } ' ) as myfile :
2728 if (load_instruction .load_from_file_buffer (myfile )):
2829 torch_weights ['state_dict' ][sd_key ] = load_instruction .get_data ()
30+ for sd_key ,special in special_instructions .items ():
31+ torch_weights ['state_dict' ][sd_key ] = special
2932 return torch_weights
3033
3134def examine_pickle (fb0 ):
@@ -47,6 +50,7 @@ def examine_pickle(fb0):
4750 re_rebuild = re .compile ('^_var\d+ = _rebuild_tensor_v2\(UNPICKLER\.persistent_load\(\(.*\)$' )
4851 re_assign = re .compile ('^_var\d+ = \{.*\}$' )
4952 re_update = re .compile ('^_var\d+\.update\(\{.*\}\)$' )
53+ re_ordered_dict = re .compile ('^_var\d+ = OrderedDict\(\)$' )
5054
5155 load_instructions = {}
5256 assign_instructions = AssignInstructions ()
@@ -61,18 +65,25 @@ def examine_pickle(fb0):
6165 assign_instructions .parse_assign_line (line )
6266 elif re_update .match (line ):
6367 assign_instructions .parse_update_line (line )
64- #else:
65- # print('kicking rocks')
68+ elif re_ordered_dict .match (line ):
69+ #do nothing
70+ continue
71+ else :
72+ print (f'unmatched line: { line } ' )
73+
6674
67- #print(f"Found {len(load_instructions)} load instructions")
75+ if NO_PICKLE_DEBUG :
76+ print (f"Found { len (load_instructions )} load instructions" )
6877
6978 assign_instructions .integrate (load_instructions )
7079
71- return assign_instructions .integrated_instructions
80+ #return assign_instructions.integrated_instructions, assign_instructions.special_instructions
81+ return assign_instructions .integrated_instructions , {}
7282
7383class AssignInstructions :
7484 def __init__ (self ):
7585 self .instructions = {}
86+ self .special_instructions = {}
7687 self .integrated_instructions = {}
7788
7889 def parse_assign_line (self , line ):
@@ -84,20 +95,53 @@ def parse_assign_line(self, line):
8495 assignments = huge_mess .split (', ' )
8596 del huge_mess
8697 assignments [- 1 ] = assignments [- 1 ].strip ('}' )
98+ re_var = re .compile ('^_var\d+$' )
99+ assignment_count
87100 for a in assignments :
88- self ._add_assignment (a )
89- #print(f"Added/merged {len(assignments)} assignments. Total of {len(self.instructions)} assignment instructions")
90-
91- def _add_assignment (self , assignment ):
92- sd_key , fickling_var = assignment .split (': ' )
101+ if self ._add_assignment (a , re_var ):
102+ assignment_count = assignment_count + 1
103+ if NO_PICKLE_DEBUG :
104+ print (f"Added/merged { assignment_count } assignments. Total of { len (self .instructions )} assignment instructions" )
105+
106+ def _add_assignment (self , assignment , re_var ):
107+ # assignment can look like this:
108+ # 'cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight': _var2009
109+ # or assignment can look like this:
110+ # 'embedding_manager.embedder.transformer.text_model.encoder.layers.6.mlp.fc1': {'version': 1}
111+ sd_key , fickling_var = assignment .split (': ' , 1 )
93112 sd_key = sd_key .strip ("'" )
94- self .instructions [sd_key ] = fickling_var
113+ if re_var .match (fickling_var ):
114+ self .instructions [sd_key ] = fickling_var
115+ return True
116+ else :
117+ # now convert the string "{'version': 1}" into a dictionary {'version': 1}
118+ entries = fickling_var .split (',' )
119+ special_dict = {}
120+ for e in entries :
121+ e = e .strip ("{}" )
122+ k , v = e .split (': ' )
123+ k = k .strip ("'" )
124+ v = v .strip ("'" )
125+ special_dict [k ] = v
126+ self .special_instructions [sd_key ] = special_dict
127+ return False
95128
96129 def integrate (self , load_instructions ):
130+ unfound_keys = {}
97131 for sd_key , fickling_var in self .instructions .items ():
98132 if fickling_var in load_instructions :
99133 self .integrated_instructions [sd_key ] = load_instructions [fickling_var ]
100- #print(f"Have {len(self.integrated_instructions)} integrated load/assignment instructions")
134+ if sd_key in self .special_instructions :
135+ if NO_PICKLE_DEBUG :
136+ print (f"Key found in both load and special instructions: { sd_key } " )
137+ else :
138+ unfound_keys [sd_key ] = True ;
139+ #for sd_key, special in self.special_instructions.items():
140+ # if sd_key in unfound_keys:
141+ # #todo
142+
143+ if NO_PICKLE_DEBUG :
144+ print (f"Have { len (self .integrated_instructions )} integrated load/assignment instructions" )
101145
102146 def parse_update_line (self , line ):
103147 # input looks like:
@@ -109,9 +153,13 @@ def parse_update_line(self, line):
109153 updates = huge_mess .split (', ' )
110154 del huge_mess
111155 updates [- 1 ] = updates [- 1 ].strip ('})' )
156+ re_var = re .compile ('^_var\d+$' )
157+ update_count = 0
112158 for u in updates :
113- self ._add_assignment (u )
114- #print(f"Added/merged {len(updates)} updates. Total of {len(self.instructions)} assignment instructions")
159+ if self ._add_assignment (u , re_var ):
160+ update_count = update_count + 1
161+ if NO_PICKLE_DEBUG :
162+ print (f"Added/merged { update_count } updates. Total of { len (self .instructions )} assignment instructions" )
115163
116164class LoadInstruction :
117165 def __init__ (self , instruction_string ):
0 commit comments