@@ -39,6 +39,7 @@ void SplitTensorAndMoveTensorToScopes(
3939 const  std::vector<framework::Scope *> &sub_scopes,
4040 const  std::vector<platform::Place> &places,
4141 const  std::vector<std::string> &names) {
42+  PADDLE_ENFORCE_EQ (sub_scopes.size (), places.size ());
4243 for  (auto  &argu : names) {
4344 auto  *var = scope.FindVar (argu);
4445 const  auto  &tensor = var->Get <LoDTensor>();
@@ -54,6 +55,15 @@ void SplitTensorAndMoveTensorToScopes(
5455 }
5556}
5657
58+ void  WaitOnPlaces (const  std::vector<platform::Place> places) {
59+  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
60+ 
61+  for  (auto  &place : places) {
62+  auto  &dev_ctx = *pool.Get (place);
63+  dev_ctx.Wait ();
64+  }
65+ }
66+ 
5767class  ParallelDoOp  : public  framework ::OperatorBase {
5868 public: 
5969 ParallelDoOp (const  std::string &type,
@@ -71,19 +81,30 @@ class ParallelDoOp : public framework::OperatorBase {
7181 auto  *block = Attr<framework::BlockDesc *>(kParallelBlock );
7282 auto  *program = block->Program ();
7383
74-  //  TODO(tonyyang-svail): get places from input
75-  std::vector<platform::Place> places;
76-  places.emplace_back (platform::CPUPlace ());
77-  places.emplace_back (platform::CPUPlace ());
84+  auto  &places = scope.FindVar (Input (kPlaces ))->Get <platform::PlaceList>();
7885
7986 auto  &sub_scopes = *scope.FindVar (Output (kParallelScopes ))
8087 ->GetMutable <std::vector<framework::Scope *>>();
8188 for  (size_t  place_idx = 0 ; place_idx < places.size (); ++place_idx) {
8289 sub_scopes.push_back (&scope.NewScope ());
8390 }
8491
92+  //  split input
8593 SplitTensorAndMoveTensorToScopes (scope, sub_scopes, places,
8694 Inputs (kInputs ));
95+  //  copy parameter
96+  for  (auto  ¶m : Inputs (kParameters )) {
97+  PADDLE_ENFORCE (scope.FindVar (param)->IsType <LoDTensor>(),
98+  " Only support parameter type as LoDTensor" 
99+  auto  &src = scope.FindVar (param)->Get <LoDTensor>();
100+  for  (size_t  i = 0 ; i < places.size (); ++i) {
101+  auto  &place = places[i];
102+  auto  *sub_scope = sub_scopes[i];
103+  auto  *dst = sub_scope->Var (param)->GetMutable <LoDTensor>();
104+  framework::Copy (src, place, dst);
105+  }
106+  }
107+  WaitOnPlaces (places);
87108
88109 std::vector<std::future<void >> workers;
89110 workers.reserve (places.size ());
@@ -93,12 +114,6 @@ class ParallelDoOp : public framework::OperatorBase {
93114 auto  &place = places[place_idx];
94115 auto  *cur_scope = sub_scopes[place_idx];
95116
96-  //  copy parameter
97-  //  some version of boost lacks != for boost::variant
98-  if  (!(dev_ctx.GetPlace () == place)) {
99-  PADDLE_THROW (" Not Implemented" 
100-  }
101- 
102117 workers.emplace_back (framework::Async ([program, cur_scope, place, block] {
103118 framework::Executor executor (place);
104119 executor.Run (*program, cur_scope, block->ID (),
@@ -108,6 +123,7 @@ class ParallelDoOp : public framework::OperatorBase {
108123 for  (auto  &worker : workers) {
109124 worker.wait ();
110125 }
126+  WaitOnPlaces (places);
111127
112128 //  merge output
113129 for  (auto  &o_name : Outputs (kOutputs )) {
@@ -121,6 +137,7 @@ class ParallelDoOp : public framework::OperatorBase {
121137 scope.FindVar (o_name)->GetMutable <LoDTensor>();
122138 lod_tensor_to_be_merged->MergeLoDTensor (lod_tensors, dev_ctx.GetPlace ());
123139 }
140+  WaitOnPlaces (places);
124141 }
125142};
126143
@@ -161,15 +178,14 @@ class ParallelDoGradOp : public OperatorBase {
161178 auto  &sub_scopes = scope.FindVar (Input (kParallelScopes ))
162179 ->Get <std::vector<framework::Scope *>>();
163180
164-  //  TODO(tonyyang-svail): get places from input
165-  std::vector<platform::Place> places;
166-  places.emplace_back (platform::CPUPlace ());
167-  places.emplace_back (platform::CPUPlace ());
181+  auto  &places = scope.FindVar (Input (kPlaces ))->Get <platform::PlaceList>();
168182
169183 //  feed output@grad
170184 SplitTensorAndMoveTensorToScopes (scope, sub_scopes, places,
171185 Inputs (framework::GradVarName (kOutputs )));
186+  WaitOnPlaces (places);
172187
188+  //  for debugging
173189 for  (auto  &s : Inputs (framework::GradVarName (kOutputs ))) {
174190 VLOG (3 ) << s;
175191 VLOG (3 ) << scope.FindVar (s)->Get <LoDTensor>();
@@ -196,10 +212,11 @@ class ParallelDoGradOp : public OperatorBase {
196212 for  (auto  &worker : workers) {
197213 worker.wait ();
198214 }
215+  WaitOnPlaces (places);
199216
200217 //  merge grad
201218 for  (auto  &s : Outputs (framework::GradVarName (kParameters ))) {
202-  VLOG (3 ) << s;
219+  VLOG (3 ) << " merge grad  "  <<  s;
203220
204221 auto  &t = sub_scopes[0 ]->FindVar (s)->Get <LoDTensor>();
205222 VLOG (3 ) << t;
@@ -216,7 +233,8 @@ class ParallelDoGradOp : public OperatorBase {
216233 auto  sum_op = framework::OpRegistry::CreateOp (
217234 " sum" " X" " Out" 
218235 framework::AttributeMap{});
219-  sum_op->Run (*sub_scopes[0 ], place);
236+  sum_op->Run (*sub_scopes[0 ], places[0 ]);
237+  WaitOnPlaces (places);
220238 }
221239
222240 VLOG (3 ) << t;
@@ -236,8 +254,10 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
236254 for  (auto  &input_param : this ->InputNames ()) {
237255 VLOG (3 ) << input_param;
238256 grad->SetInput (input_param, this ->Input (input_param));
239-  grad->SetOutput (framework::GradVarName (input_param),
240-  this ->InputGrad (input_param, false ));
257+  if  (input_param != kPlaces ) {
258+  grad->SetOutput (framework::GradVarName (input_param),
259+  this ->InputGrad (input_param, false ));
260+  }
241261 }
242262
243263 for  (auto  &output_param : this ->OutputNames ()) {
0 commit comments