@@ -26,7 +26,7 @@ USE_INTERCEPTOR(Compute);
2626
2727void Carrier::Init (
2828 const std::unordered_map<int64_t , TaskNode*>& interceptor_id_to_node,
29- framework::Scope* minibatch_scope,
29+ framework::Scope* root_scope, framework::Scope* minibatch_scope,
3030 const std::vector<framework::Scope*>& microbatch_scopes,
3131 const platform::Place& place) {
3232 PADDLE_ENFORCE_EQ (is_init_, false , platform::errors::AlreadyExists (
@@ -35,6 +35,8 @@ void Carrier::Init(
3535 minibatch_scope_ = minibatch_scope;
3636 microbatch_scopes_ = microbatch_scopes;
3737 place_ = place;
38+ root_scope_ = root_scope;
39+ dev_ctx_ = platform::DeviceContextPool::Instance ().Get (place_);
3840 CreateInterceptors ();
3941 is_init_ = true ;
4042}
@@ -105,6 +107,7 @@ void Carrier::Start() {
105107 }
106108 std::unique_lock<std::mutex> lock (running_mutex_);
107109 cond_var_.wait (lock);
110+ dev_ctx_->Wait ();
108111}
109112
110113std::condition_variable& Carrier::GetCondVar () { return cond_var_; }
@@ -164,6 +167,10 @@ void Carrier::CreateInterceptors() {
164167 // TODO(wangxi): use node_type to select different Interceptor
165168 auto interceptor =
166169 std::make_unique<Interceptor>(interceptor_id, task_node);
170+ interceptor->SetPlace (place_);
171+ interceptor->SetMiniBatchScope (minibatch_scope_);
172+ interceptor->SetMicroBatchScope (microbatch_scopes_);
173+ interceptor->SetRootScope (root_scope_);
167174 SetInterceptor (interceptor_id, std::move (interceptor));
168175 VLOG (3 ) << " Create Interceptor with interceptor id: " << interceptor_id
169176 << " ." ;
0 commit comments