Skip to content

Commit 5b8e2ad

Browse files
authored
test_distributed cuda tests don't skip if cuda not available. (pytorch#2476)
test_distributed cuda tests don't skip if cuda not available.
1 parent 661beb3 commit 5b8e2ad

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

test/test_distributed.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from contextlib import contextmanager
99

1010
import torch
11+
import torch.cuda
1112
import torch.distributed as dist
1213
from common import TestCase
1314

@@ -22,6 +23,20 @@
2223
print('Distributed not available, skipping tests')
2324
sys.exit(0)
2425

26+
SKIP_IF_NO_CUDA_EXIT_CODE = 75
27+
28+
29+
def skip_if_no_cuda_distributed(func):
30+
func.skip_if_no_cuda_distributed = True
31+
32+
@wraps(func)
33+
def wrapper(*args, **kwargs):
34+
if not torch.cuda.is_available():
35+
sys.exit(SKIP_IF_NO_CUDA_EXIT_CODE)
36+
37+
return func(*args, **kwargs)
38+
return wrapper
39+
2540

2641
@contextmanager
2742
def _lock():
@@ -228,6 +243,7 @@ def test_broadcast(self):
228243
self._test_broadcast_helper(group, group_id, rank)
229244

230245
@unittest.skipIf(BACKEND != 'gloo', "Only Gloo backend supports CUDA allReduce")
246+
@skip_if_no_cuda_distributed
231247
def test_broadcast_cuda(self):
232248
group, group_id, rank = self._init_global_test()
233249
self._test_broadcast_helper(group, group_id, rank, True)
@@ -333,6 +349,7 @@ def test_all_reduce_sum(self):
333349
)
334350

335351
@unittest.skipIf(BACKEND != 'gloo', "Only Gloo backend supports CUDA allReduce")
352+
@skip_if_no_cuda_distributed
336353
def test_all_reduce_sum_cuda(self):
337354
group, group_id, rank = self._init_global_test()
338355
self._test_all_reduce_helper(
@@ -487,7 +504,7 @@ def manager_join(fn):
487504
@wraps(fn)
488505
def wrapper(self):
489506
if self.rank == self.MANAGER_PROCESS_RANK:
490-
self._join_and_reduce()
507+
self._join_and_reduce(fn)
491508
else:
492509
fn(self)
493510
return wrapper
@@ -533,10 +550,22 @@ def _run(self, rank):
533550
getattr(self, self.id().split(".")[2])()
534551
sys.exit(0)
535552

536-
def _join_and_reduce(self):
553+
def _join_and_reduce(self, fn):
554+
skip_ok = getattr(fn, "skip_if_no_cuda_distributed", False)
537555
for p in self.processes:
538556
p.join(self.JOIN_TIMEOUT)
539-
self.assertEqual(p.exitcode, 0)
557+
if not skip_ok:
558+
self.assertEqual(p.exitcode, 0)
559+
560+
if skip_ok:
561+
first_process = self.processes[0]
562+
# do this first so we don't give an error message about mismatched exit codes if the first isn't valid
563+
assert first_process.exitcode == 0 or first_process.exitcode == SKIP_IF_NO_CUDA_EXIT_CODE
564+
565+
for p in self.processes:
566+
self.assertEqual(p.exitcode, first_process.exitcode)
567+
if first_process.exitcode == SKIP_IF_NO_CUDA_EXIT_CODE:
568+
raise unittest.SkipTest("cuda is not available")
540569

541570
elif BACKEND == 'mpi':
542571
dist.init_process_group(init_method=INIT_METHOD, backend='mpi')

0 commit comments

Comments
 (0)