Skip to content
35 changes: 22 additions & 13 deletions paddle/fluid/distributed/collective/process_group_custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
return iter->second.get();
}
}

phi::ccl::CCLComm ProcessGroupCustom::XCCLComm(const Place& place) {
const std::string& key = GetKeyFromPlace(place);
phi::DeviceGuard guard(place);
Expand All @@ -164,6 +163,16 @@ phi::ccl::CCLComm ProcessGroupCustom::XCCLComm(const Place& place) {
return iter->second->xccl_comm();
}

phi::distributed::XCCLCommContext* ProcessGroupCustom::GetOrCreateCommContext(
const Place& place) {
const std::string& key = GetKeyFromPlace(place);
phi::DeviceGuard guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateXCCLEnvCache(place, key);
}
return this->GetCommContext();
}

std::string ProcessGroupCustom::GetCommName(int rank) {
PADDLE_ENFORCE_GE(rank,
0,
Expand Down Expand Up @@ -592,24 +601,24 @@ void ProcessGroupCustom::CreateXCCLEnvCache(const Place& place,

auto* calc_ctx = static_cast<phi::CustomContext*>(
phi::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::CustomContext>(place);
comm_ctx->SetAllocator(
auto custom_context = std::make_unique<phi::CustomContext>(place);
custom_context->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetAllocator()));
comm_ctx->SetHostAllocator(
custom_context->SetHostAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetHostAllocator()));
comm_ctx->SetZeroAllocator(
custom_context->SetZeroAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetZeroAllocator()));
comm_ctx->SetHostZeroAllocator(
custom_context->SetHostZeroAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetHostZeroAllocator()));

auto xccl_comm_ctx = this->GetCommContext();
comm_ctx->set_xccl_comm(xccl_comm_ctx->GetXcclComm());
custom_context->set_xccl_comm(xccl_comm_ctx->GetXcclComm());

auto xccl_event = std::make_unique<phi::event::Event>();
xccl_event->Init(place);
place_to_calc_event_.emplace(place_key, std::move(xccl_event));
place_to_calc_ctx_.emplace(place_key, calc_ctx);
place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
place_to_comm_ctx_.emplace(place_key, std::move(custom_context));

