77import ctypes
88import gc
99import io
10+ import os
1011import pickle
1112import queue
1213import sys
@@ -574,6 +575,12 @@ def test_serialization_array_with_storage(self):
574575 self .assertEqual (q_copy [3 ], torch .cuda .IntStorage (10 ).fill_ (10 ))
575576
576577 def test_cublas_allow_tf32_get_set (self ):
578+ skip_tf32_cublas = 'TORCH_ALLOW_TF32_CUBLAS_OVERRIDE' in os .environ and \
579+ int (os .environ ['TORCH_ALLOW_TF32_CUBLAS_OVERRIDE' ])
580+ if skip_tf32_cublas :
581+ self .assertTrue (torch .backends .cuda .matmul .allow_tf32 )
582+ return
583+
577584 orig = torch .backends .cuda .matmul .allow_tf32
578585 self .assertEqual (torch ._C ._get_cublas_allow_tf32 (), orig )
579586 torch .backends .cuda .matmul .allow_tf32 = not orig
@@ -582,14 +589,19 @@ def test_cublas_allow_tf32_get_set(self):
582589
583590 def test_float32_matmul_precision_get_set (self ):
584591 self .assertEqual (torch .get_float32_matmul_precision (), 'highest' )
585- self .assertFalse (torch .backends .cuda .matmul .allow_tf32 , False )
592+ skip_tf32_cublas = 'TORCH_ALLOW_TF32_CUBLAS_OVERRIDE' in os .environ and \
593+ int (os .environ ['TORCH_ALLOW_TF32_CUBLAS_OVERRIDE' ])
594+ if not skip_tf32_cublas :
595+ self .assertFalse (torch .backends .cuda .matmul .allow_tf32 )
586596 for p in ('medium' , 'high' ):
587597 torch .set_float32_matmul_precision (p )
588598 self .assertEqual (torch .get_float32_matmul_precision (), p )
589- self .assertTrue (torch .backends .cuda .matmul .allow_tf32 , True )
599+ if not skip_tf32_cublas :
600+ self .assertTrue (torch .backends .cuda .matmul .allow_tf32 )
590601 torch .set_float32_matmul_precision ('highest' )
591602 self .assertEqual (torch .get_float32_matmul_precision (), 'highest' )
592- self .assertFalse (torch .backends .cuda .matmul .allow_tf32 , False )
603+ if not skip_tf32_cublas :
604+ self .assertFalse (torch .backends .cuda .matmul .allow_tf32 )
593605
594606 def test_cublas_allow_fp16_reduced_precision_reduction_get_set (self ):
595607 orig = torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction
0 commit comments