|
7 | 7 | from lightning.fabric.plugins.environments import LightningEnvironment |
8 | 8 | from lightning.fabric.strategies import DDPStrategy |
9 | 9 | from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher |
10 | | -from lightning.fabric.utilities.distributed import _gather_all_tensors |
| 10 | +from lightning.fabric.utilities.distributed import _gather_all_tensors, _sync_ddp |
11 | 11 | from tests_fabric.helpers.runif import RunIf |
12 | 12 |
|
13 | 13 |
|
@@ -62,20 +62,47 @@ def _test_all_gather_uneven_tensors_multidim(strategy): |
62 | 62 | assert (val == torch.ones_like(val)).all() |
63 | 63 |
|
64 | 64 |
|
| 65 | +def _test_all_reduce(strategy): |
| 66 | + rank = strategy.local_rank |
| 67 | + device = strategy.root_device |
| 68 | + world_size = strategy.num_processes |
| 69 | + |
| 70 | + for dtype in (torch.long, torch.int, torch.float, torch.half): |
| 71 | + # max |
| 72 | + tensor = torch.tensor(rank + 1, device=device, dtype=dtype) |
| 73 | + expected = torch.tensor(2, device=device, dtype=dtype) |
| 74 | + result = _sync_ddp(tensor, reduce_op="max") |
| 75 | + assert torch.equal(result, expected) |
| 76 | + assert result is tensor # inplace |
| 77 | + # sum |
| 78 | + tensor = torch.tensor(rank + 1, device=device, dtype=dtype) |
| 79 | + expected = torch.tensor(sum(range(1, world_size + 1)), device=device, dtype=dtype) |
| 80 | + result = _sync_ddp(tensor, reduce_op="sum") |
| 81 | + assert torch.equal(result, expected) |
| 82 | + assert result is tensor # inplace |
| 83 | + # average |
| 84 | + tensor = torch.tensor(rank + 1, device=device, dtype=dtype) |
| 85 | + expected = torch.tensor(sum(range(1, world_size + 1)) / 2, device=device, dtype=dtype) |
| 86 | + result = _sync_ddp(tensor, reduce_op="avg") |
| 87 | + assert torch.equal(result, expected) |
| 88 | + assert result is tensor # inplace |
| 89 | + |
| 90 | + |
65 | 91 | @RunIf(skip_windows=True) |
66 | 92 | @pytest.mark.parametrize( |
67 | 93 | "process", |
68 | 94 | [ |
69 | 95 | _test_all_gather_uneven_tensors_multidim, |
70 | 96 | _test_all_gather_uneven_tensors, |
| 97 | + _test_all_reduce, |
71 | 98 | ], |
72 | 99 | ) |
73 | 100 | @pytest.mark.parametrize( |
74 | 101 | "devices", |
75 | 102 | [ |
76 | 103 | pytest.param([torch.device("cuda:0"), torch.device("cuda:1")], marks=RunIf(min_cuda_gpus=2)), |
77 | | - [torch.device("cpu")] * 2, |
| 104 | + [torch.device("cpu"), torch.device("cpu")], |
78 | 105 | ], |
79 | 106 | ) |
80 | | -def test_gather_all_tensors(devices, process): |
| 107 | +def test_collective_operations(devices, process): |
81 | 108 | spawn_launch(process, devices) |
0 commit comments