3636#include  " paddle/phi/core/dense_tensor.h" 
3737#include  " paddle/pir/include/core/block_argument.h" 
3838#if  defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
39-  defined (PADDLE_WITH_XPU_BKCL)
39+  defined (PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE) 
4040#include  " paddle/common/flags.h" 
4141#include  " paddle/fluid/distributed/collective/process_group.h" 
4242#include  " paddle/phi/core/distributed/comm_context_manager.h" 
43- #include  " paddle/phi/core/platform/collective_helper.h" 
44- COMMON_DECLARE_bool (dynamic_static_unified_comm);
45- #endif 
46- #if  defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
43+ #if  defined(PADDLE_WITH_CUSTOM_DEVICE)
44+ #include  " paddle/fluid/distributed/collective/process_group_custom.h" 
45+ #include  " paddle/phi/backends/custom/custom_context.h" 
46+ #include  " paddle/phi/core/distributed/xccl_comm_context.h" 
47+ #elif  defined(PADDLE_WITH_XPU_BKCL)
48+ #include  " paddle/fluid/distributed/collective/process_group_bkcl.h" 
49+ #include  " paddle/phi/core/distributed/bkcl_comm_context.h" 
50+ #else 
4751#include  " paddle/fluid/distributed/collective/process_group_nccl.h" 
4852#include  " paddle/phi/core/distributed/nccl_comm_context.h" 
4953#endif 
50- #if  defined(PADDLE_WITH_XPU_BKCL)
51- #include  " paddle/fluid/distributed/collective/process_group_bkcl.h" 
52- #include  " paddle/phi/core/distributed/bkcl_comm_context.h" 
54+ #include  " paddle/phi/core/platform/collective_helper.h" 
55+ COMMON_DECLARE_bool (dynamic_static_unified_comm);
5356#endif 
5457
5558#if  defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
@@ -66,10 +69,14 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
6669 phi::distributed::CommContextManager::CreateBKCLCommContext
6770#define  PLATFORM_COMM_CONTEXT  platform::BKCLCommContext
6871#define  PROCESS_GROUP  paddle::distributed::ProcessGroupBKCL
72+ #elif  defined(PADDLE_WITH_CUSTOM_DEVICE)
73+ #define  COMM_CONTEXT  phi::distributed::XCCLCommContext
74+ #define  CREATE_COMM_CONTEXT  \
75+  phi::distributed::CommContextManager::CreateXCCLCommContext
76+ #define  PROCESS_GROUP  paddle::distributed::ProcessGroupCustom
6977#endif 
7078
7179namespace  paddle ::framework {
72- 
7380std::vector<int > GetValueIds (pir::Value value,
7481 const  ValueExecutionInfo& value_exec_info) {
7582 std::vector<int > ids;
@@ -131,7 +138,7 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
131138 }
132139
133140#if  defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
134-  defined (PADDLE_WITH_XPU_BKCL)
141+  defined (PADDLE_WITH_XPU_BKCL) ||  defined (PADDLE_WITH_CUSTOM_DEVICE) 
135142 //  NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum
136143 //  with use_cal_stream==false by returning a device context getting from the
137144 //  global NCCLCommContext instance. Because when use_calc_stream==false, in
@@ -155,9 +162,16 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
155162 comm_context_manager.Get (std::to_string (ring_id)))
156163 ->GetDevContext ());
157164 } else  {
165+ #if  defined(PADDLE_WITH_CUSTOM_DEVICE)
166+  PADDLE_ENFORCE (
167+  false ,
168+  common::errors::InvalidArgument (
169+  " Custom device does not support old communication context." 
170+ #else 
158171 dev_ctx = PLATFORM_COMM_CONTEXT::Instance ()
159172 .Get (ring_id, place)
160173 ->dev_context ();
174+ #endif 
161175 }
162176 return  dev_ctx;
163177 }
@@ -175,6 +189,8 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
175189 comm_context = comm_context_manager.Get (std::to_string (ring_id));
176190 } else  if  (op_name.compare (paddle::dialect::MpAllreduceSum_Op::name ()) ==
177191 0  ||
192+  op_name.compare (paddle::dialect::MpAllreduceSumOp::name ()) ==
193+  0  ||
178194 op_name.compare (paddle::dialect::AllReduce_Op::name ()) == 0  ||
179195 op_name.compare (paddle::dialect::CIdentity_Op::name ()) == 0  ||
180196 op_name.compare (paddle::dialect::CConcatOp::name ()) == 0  ||
@@ -189,12 +205,14 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
189205 dev_ctx = static_cast <platform::DeviceContext*>(
190206 static_cast <COMM_CONTEXT*>(comm_context)->GetDevContext ());
191207 dev_ctx->SetCommContext (comm_context);
208+ 
192209 if  (op_name.compare (paddle::dialect::ReduceScatterOp::name ()) == 0  ||
193210 op_name.compare (paddle::dialect::AllReduceOp::name ()) == 0  ||
194211 op_name.compare (paddle::dialect::AllReduce_Op::name ()) == 0  ||
195212 op_name.compare (paddle::dialect::Broadcast_Op::name ()) == 0  ||
196213 op_name.compare (paddle::dialect::BroadcastOp::name ()) == 0  ||
197214 op_name.compare (paddle::dialect::AllGatherOp::name ()) == 0  ||
215+  op_name.compare (paddle::dialect::MpAllreduceSumOp::name ()) == 0  ||
198216 op_name.compare (paddle::dialect::MpAllreduceSum_Op::name ()) == 0  ||
199217 op_name.compare (paddle::dialect::CIdentity_Op::name ()) == 0  ||
200218 op_name.compare (paddle::dialect::CConcatOp::name ()) == 0  ||
@@ -203,7 +221,25 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
203221 op_name.compare (paddle::dialect::AllToAllOp::name ()) == 0  ||
204222 op_name.compare (
205223 paddle::dialect::CSoftmaxWithCrossEntropyOp::name ()) == 0 ) {
224+ #ifdef  PADDLE_WITH_CUSTOM_DEVICE
225+  if  (phi::is_custom_place (place) &&
226+  execution_stream == kDefaultStream ) {
227+  VLOG (3 ) << " set stream for " " in Custom device" 
228+  if  (origin_dev_ctx != nullptr ) {
229+  //  set stream
230+  auto  default_stream =
231+  static_cast <phi::CustomContext*>(origin_dev_ctx)->GetStream ();
232+  static_cast <phi::CustomContext*>(dev_ctx)->SetStream (
233+  default_stream);
234+  //  todo set allocator
235+  } else  {
236+  VLOG (3 ) << " CUSTOM DEVICE op " "  ring_id " 
237+  << ring_id << "  origin_dev_ctx is nullptr" 
238+  }
239+  }
240+ #else 
206241 if  (phi::is_gpu_place (place) && execution_stream == kDefaultStream ) {
242+  VLOG (3 ) << " set stream for " " in GPU device" 
207243 if  (origin_dev_ctx != nullptr ) {
208244#if  defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
209245 //  set stream
@@ -226,6 +262,7 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
226262 << "  origin_dev_ctx is nullptr" 
227263 }
228264 }
265+ #endif 
229266 return  dev_ctx;
230267 }
231268 } else  {
0 commit comments