Skip to content

Commit 76b952b

Browse files
eqypytorchmergebot
authored andcommitted
[CUBLAS][TF32] Skip test_cublas_allow_tf32_get_set if TORCH_ALLOW_TF32_CUBLAS_OVERRIDE is set (pytorch#77298)
Follow-up to pytorch#77114 to prevent test breakages when the environment variable is set. CC @xwang233 @ngimel @ptrblck Pull Request resolved: pytorch#77298 Approved by: https://github.com/xwang233, https://github.com/ngimel
1 parent 14ab3ff commit 76b952b

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

test/test_cuda.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import ctypes
88
import gc
99
import io
10+
import os
1011
import pickle
1112
import queue
1213
import sys
@@ -574,6 +575,12 @@ def test_serialization_array_with_storage(self):
574575
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
575576

576577
def test_cublas_allow_tf32_get_set(self):
578+
skip_tf32_cublas = 'TORCH_ALLOW_TF32_CUBLAS_OVERRIDE' in os.environ and\
579+
int(os.environ['TORCH_ALLOW_TF32_CUBLAS_OVERRIDE'])
580+
if skip_tf32_cublas:
581+
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
582+
return
583+
577584
orig = torch.backends.cuda.matmul.allow_tf32
578585
self.assertEqual(torch._C._get_cublas_allow_tf32(), orig)
579586
torch.backends.cuda.matmul.allow_tf32 = not orig
@@ -582,14 +589,19 @@ def test_cublas_allow_tf32_get_set(self):
582589

583590
def test_float32_matmul_precision_get_set(self):
584591
self.assertEqual(torch.get_float32_matmul_precision(), 'highest')
585-
self.assertFalse(torch.backends.cuda.matmul.allow_tf32, False)
592+
skip_tf32_cublas = 'TORCH_ALLOW_TF32_CUBLAS_OVERRIDE' in os.environ and\
593+
int(os.environ['TORCH_ALLOW_TF32_CUBLAS_OVERRIDE'])
594+
if not skip_tf32_cublas:
595+
self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
586596
for p in ('medium', 'high'):
587597
torch.set_float32_matmul_precision(p)
588598
self.assertEqual(torch.get_float32_matmul_precision(), p)
589-
self.assertTrue(torch.backends.cuda.matmul.allow_tf32, True)
599+
if not skip_tf32_cublas:
600+
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
590601
torch.set_float32_matmul_precision('highest')
591602
self.assertEqual(torch.get_float32_matmul_precision(), 'highest')
592-
self.assertFalse(torch.backends.cuda.matmul.allow_tf32, False)
603+
if not skip_tf32_cublas:
604+
self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
593605

594606
def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
595607
orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction

0 commit comments

Comments
 (0)