Skip to content

Commit b55d0e5

Browse files
【custom】add Custom xcclcommcontext init in new_executor (#71357)
* modify setup wrong path * add xccl init * recover * add instruction xccl init * set stream * modify build bug * update custom_xccl * add comm op in instruction * fix py3 ci
1 parent 8532d71 commit b55d0e5

File tree

8 files changed

+187
-99
lines changed

8 files changed

+187
-99
lines changed

paddle/fluid/distributed/collective/process_group_custom.cc

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
148148
return iter->second.get();
149149
}
150150
}
151-
152151
phi::ccl::CCLComm ProcessGroupCustom::XCCLComm(const Place& place) {
153152
const std::string& key = GetKeyFromPlace(place);
154153
phi::DeviceGuard guard(place);
@@ -164,6 +163,16 @@ phi::ccl::CCLComm ProcessGroupCustom::XCCLComm(const Place& place) {
164163
return iter->second->xccl_comm();
165164
}
166165

166+
phi::distributed::XCCLCommContext* ProcessGroupCustom::GetOrCreateCommContext(
167+
const Place& place) {
168+
const std::string& key = GetKeyFromPlace(place);
169+
phi::DeviceGuard guard(place);
170+
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
171+
CreateXCCLEnvCache(place, key);
172+
}
173+
return this->GetCommContext();
174+
}
175+
167176
std::string ProcessGroupCustom::GetCommName(int rank) {
168177
PADDLE_ENFORCE_GE(rank,
169178
0,
@@ -603,24 +612,24 @@ void ProcessGroupCustom::CreateXCCLEnvCache(const Place& place,
603612

604613
auto* calc_ctx = static_cast<phi::CustomContext*>(
605614
phi::DeviceContextPool::Instance().Get(place));
606-
auto comm_ctx = std::make_unique<phi::CustomContext>(place);
607-
comm_ctx->SetAllocator(
615+
auto custom_context = std::make_unique<phi::CustomContext>(place);
616+
custom_context->SetAllocator(
608617
&(phi::DeviceContextPool::Instance().Get(place)->GetAllocator()));
609-
comm_ctx->SetHostAllocator(
618+
custom_context->SetHostAllocator(
610619
&(phi::DeviceContextPool::Instance().Get(place)->GetHostAllocator()));
611-
comm_ctx->SetZeroAllocator(
620+
custom_context->SetZeroAllocator(
612621
&(phi::DeviceContextPool::Instance().Get(place)->GetZeroAllocator()));
613-
comm_ctx->SetHostZeroAllocator(
622+
custom_context->SetHostZeroAllocator(
614623
&(phi::DeviceContextPool::Instance().Get(place)->GetHostZeroAllocator()));
615624

616625
auto xccl_comm_ctx = this->GetCommContext();
617-
comm_ctx->set_xccl_comm(xccl_comm_ctx->GetXcclComm());
626+
custom_context->set_xccl_comm(xccl_comm_ctx->GetXcclComm());
618627

619628
auto xccl_event = std::make_unique<phi::event::Event>();
620629
xccl_event->Init(place);
621630
place_to_calc_event_.emplace(place_key, std::move(xccl_event));
622631
place_to_calc_ctx_.emplace(place_key, calc_ctx);
623-
place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
632+
place_to_comm_ctx_.emplace(place_key, std::move(custom_context));
624633

625634
// TODO(sunyilun): for compatibility, will be removed later
626635
std::vector<phi::CustomContext*> comm_ctx_wrapper{
@@ -632,9 +641,9 @@ void ProcessGroupCustom::SyncCalcStream(const Place& place) {
632641
const std::string& key = GetKeyFromPlace(place);
633642
auto& calc_event = place_to_calc_event_.at(key);
634643
const auto* calc_ctx = place_to_calc_ctx_.at(key);
635-
const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
644+
const auto* custom_context = place_to_comm_ctx_.at(key).get();
636645
calc_event->Record(calc_ctx->GetStream().get());
637-
comm_ctx->GetStream()->WaitEvent(calc_event.get());
646+
custom_context->GetStream()->WaitEvent(calc_event.get());
638647
}
639648

640649
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::RunFnInXCCLEnv(
@@ -663,9 +672,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::RunFnInXCCLEnv(
663672
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
664673

665674
const auto* calc_ctx = place_to_calc_ctx_.at(key);
666-
const auto& comm_ctx = place_to_comm_ctx_.at(key);
675+
const auto& custom_context = place_to_comm_ctx_.at(key);
667676
auto& xccl_stream =
668-
use_calc_stream ? *calc_ctx->GetStream() : *comm_ctx->GetStream();
677+
use_calc_stream ? *calc_ctx->GetStream() : *custom_context->GetStream();
669678
fn(xccl_stream);
670679

671680
if (!use_calc_stream) {
@@ -674,7 +683,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::RunFnInXCCLEnv(
674683
memory::RecordStream(tensors[i].Holder(), xccl_stream.raw_stream());
675684
}
676685
}
677-
task->UpdateWaitChain(*comm_ctx);
686+
task->UpdateWaitChain(*custom_context);
678687
}
679688

680689
return task;

paddle/fluid/distributed/collective/process_group_custom.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,15 @@ class ProcessGroupCustom final : public ProcessGroupWithStream {
160160
bool sync_op,
161161
bool use_calc_stream) override;
162162

163+
using ProcessGroupWithStream::Recv;
164+
163165
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
164166
int src_rank,
165167
int64_t offset,
166168
int64_t numel,
167169
bool sync_op,
168170
bool use_calc_stream) override;
169-
171+
using ProcessGroupWithStream::Send;
170172
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
171173
int dst_rank,
172174
int64_t offset,
@@ -180,6 +182,8 @@ class ProcessGroupCustom final : public ProcessGroupWithStream {
180182

181183
phi::ccl::CCLComm XCCLComm(const Place& place);
182184

185+
phi::distributed::XCCLCommContext* GetOrCreateCommContext(const Place& place);
186+
183187
// TODO(liyurui): This API will be moved later
184188
std::shared_ptr<ProcessGroup::Task> AllReduce(
185189
std::vector<phi::DenseTensor>& in_tensors,

paddle/fluid/framework/CMakeLists.txt

Lines changed: 28 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -268,72 +268,36 @@ else()
268268
DEPS lod_tensor selected_rows_utils phi common var_type_traits op_info)
269269
endif()
270270

271-
if(WITH_XPU)
272-
cc_library(
273-
operator
274-
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
275-
infershape_utils.cc
276-
DEPS op_info
277-
proto_desc
278-
tensor
279-
scope
280-
glog
281-
shape_inference
282-
data_transform
283-
lod_tensor
284-
op_kernel_type
285-
op_call_stack
286-
detail_op_handle
287-
phi_utils
288-
phi
289-
common
290-
op_compat_infos
291-
type_info)
292-
elseif(WITH_NCCL OR WITH_RCCL)
293-
cc_library(
294-
operator
295-
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
296-
infershape_utils.cc
297-
DEPS op_info
298-
proto_desc
299-
tensor
300-
scope
301-
glog
302-
shape_inference
303-
data_transform
304-
lod_tensor
305-
op_kernel_type
306-
op_call_stack
307-
detail_op_handle
308-
phi_utils
309-
phi
310-
common
311-
op_compat_infos
312-
type_info
313-
process_group_nccl)
314-
else()
315-
cc_library(
316-
operator
317-
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
318-
infershape_utils.cc
319-
DEPS op_info
320-
proto_desc
321-
tensor
322-
scope
323-
glog
324-
shape_inference
325-
data_transform
326-
lod_tensor
327-
op_kernel_type
328-
op_call_stack
329-
detail_op_handle
330-
phi_utils
331-
phi
332-
common
333-
op_compat_infos
334-
type_info)
271+
set(OPERETER_DEPS
272+
op_info
273+
proto_desc
274+
tensor
275+
scope
276+
glog
277+
shape_inference
278+
data_transform
279+
lod_tensor
280+
op_kernel_type
281+
op_call_stack
282+
detail_op_handle
283+
phi_utils
284+
phi
285+
common
286+
op_compat_infos
287+
type_info)
288+
289+
if(WITH_NCCL OR WITH_RCCL)
290+
set(OPERETER_DEPS ${OPERETER_DEPS} process_group_nccl)
291+
elseif(WITH_CUSTOM_DEVICE)
292+
set(OPERETER_DEPS ${OPERETER_DEPS} process_group_custom)
335293
endif()
336294

295+
cc_library(
296+
operator
297+
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
298+
infershape_utils.cc
299+
DEPS ${OPERETER_DEPS})
300+
337301
cc_library(version SRCS version.cc)
338302

339303
add_library(proto_desc_base OBJECT var_desc.cc op_desc.cc block_desc.cc

paddle/fluid/framework/new_executor/instruction/instruction_util.cc

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,23 @@
3636
#include "paddle/phi/core/dense_tensor.h"
3737
#include "paddle/pir/include/core/block_argument.h"
3838
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
39-
defined(PADDLE_WITH_XPU_BKCL)
39+
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
4040
#include "paddle/common/flags.h"
4141
#include "paddle/fluid/distributed/collective/process_group.h"
4242
#include "paddle/phi/core/distributed/comm_context_manager.h"
43-
#include "paddle/phi/core/platform/collective_helper.h"
44-
COMMON_DECLARE_bool(dynamic_static_unified_comm);
45-
#endif
46-
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
43+
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
44+
#include "paddle/fluid/distributed/collective/process_group_custom.h"
45+
#include "paddle/phi/backends/custom/custom_context.h"
46+
#include "paddle/phi/core/distributed/xccl_comm_context.h"
47+
#elif defined(PADDLE_WITH_XPU_BKCL)
48+
#include "paddle/fluid/distributed/collective/process_group_bkcl.h"
49+
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
50+
#else
4751
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
4852
#include "paddle/phi/core/distributed/nccl_comm_context.h"
4953
#endif
50-
#if defined(PADDLE_WITH_XPU_BKCL)
51-
#include "paddle/fluid/distributed/collective/process_group_bkcl.h"
52-
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
54+
#include "paddle/phi/core/platform/collective_helper.h"
55+
COMMON_DECLARE_bool(dynamic_static_unified_comm);
5356
#endif
5457

5558
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
@@ -66,10 +69,14 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
6669
phi::distributed::CommContextManager::CreateBKCLCommContext
6770
#define PLATFORM_COMM_CONTEXT platform::BKCLCommContext
6871
#define PROCESS_GROUP paddle::distributed::ProcessGroupBKCL
72+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
73+
#define COMM_CONTEXT phi::distributed::XCCLCommContext
74+
#define CREATE_COMM_CONTEXT \
75+
phi::distributed::CommContextManager::CreateXCCLCommContext
76+
#define PROCESS_GROUP paddle::distributed::ProcessGroupCustom
6977
#endif
7078

7179
namespace paddle::framework {
72-
7380
std::vector<int> GetValueIds(pir::Value value,
7481
const ValueExecutionInfo& value_exec_info) {
7582
std::vector<int> ids;
@@ -131,7 +138,7 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
131138
}
132139

133140
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
134-
defined(PADDLE_WITH_XPU_BKCL)
141+
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
135142
// NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum
136143
// with use_cal_stream==false by returning a device context getting from the
137144
// global NCCLCommContext instance. Because when use_calc_stream==false, in
@@ -155,9 +162,16 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
155162
comm_context_manager.Get(std::to_string(ring_id)))
156163
->GetDevContext());
157164
} else {
165+
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
166+
PADDLE_ENFORCE(
167+
false,
168+
common::errors::InvalidArgument(
169+
"Custom device does not support old communication context."));
170+
#else
158171
dev_ctx = PLATFORM_COMM_CONTEXT::Instance()
159172
.Get(ring_id, place)
160173
->dev_context();
174+
#endif
161175
}
162176
return dev_ctx;
163177
}
@@ -175,6 +189,8 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
175189
comm_context = comm_context_manager.Get(std::to_string(ring_id));
176190
} else if (op_name.compare(paddle::dialect::MpAllreduceSum_Op::name()) ==
177191
0 ||
192+
op_name.compare(paddle::dialect::MpAllreduceSumOp::name()) ==
193+
0 ||
178194
op_name.compare(paddle::dialect::AllReduce_Op::name()) == 0 ||
179195
op_name.compare(paddle::dialect::CIdentity_Op::name()) == 0 ||
180196
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
@@ -189,12 +205,14 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
189205
dev_ctx = static_cast<platform::DeviceContext*>(
190206
static_cast<COMM_CONTEXT*>(comm_context)->GetDevContext());
191207
dev_ctx->SetCommContext(comm_context);
208+
192209
if (op_name.compare(paddle::dialect::ReduceScatterOp::name()) == 0 ||
193210
op_name.compare(paddle::dialect::AllReduceOp::name()) == 0 ||
194211
op_name.compare(paddle::dialect::AllReduce_Op::name()) == 0 ||
195212
op_name.compare(paddle::dialect::Broadcast_Op::name()) == 0 ||
196213
op_name.compare(paddle::dialect::BroadcastOp::name()) == 0 ||
197214
op_name.compare(paddle::dialect::AllGatherOp::name()) == 0 ||
215+
op_name.compare(paddle::dialect::MpAllreduceSumOp::name()) == 0 ||
198216
op_name.compare(paddle::dialect::MpAllreduceSum_Op::name()) == 0 ||
199217
op_name.compare(paddle::dialect::CIdentity_Op::name()) == 0 ||
200218
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
@@ -203,7 +221,25 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
203221
op_name.compare(paddle::dialect::AllToAllOp::name()) == 0 ||
204222
op_name.compare(
205223
paddle::dialect::CSoftmaxWithCrossEntropyOp::name()) == 0) {
224+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
225+
if (phi::is_custom_place(place) &&
226+
execution_stream == kDefaultStream) {
227+
VLOG(3) << "set stream for " << op_name << "in Custom device";
228+
if (origin_dev_ctx != nullptr) {
229+
// set stream
230+
auto default_stream =
231+
static_cast<phi::CustomContext*>(origin_dev_ctx)->GetStream();
232+
static_cast<phi::CustomContext*>(dev_ctx)->SetStream(
233+
default_stream);
234+
// todo set allocator
235+
} else {
236+
VLOG(3) << "CUSTOM DEVICE op " << op_name << " ring_id "
237+
<< ring_id << " origin_dev_ctx is nullptr";
238+
}
239+
}
240+
#else
206241
if (phi::is_gpu_place(place) && execution_stream == kDefaultStream) {
242+
VLOG(3) << "set stream for " << op_name << "in GPU device";
207243
if (origin_dev_ctx != nullptr) {
208244
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
209245
// set stream
@@ -226,6 +262,7 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
226262
<< " origin_dev_ctx is nullptr";
227263
}
228264
}
265+
#endif
229266
return dev_ctx;
230267
}
231268
} else {

0 commit comments

Comments
 (0)