Skip to content

Commit 2794b51

Browse files
committed
connect all
1 parent 87e65a9 commit 2794b51

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

paddle/fluid/distributed/fleet_executor/carrier.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ USE_INTERCEPTOR(Compute);
2626

2727
void Carrier::Init(
2828
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
29-
framework::Scope* minibatch_scope,
29+
framework::Scope* root_scope, framework::Scope* minibatch_scope,
3030
const std::vector<framework::Scope*>& microbatch_scopes,
3131
const platform::Place& place) {
3232
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
@@ -35,6 +35,8 @@ void Carrier::Init(
3535
minibatch_scope_ = minibatch_scope;
3636
microbatch_scopes_ = microbatch_scopes;
3737
place_ = place;
38+
root_scope_ = root_scope;
39+
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
3840
CreateInterceptors();
3941
is_init_ = true;
4042
}
@@ -105,6 +107,7 @@ void Carrier::Start() {
105107
}
106108
std::unique_lock<std::mutex> lock(running_mutex_);
107109
cond_var_.wait(lock);
110+
dev_ctx_->Wait();
108111
}
109112

110113
std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
@@ -164,6 +167,10 @@ void Carrier::CreateInterceptors() {
164167
// TODO(wangxi): use node_type to select different Interceptor
165168
auto interceptor =
166169
std::make_unique<Interceptor>(interceptor_id, task_node);
170+
interceptor->SetPlace(place_);
171+
interceptor->SetMiniBatchScope(minibatch_scope_);
172+
interceptor->SetMicroBatchScope(microbatch_scopes_);
173+
interceptor->SetRootScope(root_scope_);
167174
SetInterceptor(interceptor_id, std::move(interceptor));
168175
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
169176
<< ".";

paddle/fluid/distributed/fleet_executor/carrier.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
2525
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
26+
#include "paddle/fluid/platform/device_context.h"
2627
#include "paddle/fluid/platform/enforce.h"
2728
#include "paddle/fluid/platform/errors.h"
2829
#include "paddle/fluid/platform/macros.h"
@@ -48,7 +49,7 @@ class Carrier final {
4849

4950
void Init(
5051
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
51-
framework::Scope* minibatch_scope,
52+
framework::Scope* root_scope, framework::Scope* minibatch_scope,
5253
const std::vector<framework::Scope*>& microbatch_scopes,
5354
const platform::Place& place);
5455

@@ -98,8 +99,10 @@ class Carrier final {
9899
std::mutex running_mutex_;
99100
std::condition_variable cond_var_;
100101
std::vector<framework::Scope*> microbatch_scopes_;
102+
framework::Scope* root_scope_;
101103
framework::Scope* minibatch_scope_;
102104
paddle::platform::Place place_;
105+
paddle::platform::DeviceContext* dev_ctx_ = nullptr;
103106
};
104107

105108
} // namespace distributed

paddle/fluid/distributed/fleet_executor/fleet_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
5858
void FleetExecutor::InitCarrier() {
5959
Carrier& carrier_instance = Carrier::Instance();
6060
if (!carrier_instance.IsInit()) {
61-
carrier_instance.Init(runtime_graph_->intercepter_id_to_node(),
61+
carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), root_scope_,
6262
minibatch_scope_, microbatch_scopes_, place_);
6363
}
6464
}

0 commit comments

Comments
 (0)