Skip to content

Commit 4a76d1d

Browse files
committed
Add enable/disable for delayed ops
1 parent 03673c8 commit 4a76d1d

File tree

7 files changed

+37
-15
lines changed

7 files changed

+37
-15
lines changed

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2323
size_t num_threads, bool use_event,
2424
const std::vector<Scope *> &local_scopes,
2525
const std::vector<platform::Place> &places,
26-
std::unique_ptr<SSAGraph> &&graph)
26+
std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay)
2727
: SSAGraphExecutor(std::move(graph)),
2828
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
2929
local_scopes_(local_scopes),
3030
places_(places),
3131
fetch_ctxs_(places),
3232
use_event_(use_event),
33-
running_ops_(0) {}
33+
running_ops_(0),
34+
allow_op_delay_(allow_op_delay) {}
3435

3536
void ThreadedSSAGraphExecutor::RunDelayedOps(
3637
const std::unordered_set<OpHandleBase *> &delayed_ops) {
@@ -119,7 +120,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
119120

120121
auto run_all_ready_ops = [&] {
121122
for (auto *op : ready_ops) {
122-
if (op->IsMultiDeviceTransfer()) {
123+
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
123124
delayed_ops.insert(op);
124125
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
125126
ready_vars.Extend(op->outputs_);

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
7575
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
7676
const std::vector<Scope *> &local_scopes,
7777
const std::vector<platform::Place> &places,
78-
std::unique_ptr<SSAGraph> &&graph);
78+
std::unique_ptr<SSAGraph> &&graph,
79+
bool allow_op_delay);
7980

8081
// Run a SSAGraph by a thread pool
8182
// Use topological sort algorithm
@@ -97,6 +98,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
9798
const bool use_event_;
9899
std::unique_ptr<platform::EnforceNotMet> exception_;
99100
std::atomic<int> running_ops_;
101+
bool allow_op_delay_;
100102

101103
size_t computation_count_{0};
102104
size_t max_async_computation{100};

paddle/fluid/framework/parallel_executor.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ ParallelExecutor::ParallelExecutor(
4949
const std::vector<platform::Place> &places,
5050
const std::unordered_set<std::string> &params,
5151
const ProgramDesc &startup_program, const ProgramDesc &main_program,
52-
const std::string &loss_var_name, Scope *scope)
52+
const std::string &loss_var_name, Scope *scope, bool allow_op_delay)
5353
: member_(new ParallelExecutorPrivate(places)) {
5454
member_->global_scope_ = scope;
5555

@@ -84,8 +84,8 @@ ParallelExecutor::ParallelExecutor(
8484
auto graph = builder.Build(main_program);
8585

8686
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
87-
num_threads, use_event, member_->local_scopes_, places,
88-
std::move(graph)));
87+
num_threads, use_event, member_->local_scopes_, places, std::move(graph),
88+
allow_op_delay));
8989

9090
// Step 3. Create vars in each scope;
9191
for (auto *scope : member_->local_scopes_) {

paddle/fluid/framework/parallel_executor.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include <future>
17+
#include <string>
1818
#include <unordered_set>
19+
#include <vector>
1920
#include "paddle/fluid/framework/executor.h"
2021
#include "paddle/fluid/framework/op_info.h"
2122
#include "paddle/fluid/framework/program_desc.h"
@@ -37,7 +38,8 @@ class ParallelExecutor {
3738
const std::unordered_set<std::string>& params,
3839
const ProgramDesc& startup_program,
3940
const ProgramDesc& main_program,
40-
const std::string& loss_var_name, Scope* scope);
41+
const std::string& loss_var_name, Scope* scope,
42+
bool allow_op_delay);
4143

4244
void Run(const std::vector<std::string>& fetch_tensors,
4345
const std::string& fetched_var_name = "fetched_var");

paddle/fluid/pybind/pybind.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,10 @@ All parameter, weight, gradient are variables in Paddle.
504504
const std::unordered_set<std::string> &params,
505505
const ProgramDesc &startup_program,
506506
const ProgramDesc &main_program, const std::string &loss_var_name,
507-
Scope *scope) {
507+
Scope *scope, bool allow_op_delay) {
508508
new (&self) ParallelExecutor(num_threads, use_event, places,
509509
params, startup_program, main_program,
510-
loss_var_name, scope);
510+
loss_var_name, scope, allow_op_delay);
511511
})
512512
.def("run", &ParallelExecutor::Run);
513513

python/paddle/fluid/parallel_executor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121

2222

2323
class ParallelExecutor(object):
24-
def __init__(self, loss_name, use_cuda, num_threads=None):
24+
def __init__(self,
25+
loss_name,
26+
use_cuda,
27+
num_threads=None,
28+
allow_op_delay=False):
2529
places = []
2630
if use_cuda:
2731
for i in xrange(core.get_cuda_device_count()):
@@ -57,7 +61,8 @@ def __init__(self, loss_name, use_cuda, num_threads=None):
5761
startup.desc,
5862
main.desc,
5963
loss_name,
60-
scope)
64+
scope,
65+
allow_op_delay)
6166
self.scope = scope
6267

6368
def run(self, fetch_list):

python/paddle/fluid/tests/unittests/test_parallel_executor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def check_network_convergence(self,
184184
method,
185185
memory_opt=True,
186186
iter=10,
187-
batch_size=None):
187+
batch_size=None,
188+
allow_op_delay=False):
188189
main = fluid.Program()
189190
startup = fluid.Program()
190191
with fluid.program_guard(main, startup):
@@ -194,7 +195,10 @@ def check_network_convergence(self,
194195
if memory_opt:
195196
fluid.memory_optimize(main)
196197

197-
exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True)
198+
exe = fluid.ParallelExecutor(
199+
loss_name=loss.name,
200+
use_cuda=True,
201+
allow_op_delay=allow_op_delay)
198202
if batch_size is not None:
199203
batch_size *= fluid.core.get_cuda_device_count()
200204
begin = time.time()
@@ -236,9 +240,11 @@ def setUpClass(cls):
236240

237241
def test_simple_fc(self):
238242
self.check_network_convergence(simple_fc_net)
243+
self.check_network_convergence(simple_fc_net, allow_op_delay=True)
239244

240245
def test_batchnorm_fc(self):
241246
self.check_network_convergence(fc_with_batchnorm)
247+
self.check_network_convergence(fc_with_batchnorm, allow_op_delay=True)
242248

243249

244250
class TestResnet(TestParallelExecutorBase):
@@ -268,6 +274,12 @@ def test_resnet(self):
268274
SE_ResNeXt152, batch_size=batch_size),
269275
iter=20,
270276
batch_size=batch_size)
277+
self.check_network_convergence(
278+
functools.partial(
279+
SE_ResNeXt152, batch_size=batch_size),
280+
iter=20,
281+
batch_size=batch_size,
282+
allow_op_delay=True)
271283

272284

273285
class ModelHyperParams(object):

0 commit comments

Comments
 (0)