@@ -428,19 +428,19 @@ void build_op_func_list(const platform::Place& place,
428428 op_func_node.dev_ctx_ = dev_ctx;
429429 VLOG (3 ) << op_with_kernel->Type ()
430430 << " : expected_kernel_key : " << expected_kernel_key;
431- auto exec_ctx =
432- ExecutionContext (*op_with_kernel, scope, *dev_ctx, runtime_context);
433431
434432 // see OperatorWithKernel::RunImpl in operator.cc for why
435433 if (!(op->HasAttr (kAllKernelsMustComputeRuntimeShape ) &&
436434 op->Attr <bool >(kAllKernelsMustComputeRuntimeShape ))) {
437435 InterpretercoreInferShapeContext infer_shape_ctx (*op, runtime_context);
438436 // TODO(Aurelius84): In case of control flow ops, they are NOT
439- // inheritted
440- // from OperatorWithKernel.
437+ // inheritted from OperatorWithKernel.
441438 op_with_kernel->Info ().infer_shape_ (&infer_shape_ctx);
442439 }
443440
441+ auto exec_ctx =
442+ ExecutionContext (*op_with_kernel, scope, *dev_ctx, runtime_context);
443+
444444 auto run_phi_kernel = false ;
445445 if (phi::KernelFactory::Instance ().HasCompatiblePhiKernel (
446446 op_with_kernel->Type ())) {
@@ -476,7 +476,6 @@ void build_op_func_list(const platform::Place& place,
476476 op_with_kernel->BuildPhiKernelContext (runtime_context, dev_ctx,
477477 &pt_kernel_context);
478478 op_func_node.pt_kernel_ = op_with_kernel->PhiKernel ();
479-
480479 (*op_func_node.pt_kernel_ )(&pt_kernel_context);
481480 } else {
482481 auto kernels_iter = all_op_kernels.find (op->Type ());
@@ -711,6 +710,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
711710 const std::set<std::string> random_op_set = {
712711 " bernoulli" , " poisson" , " multinomial" , " gaussian_random" ,
713712 " uniform_random" , " randint" , " randperm" , " exponential" };
713+
714714 int dependence_op_idx = -1 ;
715715 for (size_t op_idx = 0 ; op_idx < vec_instruction.size (); ++op_idx) {
716716 if (random_op_set.count (vec_instruction[op_idx].OpBase ()->Type ())) {
@@ -721,6 +721,147 @@ std::map<int, std::list<int>> build_op_downstream_map(
721721 }
722722 }
723723
724+ // add dependency for communication op
725+ const std::string communication_op_prefix = " c_" ;
726+ dependence_op_idx = -1 ;
727+ for (size_t op_idx = 0 ; op_idx < vec_instruction.size (); ++op_idx) {
728+ if (vec_instruction[op_idx].OpBase ()->Type ().find (
729+ communication_op_prefix) != std::string::npos) {
730+ if (dependence_op_idx != -1 ) {
731+ op2dependences[op_idx].insert (dependence_op_idx);
732+ }
733+ dependence_op_idx = op_idx;
734+ }
735+ }
736+
737+ // TODO(zhiqiu): there still some cases not handled
738+ // add dependency for c_sync_comm_stream
739+
740+ // in program, we can add only one c_sync_comm_stream to sync all
741+ // communication ops.
742+ // c_allreduce_sum(a)
743+ // c_allreduce_sum(b)
744+ // c_allreduce_sum(c)
745+ // c_sync_comm_stream(a)
746+ const std::string kSyncComm = " c_sync_comm_stream" ;
747+ dependence_op_idx = -1 ;
748+ for (size_t op_idx = 0 ; op_idx < vec_instruction.size (); ++op_idx) {
749+ if (vec_instruction[op_idx].OpBase ()->Type () == kSyncComm ) {
750+ dependence_op_idx = op_idx;
751+ } else {
752+ if (dependence_op_idx != -1 ) {
753+ VLOG (4 ) << " Add depend from "
754+ << vec_instruction[dependence_op_idx].OpBase ()->Type () << " to "
755+ << vec_instruction[op_idx].OpBase ()->Type ();
756+ op2dependences[op_idx].insert (dependence_op_idx);
757+ }
758+ }
759+ }
760+
761+ // add dependency for coalesce_tensor
762+ const std::string kCoalesceTensor = " coalesce_tensor" ;
763+ for (size_t op_idx = 0 ; op_idx < vec_instruction.size (); ++op_idx) {
764+ if (vec_instruction[op_idx].OpBase ()->Type () == kCoalesceTensor ) {
765+ VLOG (4 ) << " Add depend for " << kCoalesceTensor << " " << op_idx;
766+ auto fused_out = vec_instruction[op_idx].Outputs ().at (" FusedOutput" )[0 ];
767+ auto outputs = vec_instruction[op_idx].Outputs ().at (" Output" );
768+
769+ auto is_read = [](const Instruction& inst, int var_id) -> bool {
770+ for (auto pair : inst.Inputs ()) {
771+ for (auto item : pair.second ) {
772+ if (item == var_id) {
773+ return true ;
774+ }
775+ }
776+ }
777+ return false ;
778+ };
779+
780+ auto is_write = [](const Instruction& inst, int var_id) -> bool {
781+ for (auto pair : inst.Outputs ()) {
782+ for (auto item : pair.second ) {
783+ if (item == var_id) {
784+ return true ;
785+ }
786+ }
787+ }
788+ return false ;
789+ };
790+
791+ // find first op that reads fused_out
792+ auto first_read_fused_out_op = -1 ;
793+ for (auto j = op_idx + 1 ; j < vec_instruction.size (); ++j) {
794+ if (is_read (vec_instruction[j], fused_out)) {
795+ first_read_fused_out_op = j;
796+ break ;
797+ }
798+ }
799+
800+ if (UNLIKELY (first_read_fused_out_op == -1 )) {
801+ VLOG (4 ) << " No op read FusedOutput" ;
802+ continue ;
803+ }
804+
805+ // find ops that write 'outputs' between (op_index,
806+ // first_read_fused_out_op)
807+ // add depend: them->first_read_fused_out_op
808+ for (auto j = op_idx + 1 ;
809+ j < static_cast <size_t >(first_read_fused_out_op); ++j) {
810+ for (auto var_id : outputs) {
811+ if (is_write (vec_instruction[j], var_id)) {
812+ op2dependences[first_read_fused_out_op].insert (j);
813+ VLOG (4 ) << j << " -> " << first_read_fused_out_op;
814+ VLOG (4 )
815+ << " Add depend from " << vec_instruction[j].OpBase ()->Type ()
816+ << " to "
817+ << vec_instruction[first_read_fused_out_op].OpBase ()->Type ();
818+ }
819+ }
820+ }
821+
822+ // find first op read 'outputs' between (first_read_fused_out_op, end)
823+ // add depned: first_read_fused_out_op -> first op that reads 'outputs'
824+
825+ // special case for consecutive communication ops, for example,
826+ // FusedOutput = c_sync_calc_stream(FusedOutput)
827+ // FusedOutput= c_allreduce_sum(FusedOutput)
828+ // FusedOutput = c_sync_comm_stream(FusedOutput)
829+ // we should take the last one to add depned instead of
830+ // 'first_read_fused_out_op'
831+ size_t target = first_read_fused_out_op;
832+ for (size_t j = first_read_fused_out_op + 1 ; j < vec_instruction.size ();
833+ ++j) {
834+ if (j == target + 1 &&
835+ vec_instruction[target].OpBase ()->Type ().find (
836+ communication_op_prefix) != std::string::npos &&
837+ vec_instruction[j].OpBase ()->Type ().find (communication_op_prefix) !=
838+ std::string::npos) {
839+ VLOG (4 ) << " Found consecutive communication ops, "
840+ << vec_instruction[target].OpBase ()->Type () << " -> "
841+ << vec_instruction[j].OpBase ()->Type ();
842+ target = j;
843+ continue ;
844+ }
845+
846+ for (auto var_id : outputs) {
847+ if (is_read (vec_instruction[j], var_id)) {
848+ op2dependences[j].insert (target);
849+ VLOG (4 ) << target << " -> " << j;
850+ VLOG (4 ) << " Add depend from "
851+ << vec_instruction[target].OpBase ()->Type () << " to "
852+ << vec_instruction[j].OpBase ()->Type ();
853+ }
854+ }
855+ }
856+ }
857+ }
858+ for (auto pair : op2dependences) {
859+ VLOG (10 ) << pair.first << " Depends on " << pair.second .size ();
860+ std::ostringstream oss;
861+ std::copy (pair.second .begin (), pair.second .end (),
862+ std::ostream_iterator<int >(oss, " " ));
863+ VLOG (10 ) << oss.str ();
864+ }
724865 return std::move (get_downstream_map (op2dependences));
725866}
726867
0 commit comments