// TODO(sunyilun): for compatibility, will be removed later
std::vector<phi::CustomContext*> comm_ctx_wrapper{
Expand All @@ -621,9 +630,9 @@ void ProcessGroupCustom::SyncCalcStream(const Place& place) {
const std::string& key = GetKeyFromPlace(place);
auto& calc_event = place_to_calc_event_.at(key);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
const auto* custom_context = place_to_comm_ctx_.at(key).get();
calc_event->Record(calc_ctx->GetStream().get());
comm_ctx->GetStream()->WaitEvent(calc_event.get());
custom_context->GetStream()->WaitEvent(calc_event.get());
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::RunFnInXCCLEnv(
Expand All @@ -648,16 +657,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::RunFnInXCCLEnv(
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);

const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
const auto& custom_context = place_to_comm_ctx_.at(key);
auto& xccl_stream =
use_calc_stream ? *calc_ctx->GetStream() : *comm_ctx->GetStream();
use_calc_stream ? *calc_ctx->GetStream() : *custom_context->GetStream();
fn(xccl_stream);

if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(tensor.Holder(), xccl_stream.raw_stream());
}
task->UpdateWaitChain(*comm_ctx);
task->UpdateWaitChain(*custom_context);
}

return task;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,15 @@ class ProcessGroupCustom final : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;

using ProcessGroupWithStream::Recv;

std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;

using ProcessGroupWithStream::Send;
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
Expand All @@ -174,6 +176,8 @@ class ProcessGroupCustom final : public ProcessGroupWithStream {

phi::ccl::CCLComm XCCLComm(const Place& place);

phi::distributed::XCCLCommContext* GetOrCreateCommContext(const Place& place);

// TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
Expand Down
92 changes: 28 additions & 64 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,72 +268,36 @@ else()
DEPS lod_tensor selected_rows_utils phi common var_type_traits op_info)
endif()

if(WITH_XPU)
cc_library(
operator
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
infershape_utils.cc
DEPS op_info
proto_desc
tensor
scope
glog
shape_inference
data_transform
lod_tensor
op_kernel_type
op_call_stack
detail_op_handle
phi_utils
phi
common
op_compat_infos
type_info)
elseif(WITH_NCCL OR WITH_RCCL)
cc_library(
operator
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
infershape_utils.cc
DEPS op_info
proto_desc
tensor
scope
glog
shape_inference
data_transform
lod_tensor
op_kernel_type
op_call_stack
detail_op_handle
phi_utils
phi
common
op_compat_infos
type_info
process_group_nccl)
else()
cc_library(
operator
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
infershape_utils.cc
DEPS op_info
proto_desc
tensor
scope
glog
shape_inference
data_transform
lod_tensor
op_kernel_type
op_call_stack
detail_op_handle
phi_utils
phi
common
op_compat_infos
type_info)
set(OPERETER_DEPS
op_info
proto_desc
tensor
scope
glog
shape_inference
data_transform
lod_tensor
op_kernel_type
op_call_stack
detail_op_handle
phi_utils
phi
common
op_compat_infos
type_info)

if(WITH_NCCL OR WITH_RCCL)
set(OPERETER_DEPS ${OPERETER_DEPS} process_group_nccl)
elseif(WITH_CUSTOM_DEVICE)
set(OPERETER_DEPS ${OPERETER_DEPS} process_group_custom)
endif()

cc_library(
operator
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
infershape_utils.cc
DEPS ${OPERETER_DEPS})

cc_library(version SRCS version.cc)

add_library(proto_desc_base OBJECT var_desc.cc op_desc.cc block_desc.cc
Expand Down
57 changes: 47 additions & 10 deletions paddle/fluid/framework/new_executor/instruction/instruction_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,23 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/pir/include/core/block_argument.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/common/flags.h"
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/platform/collective_helper.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/distributed/collective/process_group_custom.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#elif defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/distributed/collective/process_group_bkcl.h"
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
#else
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/distributed/collective/process_group_bkcl.h"
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
#include "paddle/phi/core/platform/collective_helper.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand All @@ -66,10 +69,14 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
phi::distributed::CommContextManager::CreateBKCLCommContext
#define PLATFORM_COMM_CONTEXT platform::BKCLCommContext
#define PROCESS_GROUP paddle::distributed::ProcessGroupBKCL
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
#define COMM_CONTEXT phi::distributed::XCCLCommContext
#define CREATE_COMM_CONTEXT \
phi::distributed::CommContextManager::CreateXCCLCommContext
#define PROCESS_GROUP paddle::distributed::ProcessGroupCustom
#endif

namespace paddle::framework {

std::vector<int> GetValueIds(pir::Value value,
const ValueExecutionInfo& value_exec_info) {
std::vector<int> ids;
Expand Down Expand Up @@ -131,7 +138,7 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
}

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
// NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum
// with use_cal_stream==false by returning a device context getting from the
// global NCCLCommContext instance. Because when use_calc_stream==false, in
Expand All @@ -155,9 +162,16 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
comm_context_manager.Get(std::to_string(ring_id)))
->GetDevContext());
} else {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
PADDLE_ENFORCE(
false,
common::errors::InvalidArgument(
"Custom device does not support old communication context."));
#else
dev_ctx = PLATFORM_COMM_CONTEXT::Instance()
.Get(ring_id, place)
->dev_context();
#endif
}
return dev_ctx;
}
Expand All @@ -175,6 +189,8 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
comm_context = comm_context_manager.Get(std::to_string(ring_id));
} else if (op_name.compare(paddle::dialect::MpAllreduceSum_Op::name()) ==
0 ||
op_name.compare(paddle::dialect::MpAllreduceSumOp::name()) ==
0 ||
op_name.compare(paddle::dialect::AllReduce_Op::name()) == 0 ||
op_name.compare(paddle::dialect::CIdentity_Op::name()) == 0 ||
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
Expand All @@ -189,12 +205,14 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
dev_ctx = static_cast<platform::DeviceContext*>(
static_cast<COMM_CONTEXT*>(comm_context)->GetDevContext());
dev_ctx->SetCommContext(comm_context);

if (op_name.compare(paddle::dialect::ReduceScatterOp::name()) == 0 ||
op_name.compare(paddle::dialect::AllReduceOp::name()) == 0 ||
op_name.compare(paddle::dialect::AllReduce_Op::name()) == 0 ||
op_name.compare(paddle::dialect::Broadcast_Op::name()) == 0 ||
op_name.compare(paddle::dialect::BroadcastOp::name()) == 0 ||
op_name.compare(paddle::dialect::AllGatherOp::name()) == 0 ||
op_name.compare(paddle::dialect::MpAllreduceSumOp::name()) == 0 ||
op_name.compare(paddle::dialect::MpAllreduceSum_Op::name()) == 0 ||
op_name.compare(paddle::dialect::CIdentity_Op::name()) == 0 ||
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
Expand All @@ -203,7 +221,25 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
op_name.compare(paddle::dialect::AllToAllOp::name()) == 0 ||
op_name.compare(
paddle::dialect::CSoftmaxWithCrossEntropyOp::name()) == 0) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (phi::is_custom_place(place) &&
execution_stream == kDefaultStream) {
VLOG(3) << "set stream for " << op_name << "in Custom device";
if (origin_dev_ctx != nullptr) {
// set stream
auto default_stream =
static_cast<phi::CustomContext*>(origin_dev_ctx)->GetStream();
static_cast<phi::CustomContext*>(dev_ctx)->SetStream(
default_stream);
// todo set allocator
} else {
VLOG(3) << "CUSTOM DEVICE op " << op_name << " ring_id "
<< ring_id << " origin_dev_ctx is nullptr";
}
}
#else
if (phi::is_gpu_place(place) && execution_stream == kDefaultStream) {
VLOG(3) << "set stream for " << op_name << "in GPU device";
if (origin_dev_ctx != nullptr) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// set stream
Expand All @@ -226,6 +262,7 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
<< " origin_dev_ctx is nullptr";
}
}
#endif
return dev_ctx;
}
} else {
Expand Down
Loading