Skip to content

Commit 03881b0

Browse files
pritamdamania87pytorchmergebot
authored andcommitted
Ensure ncclCommAbort can abort stuck ncclCommInitRank (pytorch#103264)
pytorch#95715 added the functionality to abort `ncclCommInitRankConfig` by specifying `blocking=0` to enable non-blocking behavior. However, calling the `pg._abort()` didn't recover from a stuck `ncclCommInitRankConfig` since the `_abort` method only looked through `devNCCLCommMap_` map and aborted those communicators. Since `ncclCommInitRankConfig` was stuck, the communicator itself wasn't added to the map and the host thread was stuck on this line: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1171. As a result, `_abort` was a no-op. To resolve this issue, I added the communicators to `inProgressCommMap_` as soon as they were created and then removed them once added to `devNCCLCommMap_`. I also added a unit test that was failing without the changes to ProcessGroupNCCL.cpp Pull Request resolved: pytorch#103264 Approved by: https://github.com/kwen2501
1 parent 1985c49 commit 03881b0

File tree

3 files changed

+76
-7
lines changed

3 files changed

+76
-7
lines changed

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -772,12 +772,14 @@ uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() {
772772
return seq_;
773773
}
774774

775-
// Abort all communicators on this rank
776-
void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) {
777-
std::lock_guard<std::mutex> lock(mutex_);
775+
void abortCommsFromMap(
776+
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
777+
ncclCommsMap,
778+
const int rank,
779+
c10::optional<std::string> abortReason) {
778780
// The process may control multiple devices, loop through the communicators on
779781
// each device
780-
for (auto& it : devNCCLCommMap_) {
782+
for (auto& it : ncclCommsMap) {
781783
auto& devName = it.first;
782784
auto& ncclComms = it.second;
783785

@@ -794,11 +796,18 @@ void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) {
794796
// their responsibility to destroy the process group and recreate
795797
// it to recover from errors.
796798

797-
LOG(INFO) << "[Rank " << rank_ << "] Destroyed " << ncclComms.size()
799+
LOG(INFO) << "[Rank " << rank << "] Destroyed " << ncclComms.size()
798800
<< "communicators on CUDA device " << devName;
799801
}
800802
}
801803

804+
// Abort all communicators on this rank
805+
void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) {
806+
std::lock_guard<std::mutex> lock(mutex_);
807+
abortCommsFromMap(devNCCLCommMap_, rank_, abortReason);
808+
abortCommsFromMap(inInitializationCommMap_, rank_, abortReason);
809+
}
810+
802811
ProcessGroupNCCL::~ProcessGroupNCCL() {
803812
terminateProcessGroup_.store(true);
804813

@@ -1160,6 +1169,11 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
11601169
at::cuda::getStreamFromPool(options_->is_high_priority_stream));
11611170
}
11621171

1172+
{
1173+
std::lock_guard<std::mutex> lock(mutex_);
1174+
inInitializationCommMap_.emplace(devicesKey, ncclComms);
1175+
}
1176+
11631177
// [Note 2 ]
11641178
#ifndef NCCL_HAS_COMM_NONBLOCKING
11651179
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
@@ -1201,8 +1215,18 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
12011215
ncclIdToCommMap_.emplace(buildNcclUniqueIdStr(ncclID), ncclComms);
12021216

12031217
// Move the NCCL resource to cache
1204-
devNCCLCommMap_.emplace(devicesKey, std::move(ncclComms));
1205-
return devNCCLCommMap_[devicesKey];
1218+
auto it = inInitializationCommMap_.find(devicesKey);
1219+
// A previous thread could've already removed devicesKey from
1220+
// inInitializationCommMap_ and added it to devNCCLCommMap_
1221+
if (it != inInitializationCommMap_.end()) {
1222+
devNCCLCommMap_.emplace(devicesKey, std::move(it->second));
1223+
inInitializationCommMap_.erase(devicesKey);
1224+
}
1225+
1226+
it = devNCCLCommMap_.find(devicesKey);
1227+
TORCH_INTERNAL_ASSERT(
1228+
it != devNCCLCommMap_.end(), "Communicators not populated in cache!");
1229+
return it->second;
12061230
}
12071231

12081232
namespace {

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
640640
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
641641
devNCCLCommMap_;
642642

643+
// The NCCL communicators currently in process of being initialized.
644+
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
645+
inInitializationCommMap_;
646+
643647
// Map from ncclUniqueId to appropriate communicator.
644648
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
645649
ncclIdToCommMap_;

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9725,6 +9725,47 @@ def forward(self, inp):
97259725
ddp._check_reducer_finalized()
97269726
ddp(input)
97279727

9728+
@skip_if_lt_x_gpu(2)
9729+
@skip_but_pass_in_sandcastle_if(
9730+
BACKEND != "nccl",
9731+
"TORCH_NCCL_USE_COMM_NONBLOCKING only applies to NCCL"
9732+
)
9733+
def test_nccl_init_abort(self):
9734+
"""
9735+
Tests that we can abort a NCCL communicator during initialization and
9736+
recover appropriately.
9737+
"""
9738+
# Reinitialize global process group with TORCH_NCCL_USE_COMM_NONBLOCKING=1
9739+
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
9740+
dist.destroy_process_group()
9741+
timeout = timedelta(seconds=1)
9742+
dist.init_process_group(
9743+
init_method=INIT_METHOD,
9744+
backend=BACKEND,
9745+
world_size=int(os.environ["WORLD_SIZE"]),
9746+
rank=self.rank,
9747+
timeout=timeout,
9748+
)
9749+
9750+
# Abort pg in background thread.
9751+
running = True
9752+
9753+
def abort():
9754+
pg = _get_default_group()
9755+
while running:
9756+
pg._get_backend(torch.device(0))._abort()
9757+
time.sleep(1)
9758+
9759+
if self.rank != 1:
9760+
import threading
9761+
t = threading.Thread(target=abort)
9762+
t.start()
9763+
with self.assertRaises(RuntimeError):
9764+
# First collective triggers initialization via ncclCommInitRank.
9765+
torch.distributed.barrier()
9766+
running = False
9767+
t.join()
9768+
97289769

97299770

97309771
@skip_if_lt_x_gpu(2)

0 commit comments

Comments
 (0)