|
8 | 8 | from contextlib import contextmanager |
9 | 9 |
|
10 | 10 | import torch |
| 11 | +import torch.cuda |
11 | 12 | import torch.distributed as dist |
12 | 13 | from common import TestCase |
13 | 14 |
|
|
22 | 23 | print('Distributed not available, skipping tests') |
23 | 24 | sys.exit(0) |
24 | 25 |
|
| 26 | +SKIP_IF_NO_CUDA_EXIT_CODE = 75 |
| 27 | + |
| 28 | + |
| 29 | +def skip_if_no_cuda_distributed(func): |
| 30 | + func.skip_if_no_cuda_distributed = True |
| 31 | + |
| 32 | + @wraps(func) |
| 33 | + def wrapper(*args, **kwargs): |
| 34 | + if not torch.cuda.is_available(): |
| 35 | + sys.exit(SKIP_IF_NO_CUDA_EXIT_CODE) |
| 36 | + |
| 37 | + return func(*args, **kwargs) |
| 38 | + return wrapper |
| 39 | + |
25 | 40 |
|
26 | 41 | @contextmanager |
27 | 42 | def _lock(): |
@@ -228,6 +243,7 @@ def test_broadcast(self): |
228 | 243 | self._test_broadcast_helper(group, group_id, rank) |
229 | 244 |
|
230 | 245 | @unittest.skipIf(BACKEND != 'gloo', "Only Gloo backend supports CUDA allReduce") |
| 246 | + @skip_if_no_cuda_distributed |
231 | 247 | def test_broadcast_cuda(self): |
232 | 248 | group, group_id, rank = self._init_global_test() |
233 | 249 | self._test_broadcast_helper(group, group_id, rank, True) |
@@ -333,6 +349,7 @@ def test_all_reduce_sum(self): |
333 | 349 | ) |
334 | 350 |
|
335 | 351 | @unittest.skipIf(BACKEND != 'gloo', "Only Gloo backend supports CUDA allReduce") |
| 352 | + @skip_if_no_cuda_distributed |
336 | 353 | def test_all_reduce_sum_cuda(self): |
337 | 354 | group, group_id, rank = self._init_global_test() |
338 | 355 | self._test_all_reduce_helper( |
@@ -487,7 +504,7 @@ def manager_join(fn): |
487 | 504 | @wraps(fn) |
488 | 505 | def wrapper(self): |
489 | 506 | if self.rank == self.MANAGER_PROCESS_RANK: |
490 | | - self._join_and_reduce() |
| 507 | + self._join_and_reduce(fn) |
491 | 508 | else: |
492 | 509 | fn(self) |
493 | 510 | return wrapper |
@@ -533,10 +550,22 @@ def _run(self, rank): |
533 | 550 | getattr(self, self.id().split(".")[2])() |
534 | 551 | sys.exit(0) |
535 | 552 |
|
536 | | - def _join_and_reduce(self): |
| 553 | + def _join_and_reduce(self, fn): |
| 554 | + skip_ok = getattr(fn, "skip_if_no_cuda_distributed", False) |
537 | 555 | for p in self.processes: |
538 | 556 | p.join(self.JOIN_TIMEOUT) |
539 | | - self.assertEqual(p.exitcode, 0) |
| 557 | + if not skip_ok: |
| 558 | + self.assertEqual(p.exitcode, 0) |
| 559 | + |
| 560 | + if skip_ok: |
| 561 | + first_process = self.processes[0] |
| 562 | + # do this first so we don't give an error message about mismatched exit codes if the first isn't valid |
| 563 | + assert first_process.exitcode == 0 or first_process.exitcode == SKIP_IF_NO_CUDA_EXIT_CODE |
| 564 | + |
| 565 | + for p in self.processes: |
| 566 | + self.assertEqual(p.exitcode, first_process.exitcode) |
| 567 | + if first_process.exitcode == SKIP_IF_NO_CUDA_EXIT_CODE: |
| 568 | + raise unittest.SkipTest("cuda is not available") |
540 | 569 |
|
541 | 570 | elif BACKEND == 'mpi': |
542 | 571 | dist.init_process_group(init_method=INIT_METHOD, backend='mpi') |
|
0 commit comments