@@ -83,6 +83,7 @@ def recover_export_model(model, current_key_name=None):
8383 Return optimum format model.
8484 """
8585 from ..llm .quantization .nn .modules import QuantizedLinearQBits
86+
8687 for name , module in model .named_children ():
8788 if current_key_name is None :
8889 current_key_name = []
@@ -194,8 +195,13 @@ def save_low_bit(
194195 )
195196 return
196197
197- if self .quantization_config .weight_dtype not in \
198- ["fp8_e5m2" , "fp8_e4m3" , "nf4" , "fp4" , "int4_fullrange" ]:
198+ if self .quantization_config .weight_dtype not in [
199+ "fp8_e5m2" ,
200+ "fp8_e4m3" ,
201+ "nf4" ,
202+ "fp4" ,
203+ "int4_fullrange" ,
204+ ]:
199205 convert_model_to_public (self )
200206 os .makedirs (save_directory , exist_ok = True )
201207 # use transformers original `save_pretrained` function
@@ -336,7 +342,27 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
336342 return_unused_kwargs = True ,
337343 ** kwargs ,
338344 )
339- if hasattr (config , "quantization_config" ):
345+
346+ if kwargs .get ("use_llm_runtime" , None ) is not None :
347+ use_neural_speed = kwargs .pop ("use_llm_runtime" , True ) and not use_xpu
348+ logger .warning (
349+ "use_llm_runtime is deprecated in version 1.3.2, please use_neural_speed instead."
350+ )
351+ elif kwargs .get ("use_neural_speed" , None ) is not None :
352+ use_neural_speed = kwargs .pop ("use_neural_speed" , True ) and not use_xpu
353+ else :
354+ if hasattr (config , "model_type" ) == False :
355+ logger .error (
356+ "Can't get the model_type. Please check the correct model_type"
357+ )
358+ exit (0 )
359+
360+ if config .model_type in cls .model_type_list and not use_xpu :
361+ use_neural_speed = True
362+ else :
363+ use_neural_speed = False
364+
365+ if hasattr (config , "quantization_config" ) and not use_neural_speed :
340366 if config .quantization_config is None :
341367 logger .warning (
342368 "Quantization_config loading failed. If you want to load saved "
@@ -369,26 +395,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
369395 "Saved low bit model loading failed, please check your model."
370396 )
371397 exit (0 )
372- if kwargs .get ("use_llm_runtime" , None ) is not None :
373- use_neural_speed = kwargs .pop ("use_llm_runtime" , True ) and not use_xpu
374- logger .warning (
375- "use_llm_runtime is deprecated in version 1.3.2, please use_neural_speed instead."
376- )
377- elif kwargs .get ("use_neural_speed" , None ) is not None :
378- use_neural_speed = kwargs .pop ("use_neural_speed" , True ) and not use_xpu
379- else :
380- if hasattr (config , "model_type" ) == False :
381- logger .error (
382- "Can't get the model_type. Please check the correct model_type"
383- )
384- exit (0 )
385-
386- if config .model_type in cls .model_type_list and not use_xpu :
387- logger .info ("Using Neural Speed..." )
388- use_neural_speed = True
389- else :
390- logger .info ("Using Pytorch..." )
391- use_neural_speed = False
392398
393399 import intel_extension_for_transformers .transformers .modeling .modeling_map
394400
@@ -437,7 +443,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
437443 if quantization_config is None :
438444 if use_neural_speed :
439445 # use wnf4_sfp32_cfp32_g32_sym by default
440- quantization_config = RtnConfig (compute_dtype = "fp32" , weight_dtype = "nf4" )
446+ quantization_config = RtnConfig (
447+ compute_dtype = "fp32" , weight_dtype = "nf4"
448+ )
441449 else :
442450 quantization_config = RtnConfig (
443451 bits = 4 ,
@@ -502,7 +510,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
502510 ):
503511 logger .info ("Applying Weight Only Quantization." )
504512 if use_neural_speed :
505- logger .info ("Using LLM runtime ." )
513+ logger .info ("Using Neural Speed ." )
506514 quantization_config .post_init_runtime ()
507515 from neural_speed import Model
508516
@@ -966,6 +974,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
966974 kwargs ["torch_dtype" ] = "auto"
967975 config = kwargs .pop ("config" , None )
968976 quantization_config = config .quantization_config
977+
969978 if quantization_config ["quant_method" ] == "rtn" :
970979 quantization_config = RtnConfig .from_dict (quantization_config )
971980 elif quantization_config ["quant_method" ] == "awq" :
@@ -976,7 +985,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
976985 quantization_config = GPTQConfig .from_dict (quantization_config )
977986 elif quantization_config ["quant_method" ] == "autoround" :
978987 quantization_config = AutoRoundConfig .from_dict (quantization_config )
979-
980988 assert (
981989 quantization_config is not None
982990 ), "Detect this model is not a low-bit model."
@@ -1170,8 +1178,13 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
11701178 model = model_class (config , * model_args , ** kwargs )
11711179 else :
11721180 model = model_class (config , * model_args , ** kwargs )
1173- if config .quantization_config ["weight_dtype" ] not in \
1174- ["fp8_e5m2" , "fp8_e4m3" , "fp4" , "nf4" , "int4_fullrange" ]:
1181+ if config .quantization_config ["weight_dtype" ] not in [
1182+ "fp8_e5m2" ,
1183+ "fp8_e4m3" ,
1184+ "fp4" ,
1185+ "nf4" ,
1186+ "int4_fullrange" ,
1187+ ]:
11751188 model = build_woq_model (model , quantization_config )
11761189 else :
11771190 model = replace_linear (
@@ -1221,8 +1234,12 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
12211234
12221235 # Set model in evaluation mode to deactivate DropOut modules by default
12231236 model .eval ()
1224- if config .quantization_config ["weight_dtype" ] not in \
1225- ["fp8_e5m2" , "fp8_e4m3" , "nf4" , "fp4" "int4_fullrange" ]:
1237+ if config .quantization_config ["weight_dtype" ] not in [
1238+ "fp8_e5m2" ,
1239+ "fp8_e4m3" ,
1240+ "nf4" ,
1241+ "fp4" "int4_fullrange" ,
1242+ ]:
12261243 model = replace_linear (
12271244 model ,
12281245 quantization_config = quantization_config ,
0 commit comments