@@ -78,8 +78,8 @@ def get_model_type(model_config):
7878 def init (self , model_name , use_quant = True , use_cache = False , use_gptq = False , ** quant_kwargs ):
7979 self .config = AutoConfig .from_pretrained (model_name , trust_remote_code = True )
8080 self .tokenizer = AutoTokenizer .from_pretrained (model_name , trust_remote_code = True )
81- model_type = Model .get_model_type (self .config )
82- self .__import_package (model_type )
81+ self . model_type = Model .get_model_type (self .config )
82+ self .__import_package (self . model_type )
8383
8484 # check cache and quantization
8585 if use_quant :
@@ -88,7 +88,7 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, **qu
8888 " is not currently supported. Please use other combinations." )
8989 output_path = "runtime_outs"
9090 os .makedirs (output_path , exist_ok = True )
91- fp32_bin = "{}/ne_{}_f32.bin" .format (output_path , model_type )
91+ fp32_bin = "{}/ne_{}_f32.bin" .format (output_path , self . model_type )
9292 quant_desc = quant_kwargs ['weight_dtype' ]
9393 if quant_kwargs ['use_ggml' ]:
9494 quant_desc += "_ggml"
@@ -100,7 +100,7 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, **qu
100100 quant_desc += "_g{}" .format (quant_kwargs ['group_size' ])
101101 if use_gptq :
102102 quant_desc = "gptq"
103- quant_bin = "{}/ne_{}_q_{}.bin" .format (output_path , model_type , quant_desc )
103+ quant_bin = "{}/ne_{}_q_{}.bin" .format (output_path , self . model_type , quant_desc )
104104
105105 if not use_quant :
106106 self .bin_file = fp32_bin
@@ -203,7 +203,7 @@ def is_token_end(self):
203203 return self .model .is_token_end ()
204204
205205 def eos_token_id (self ):
206- if self .tokenizer . eos_token_id == None :
206+ if self .model_type == 'qwen' :
207207 return self .tokenizer .special_tokens ['<|endoftext|>' ]
208208 return self .tokenizer .eos_token_id
209209
0 commit comments