Skip to content

Commit cfefd09

Browse files
awaelchlilexierule
authored andcommitted
Make all_reduce consistent for both NCCL and GLOO (#18235)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> (cherry picked from commit 70e31b6)
1 parent 03f57f9 commit cfefd09

File tree

6 files changed

+57
-18
lines changed

6 files changed

+57
-18
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))
2323

2424

25+
- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))
26+
2527

2628
## [2.0.5] - 2023-07-07
2729

src/lightning/fabric/fabric.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,10 +531,13 @@ def all_reduce(
531531
) -> Union[Tensor, Dict, List, Tuple]:
532532
"""Reduce tensors or collections of tensors from multiple processes.
533533
534-
This method needs to be called on all processes. Failing to do so will cause your program to stall forever.
534+
The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor.
535+
This method needs to be called on all processes and the tensors need to have the same shape across all
536+
processes, otherwise your program will stall forever.
535537
536538
Args:
537-
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
539+
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. Tensor will be
540+
modified in-place.
538541
group: the process group to reduce results across. Defaults to all processes (world).
539542
reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp.
540543
Some strategies may limit the choices here.

src/lightning/fabric/utilities/distributed.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def _sync_ddp_if_available(
113113

114114

115115
def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
116-
"""Function to reduce the tensors from several DDP processes to one main process.
116+
"""Reduces a tensor across several distributed processes.
117+
118+
This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes.
117119
118120
Args:
119121
result: The value to sync and reduce (typically tensor or number)
@@ -122,25 +124,26 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
122124
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
123125
124126
Return:
125-
reduced value
127+
The reduced value.
126128
127129
"""
128130
divide_by_world_size = False
129-
130-
if group is None:
131-
group = torch.distributed.group.WORLD
131+
group = torch.distributed.group.WORLD if group is None else group
132132

133133
op: Optional[ReduceOp]
134134
if isinstance(reduce_op, str):
135-
if reduce_op.lower() in ("avg", "mean"):
135+
reduce_op = "avg" if reduce_op == "mean" else reduce_op
136+
if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo":
137+
# The GLOO backend does not support the `ReduceOp.AVG` operation
136138
op = ReduceOp.SUM # type: ignore[assignment]
137139
divide_by_world_size = True
138140
else:
139141
op = getattr(ReduceOp, reduce_op.upper())
140142
else:
141143
op = reduce_op
142144

143-
# WA for HPU. HPU doesn't support Long types, forcefully set it to float
145+
# HPU doesn't support Long types, forcefully set it to float
146+
# TODO: move this to the `lightning_habana` package
144147
if (
145148
package_available("habana_frameworks")
146149
and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
@@ -156,11 +159,15 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
156159
# Sync all processes before reduction
157160
torch.distributed.barrier(group=group)
158161
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
162+
world_size = torch.distributed.get_world_size(group)
159163

160-
if divide_by_world_size:
161-
result = result / torch.distributed.get_world_size(group)
162-
163-
return result
164+
if not divide_by_world_size:
165+
return result
166+
# `torch.distributed.all_reduce` is in-place, so we should do the division in-place to leave the modified tensors
167+
# with the expected value
168+
if not torch.is_floating_point(result):
169+
return result.copy_(result / world_size)
170+
return result.div_(world_size)
164171

165172

166173
class _AllGather(torch.autograd.Function):

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def update(self, value: _VALUE, batch_size: int) -> None:
239239

240240
def compute(self) -> Tensor:
241241
if self.is_tensor:
242-
value = self.meta.sync(self.value)
242+
value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place
243243
if self.meta.is_mean_reduction:
244244
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
245245
return value / cumulated_batch_size

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from lightning.fabric.plugins.environments import LightningEnvironment
88
from lightning.fabric.strategies import DDPStrategy
99
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
10-
from lightning.fabric.utilities.distributed import _gather_all_tensors
10+
from lightning.fabric.utilities.distributed import _gather_all_tensors, _sync_ddp
1111
from tests_fabric.helpers.runif import RunIf
1212

1313

@@ -62,20 +62,47 @@ def _test_all_gather_uneven_tensors_multidim(strategy):
6262
assert (val == torch.ones_like(val)).all()
6363

6464

65+
def _test_all_reduce(strategy):
66+
rank = strategy.local_rank
67+
device = strategy.root_device
68+
world_size = strategy.num_processes
69+
70+
for dtype in (torch.long, torch.int, torch.float, torch.half):
71+
# max
72+
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
73+
expected = torch.tensor(2, device=device, dtype=dtype)
74+
result = _sync_ddp(tensor, reduce_op="max")
75+
assert torch.equal(result, expected)
76+
assert result is tensor # inplace
77+
# sum
78+
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
79+
expected = torch.tensor(sum(range(1, world_size + 1)), device=device, dtype=dtype)
80+
result = _sync_ddp(tensor, reduce_op="sum")
81+
assert torch.equal(result, expected)
82+
assert result is tensor # inplace
83+
# average
84+
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
85+
expected = torch.tensor(sum(range(1, world_size + 1)) / 2, device=device, dtype=dtype)
86+
result = _sync_ddp(tensor, reduce_op="avg")
87+
assert torch.equal(result, expected)
88+
assert result is tensor # inplace
89+
90+
6591
@RunIf(skip_windows=True)
6692
@pytest.mark.parametrize(
6793
"process",
6894
[
6995
_test_all_gather_uneven_tensors_multidim,
7096
_test_all_gather_uneven_tensors,
97+
_test_all_reduce,
7198
],
7299
)
73100
@pytest.mark.parametrize(
74101
"devices",
75102
[
76103
pytest.param([torch.device("cuda:0"), torch.device("cuda:1")], marks=RunIf(min_cuda_gpus=2)),
77-
[torch.device("cpu")] * 2,
104+
[torch.device("cpu"), torch.device("cpu")],
78105
],
79106
)
80-
def test_gather_all_tensors(devices, process):
107+
def test_collective_operations(devices, process):
81108
spawn_launch(process, devices)

tests/tests_pytorch/core/test_metric_result_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def on_train_epoch_end(self) -> None:
357357
assert metrics["callback"]["tracking"] == expected
358358
assert computed_value == 2
359359

360-
assert self.results["training_step.tracking_2"].value == total * devices
360+
assert self.results["training_step.tracking_2"].value == total
361361
assert metrics["callback"]["tracking_2"] == expected
362362
assert computed_value == 2
363363
self.has_validated_sum = True

0 commit comments

Comments
 (0)