Skip to content

Commit f5aa90d

Browse files
authored
🚨🚨🚨🚨🚨🚨🚨🚨🚨 default to "auto" dtype (#34919)
* default to `"auto"` dtype * the actual change? * up? * style * up? * only sam models were broken with this
1 parent 0af2381 commit f5aa90d

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
341341
mapping[kernel_name] = kernel
342342
except FileNotFoundError:
343343
mapping[kernel_name] = None
344+
except AssertionError:
345+
# Happens when torch is built without an accelerator backend; fall back to slow path.
346+
mapping[kernel_name] = None
344347

345348
else:
346349
# Try to import is_{kernel_name}_available from ..utils

src/transformers/modeling_utils.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,16 +268,24 @@ def get_torch_context_manager_or_global_device():
268268
return device_in_context
269269

270270

271-
def get_state_dict_dtype(state_dict):
271+
def get_state_dict_dtype(state_dict, config_dtype: Optional[torch.dtype] = None):
272272
"""
273273
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
274+
275+
If `config_dtype` is provided (for instance when `dtype="auto"` and the config already carries a dtype), it is used.
274276
"""
277+
if config_dtype is not None:
278+
return config_dtype
279+
280+
if len(state_dict) == 0:
281+
return torch.get_default_dtype()
282+
275283
for t in state_dict.values():
276284
if t.is_floating_point():
277285
return t.dtype
278286

279287
# if no floating dtype was found return whatever the first dtype is
280-
return next(state_dict.values()).dtype
288+
return next(iter(state_dict.values())).dtype
281289

282290

283291
str_to_torch_dtype = {
@@ -722,12 +730,16 @@ def _get_dtype(
722730
if is_sharded and "dtype" in sharded_metadata:
723731
dtype = sharded_metadata["dtype"]
724732
elif state_dict is not None:
725-
dtype = get_state_dict_dtype(state_dict)
733+
dtype = get_state_dict_dtype(state_dict, getattr(config, "dtype", None))
726734
else:
727735
state_dict = load_state_dict(
728736
checkpoint_files[0], map_location="meta", weights_only=weights_only
729737
)
730-
dtype = get_state_dict_dtype(state_dict)
738+
dtype = get_state_dict_dtype(state_dict, getattr(config, "dtype", None))
739+
config.dtype = dtype
740+
for sub_config_key in config.sub_configs:
741+
if (sub_config := getattr(config, sub_config_key)) is not None:
742+
sub_config.dtype = dtype
731743
logger.info(
732744
"Since the `dtype` attribute can't be found in model's config object, "
733745
"will use dtype={dtype} as derived from model's weights"
@@ -1219,6 +1231,14 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
12191231
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
12201232
)
12211233
self.config = config
1234+
if getattr(self.config, "dtype", None) is None:
1235+
default_dtype = torch.get_default_dtype()
1236+
self.config.dtype = default_dtype
1237+
for sub_config_key in self.config.sub_configs:
1238+
if (sub_config := getattr(self.config, sub_config_key)) is not None and getattr(
1239+
sub_config, "dtype", None
1240+
) is None:
1241+
sub_config.dtype = default_dtype
12221242

12231243
# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
12241244
# setting it recursively)
@@ -3789,7 +3809,8 @@ def from_pretrained(
37893809
output_loading_info = kwargs.pop("output_loading_info", False)
37903810
from_pipeline = kwargs.pop("_from_pipeline", None)
37913811
from_auto_class = kwargs.pop("_from_auto", False)
3792-
dtype = kwargs.pop("dtype", None)
3812+
dtype_kwarg_provided = "dtype" in kwargs
3813+
dtype = kwargs.pop("dtype", "auto")
37933814
torch_dtype = kwargs.pop("torch_dtype", None) # kept for BC
37943815
device_map = kwargs.pop("device_map", None)
37953816
max_memory = kwargs.pop("max_memory", None)
@@ -3820,8 +3841,8 @@ def from_pretrained(
38203841
_ = kwargs.pop(name, None)
38213842

38223843
# For BC on torch_dtype argument
3823-
if torch_dtype is not None:
3824-
dtype = dtype if dtype is not None else torch_dtype
3844+
if torch_dtype is not None and (not dtype_kwarg_provided or dtype is None):
3845+
dtype = torch_dtype
38253846

38263847
if is_offline_mode() and not local_files_only:
38273848
local_files_only = True

0 commit comments

Comments
 (0)