@@ -482,9 +482,11 @@ def test_model_from_config_torch_dtype_str(self):
482482 # test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
483483 model = AutoModel .from_pretrained (TINY_T5 , torch_dtype = "float32" )
484484 self .assertEqual (model .dtype , torch .float32 )
485+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
485486
486487 model = AutoModel .from_pretrained (TINY_T5 , torch_dtype = "float16" )
487488 self .assertEqual (model .dtype , torch .float16 )
489+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
488490
489491 # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
490492 with self .assertRaises (ValueError ):
@@ -495,14 +497,22 @@ def test_model_from_config_torch_dtype_composite(self):
495497 Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
496498 Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
497499 """
500+ # Load without dtype specified
501+ model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA )
502+ self .assertEqual (model .language_model .dtype , torch .float32 )
503+ self .assertEqual (model .vision_tower .dtype , torch .float32 )
504+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
505+
498506 # should be able to set torch_dtype as a simple string and the model loads it correctly
499507 model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA , torch_dtype = "float32" )
500508 self .assertEqual (model .language_model .dtype , torch .float32 )
501509 self .assertEqual (model .vision_tower .dtype , torch .float32 )
510+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
502511
503512 model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA , torch_dtype = torch .float16 )
504513 self .assertEqual (model .language_model .dtype , torch .float16 )
505514 self .assertEqual (model .vision_tower .dtype , torch .float16 )
515+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
506516
507517 # should be able to set torch_dtype as a dict for each sub-config
508518 model = LlavaForConditionalGeneration .from_pretrained (
@@ -511,6 +521,7 @@ def test_model_from_config_torch_dtype_composite(self):
511521 self .assertEqual (model .language_model .dtype , torch .float32 )
512522 self .assertEqual (model .vision_tower .dtype , torch .float16 )
513523 self .assertEqual (model .multi_modal_projector .linear_1 .weight .dtype , torch .bfloat16 )
524+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
514525
515526 # should be able to set the values as torch.dtype (not str)
516527 model = LlavaForConditionalGeneration .from_pretrained (
@@ -519,6 +530,7 @@ def test_model_from_config_torch_dtype_composite(self):
519530 self .assertEqual (model .language_model .dtype , torch .float32 )
520531 self .assertEqual (model .vision_tower .dtype , torch .float16 )
521532 self .assertEqual (model .multi_modal_projector .linear_1 .weight .dtype , torch .bfloat16 )
533+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
522534
523535 # should be able to set the values in configs directly and pass it to `from_pretrained`
524536 config = copy .deepcopy (model .config )
@@ -529,13 +541,15 @@ def test_model_from_config_torch_dtype_composite(self):
529541 self .assertEqual (model .language_model .dtype , torch .float32 )
530542 self .assertEqual (model .vision_tower .dtype , torch .bfloat16 )
531543 self .assertEqual (model .multi_modal_projector .linear_1 .weight .dtype , torch .float16 )
544+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
532545
533546 # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
534547 LlavaForConditionalGeneration ._keep_in_fp32_modules = ["multi_modal_projector" ]
535548 model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA , config = config , torch_dtype = "auto" )
536549 self .assertEqual (model .language_model .dtype , torch .float32 )
537550 self .assertEqual (model .vision_tower .dtype , torch .bfloat16 )
538551 self .assertEqual (model .multi_modal_projector .linear_1 .weight .dtype , torch .float32 )
552+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
539553
540554 # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
541555 with self .assertRaises (ValueError ):
0 commit comments