@@ -268,16 +268,24 @@ def get_torch_context_manager_or_global_device():
268268 return device_in_context
269269
270270
271- def get_state_dict_dtype (state_dict ):
271+ def get_state_dict_dtype (state_dict , config_dtype : Optional [ torch . dtype ] = None ):
272272 """
273273 Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
274+
275+ If `config_dtype` is provided (for instance when `dtype="auto"` and the config already carries a dtype), it is used.
274276 """
277+ if config_dtype is not None :
278+ return config_dtype
279+
280+ if len (state_dict ) == 0 :
281+ return torch .get_default_dtype ()
282+
275283 for t in state_dict .values ():
276284 if t .is_floating_point ():
277285 return t .dtype
278286
279287 # if no floating dtype was found return whatever the first dtype is
280- return next (state_dict .values ()).dtype
288+ return next (iter ( state_dict .values () )).dtype
281289
282290
283291str_to_torch_dtype = {
@@ -722,12 +730,16 @@ def _get_dtype(
722730 if is_sharded and "dtype" in sharded_metadata :
723731 dtype = sharded_metadata ["dtype" ]
724732 elif state_dict is not None :
725- dtype = get_state_dict_dtype (state_dict )
733+ dtype = get_state_dict_dtype (state_dict , getattr ( config , "dtype" , None ) )
726734 else :
727735 state_dict = load_state_dict (
728736 checkpoint_files [0 ], map_location = "meta" , weights_only = weights_only
729737 )
730- dtype = get_state_dict_dtype (state_dict )
738+ dtype = get_state_dict_dtype (state_dict , getattr (config , "dtype" , None ))
739+ config .dtype = dtype
740+ for sub_config_key in config .sub_configs :
741+ if (sub_config := getattr (config , sub_config_key )) is not None :
742+ sub_config .dtype = dtype
731743 logger .info (
732744 "Since the `dtype` attribute can't be found in model's config object, "
733745 "will use dtype={dtype} as derived from model's weights"
@@ -1219,6 +1231,14 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
12191231 f"`model = { self .__class__ .__name__ } .from_pretrained(PRETRAINED_MODEL_NAME)`"
12201232 )
12211233 self .config = config
1234+ if getattr (self .config , "dtype" , None ) is None :
1235+ default_dtype = torch .get_default_dtype ()
1236+ self .config .dtype = default_dtype
1237+ for sub_config_key in self .config .sub_configs :
1238+ if (sub_config := getattr (self .config , sub_config_key )) is not None and getattr (
1239+ sub_config , "dtype" , None
1240+ ) is None :
1241+ sub_config .dtype = default_dtype
12221242
12231243 # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
12241244 # setting it recursively)
@@ -3789,7 +3809,8 @@ def from_pretrained(
37893809 output_loading_info = kwargs .pop ("output_loading_info" , False )
37903810 from_pipeline = kwargs .pop ("_from_pipeline" , None )
37913811 from_auto_class = kwargs .pop ("_from_auto" , False )
3792- dtype = kwargs .pop ("dtype" , None )
3812+ dtype_kwarg_provided = "dtype" in kwargs
3813+ dtype = kwargs .pop ("dtype" , "auto" )
37933814 torch_dtype = kwargs .pop ("torch_dtype" , None ) # kept for BC
37943815 device_map = kwargs .pop ("device_map" , None )
37953816 max_memory = kwargs .pop ("max_memory" , None )
@@ -3820,8 +3841,8 @@ def from_pretrained(
38203841 _ = kwargs .pop (name , None )
38213842
38223843 # For BC on torch_dtype argument
3823- if torch_dtype is not None :
3824- dtype = dtype if dtype is not None else torch_dtype
3844+ if torch_dtype is not None and ( not dtype_kwarg_provided or dtype is None ) :
3845+ dtype = torch_dtype
38253846
38263847 if is_offline_mode () and not local_files_only :
38273848 local_files_only = True
0 commit comments