@@ -25,7 +25,11 @@ def __init__(
2525 self .optimization_config  =  optimization_config 
2626
2727 def  optimize (self , model , use_llm_runtime = False ):
28-  optimized_model  =  model 
28+  if  isinstance (model , str ):
29+  model_name  =  model 
30+  else :
31+  model_name  =  model .config ._name_or_path 
32+  optimized_model  =  model 
2933 from  intel_extension_for_transformers .transformers  import  (
3034 MixedPrecisionConfig ,
3135 WeightOnlyQuantConfig ,
@@ -35,39 +39,40 @@ def optimize(self, model, use_llm_runtime=False):
3539 f"Expect optimization_config be an object of MixedPrecisionConfig, WeightOnlyQuantConfig"  +  \
3640 " or BitsAndBytesConfig,got {type(self.optimization_config)}." 
3741 config  =  self .optimization_config 
38-  if  re .search ("flan-t5" , model . config . _name_or_path , re .IGNORECASE ):
42+  if  re .search ("flan-t5" , model_name , re .IGNORECASE ):
3943 from  intel_extension_for_transformers .transformers  import  AutoModelForSeq2SeqLM 
4044 optimized_model  =  AutoModelForSeq2SeqLM .from_pretrained (
41-  model . config . _name_or_path ,
45+  model_name ,
4246 quantization_config = config ,
4347 use_llm_runtime = use_llm_runtime ,
4448 trust_remote_code = True )
4549 elif  (
46-  re .search ("gpt" , model .config ._name_or_path , re .IGNORECASE )
47-  or  re .search ("mpt" , model .config ._name_or_path , re .IGNORECASE )
48-  or  re .search ("bloom" , model .config ._name_or_path , re .IGNORECASE )
49-  or  re .search ("llama" , model .config ._name_or_path , re .IGNORECASE )
50-  or  re .search ("opt" , model .config ._name_or_path , re .IGNORECASE )
51-  or  re .search ("neural-chat-7b-v1" , model .config ._name_or_path , re .IGNORECASE )
52-  or  re .search ("neural-chat-7b-v2" , model .config ._name_or_path , re .IGNORECASE )
50+  re .search ("gpt" , model_name , re .IGNORECASE )
51+  or  re .search ("mpt" , model_name , re .IGNORECASE )
52+  or  re .search ("bloom" , model_name , re .IGNORECASE )
53+  or  re .search ("llama" , model_name , re .IGNORECASE )
54+  or  re .search ("opt" , model_name , re .IGNORECASE )
55+  or  re .search ("neural-chat-7b-v1" , model_name , re .IGNORECASE )
56+  or  re .search ("neural-chat-7b-v2" , model_name , re .IGNORECASE )
57+  or  re .search ("neural-chat-7b-v3" , model_name , re .IGNORECASE )
5358 ):
5459 from  intel_extension_for_transformers .transformers  import  AutoModelForCausalLM 
5560 optimized_model  =  AutoModelForCausalLM .from_pretrained (
56-  model . config . _name_or_path ,
61+  model_name ,
5762 quantization_config = config ,
5863 use_llm_runtime = use_llm_runtime ,
5964 trust_remote_code = True )
60-  elif  re .search ("starcoder" , model . config . _name_or_path , re .IGNORECASE ):
65+  elif  re .search ("starcoder" , model_name , re .IGNORECASE ):
6166 from  intel_extension_for_transformers .transformers  import  GPTBigCodeForCausalLM 
6267 optimized_model  =  GPTBigCodeForCausalLM .from_pretrained (
63-  model . config . _name_or_path ,
68+  model_name ,
6469 quantization_config = config ,
6570 use_llm_runtime = use_llm_runtime ,
6671 trust_remote_code = True )
67-  elif  re .search ("chatglm" , model . config . _name_or_path , re .IGNORECASE ):
72+  elif  re .search ("chatglm" , model_name , re .IGNORECASE ):
6873 from  intel_extension_for_transformers .transformers  import  AutoModel 
6974 optimized_model  =  AutoModel .from_pretrained (
70-  model . config . _name_or_path ,
75+  model_name ,
7176 quantization_config = config ,
7277 use_llm_runtime = use_llm_runtime ,
7378 trust_remote_code = True )
0 commit comments