@@ -23,22 +23,36 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2323 size_t num_threads, bool use_event,
2424 const std::vector<Scope *> &local_scopes,
2525 const std::vector<platform::Place> &places,
26- std::unique_ptr<SSAGraph> &&graph)
26+ std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay )
2727 : SSAGraphExecutor(std::move(graph)),
2828 pool_ (num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
2929 local_scopes_(local_scopes),
3030 places_(places),
3131 fetch_ctxs_(places),
32- use_event_(use_event) {}
32+ use_event_(use_event),
33+ running_ops_(0 ),
34+ allow_op_delay_(allow_op_delay) {}
35+
36+ void ThreadedSSAGraphExecutor::RunDelayedOps (
37+ const std::unordered_set<OpHandleBase *> &delayed_ops) {
38+ for (auto op : delayed_ops) {
39+ op->Run (use_event_);
40+ }
41+ }
3342
3443FeedFetchList ThreadedSSAGraphExecutor::Run (
3544 const std::vector<std::string> &fetch_tensors) {
3645 std::unordered_map<OpHandleBase *, size_t > pending_ops;
3746 std::unordered_set<VarHandleBase *> pending_vars;
38-
3947 BlockingQueue<VarHandleBase *> ready_vars;
40-
4148 std::unordered_set<OpHandleBase *> ready_ops;
49+ // For ops (e.g. nccl_all_reduce) that need to coordinate multiple
50+ // streams from multiple GPUs, it's faster to buffer them and schedule
51+ // together since we currently cannot overlap computation and memcpy streams.
52+ // Should revisit it if overlapping is available.
53+ std::unordered_set<OpHandleBase *> delayed_ops;
54+ std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
55+ std::unordered_set<VarHandleBase *> delayed_vars;
4256
4357 auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
4458 pending_vars.insert (&var);
@@ -106,7 +120,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
106120
107121 auto run_all_ready_ops = [&] {
108122 for (auto *op : ready_ops) {
109- RunOp (ready_vars, op);
123+ if (op->IsMultiDeviceTransfer () && allow_op_delay_) {
124+ delayed_ops.insert (op);
125+ delayed_vars.insert (op->outputs_ .begin (), op->outputs_ .end ());
126+ ready_vars.Extend (op->outputs_ );
127+ continue ;
128+ }
129+ running_ops_++;
130+ RunOp (&ready_vars, op);
110131 }
111132 ready_ops.clear ();
112133 };
@@ -118,13 +139,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
118139 }
119140
120141 // Step 3. Execution
121- while (!pending_vars.empty ()) {
142+ while (!pending_vars.empty () || !ready_ops. empty () || !delayed_ops. empty () ) {
122143 // 1. Run All Ready ops
123144 run_all_ready_ops ();
124145
125146 // 2. Find ready variable
126147 bool timeout;
127- auto cur_ready_vars = ready_vars.PopAll (1000 , &timeout);
148+ auto cur_ready_vars = ready_vars.PopAll (1 , &timeout);
128149
129150 if (timeout) {
130151 if (exception_) {
@@ -141,13 +162,29 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
141162 auto &deps = pending_ops[op];
142163 --deps;
143164 if (deps == 0 ) {
144- ready_ops.insert (op);
165+ if (delayed_vars.find (ready_var) != delayed_vars.end ()) {
166+ blocked_by_delayed_ops.insert (op);
167+ } else {
168+ ready_ops.insert (op);
169+ }
145170 }
146171 }
147172 }
173+ // When there are no other ops to schedule, schedule buffered delayed
174+ // ops and unblock other ops.
175+ if (ready_ops.empty () && !delayed_ops.empty () && running_ops_ == 0 ) {
176+ RunDelayedOps (delayed_ops);
177+ delayed_ops.clear ();
178+ for (auto *op : blocked_by_delayed_ops) {
179+ ready_ops.insert (op);
180+ }
181+ blocked_by_delayed_ops.clear ();
182+ }
148183 // Keep loop until all vars are ready.
149184 }
150-
185+ PADDLE_ENFORCE (ready_ops.empty ());
186+ PADDLE_ENFORCE (delayed_ops.empty ());
187+ PADDLE_ENFORCE (blocked_by_delayed_ops.empty ());
151188 ++computation_count_;
152189
153190 auto sync_computation = [&] {
@@ -182,12 +219,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
182219}
183220
184221void ThreadedSSAGraphExecutor::RunOp (
185- BlockingQueue<VarHandleBase *> & ready_var_q, details::OpHandleBase *op) {
186- auto op_run = [& ready_var_q, op, this ] {
222+ BlockingQueue<VarHandleBase *> * ready_var_q, details::OpHandleBase *op) {
223+ auto op_run = [ready_var_q, op, this ] {
187224 try {
188225 VLOG (10 ) << op->Name () << " : " << op->DebugString ();
189226 op->Run (use_event_);
190- ready_var_q.Extend (op->outputs_ );
227+ running_ops_--;
228+ ready_var_q->Extend (op->outputs_ );
191229 } catch (platform::EnforceNotMet ex) {
192230 exception_.reset (new platform::EnforceNotMet (ex));
193231 } catch (...) {
0 commit comments