Skip to content

Commit 1a32391

Browse files
authored
[Dygraph] Refactoring of reducer in DataParallel (#40389)
* refactor reducer * modify cmakelists * solve conflicts * rename group and update process_group * fix bugs of ProcessGroupNCCL * modify for CIs * refactoring reducer
1 parent af6ef88 commit 1a32391

File tree

9 files changed

+736
-23
lines changed

9 files changed

+736
-23
lines changed

paddle/fluid/distributed/collective/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api)
2+
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api)
3+
24
if (WITH_DISTRIBUTE)
35
cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper)
46
endif()
5-
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup)
67

78
if(WITH_NCCL)
89
cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api)

paddle/fluid/distributed/collective/ProcessGroupNCCL.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ void SyncDefaultStream(
8888
for (size_t i = 0; i < places.size(); ++i) {
8989
auto* default_ctx = static_cast<platform::CUDADeviceContext*>(
9090
platform::DeviceContextPool::Instance().Get(places[i]));
91-
ncclEvents[i].Record(*dev_ctx[i]);
92-
ncclEvents[i].Block(*default_ctx);
91+
ncclEvents[i].Record(*default_ctx);
92+
ncclEvents[i].Block(*dev_ctx[i]);
9393
}
9494
}
9595

0 commit comments

Comments
 (0)