@@ -34,6 +34,14 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap(
3434 // pass
3535 std::unordered_map<GradNodeBase*, int > node_in_degree_map;
3636
37+ // init potential startup node's indegree
38+ std::queue<GradNodeBase*> queue_tmp = init_queue;
39+ while (!queue_tmp.empty ()) {
40+ GradNodeBase* node = queue_tmp.front ();
41+ queue_tmp.pop ();
42+ node_in_degree_map[node] = 0 ;
43+ }
44+
3745 // Copy nodes
3846 std::queue<GradNodeBase*> queue = init_queue;
3947 std::unordered_set<GradNodeBase*> visited;
@@ -164,6 +172,7 @@ void GetGraphInfoBetweenTargets(
164172 }
165173 }
166174 }
175+
167176 UpdateGraphInfo (target_nodes, depending_nodes, potential_stop_nodes);
168177}
169178
@@ -193,17 +202,33 @@ void GetTargetNodesInfo(const std::vector<paddle::experimental::Tensor>& inputs,
193202
194203std::vector<paddle::experimental::Tensor> GetResults (
195204 const std::vector<paddle::experimental::Tensor>& inputs,
196- std::unordered_map<GradNodeBase*, paddle::experimental::Tensor>&
197- result_map) {
205+ const std::unordered_map<GradNodeBase*, paddle::experimental::Tensor>&
206+ results_map,
207+ bool allow_unused) {
198208 VLOG (1 ) << " Run in GetResults" ;
199209 if (inputs.empty ()) return {};
200210
201211 std::vector<paddle::experimental::Tensor> results;
202212 results.reserve (inputs.size ());
203- for (auto input : inputs) {
213+ auto results_map_ = results_map;
214+ for (size_t i = 0 ; i < inputs.size (); ++i) {
215+ auto & input = inputs[i];
204216 AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta (input);
205217 auto target_node = auto_grad_meta->GetMutableGradNode ().get ();
206- results.emplace_back (result_map[target_node]);
218+
219+ if (results_map_.find (target_node) != results_map_.end ()) {
220+ // TODO(wuweilong): set StopGradient
221+ // result_map[target_node].SetOverridedStopGradient(!create_graph_);
222+ results.emplace_back (results_map_[target_node]);
223+ } else {
224+ PADDLE_ENFORCE_EQ (allow_unused, true ,
225+ paddle::platform::errors::InvalidArgument (
226+ " The %d-th input does not appear in the backward "
227+ " graph. Please check the input variable or set "
228+ " allow_unused=True to get None result." ,
229+ i));
230+ results.emplace_back ();
231+ }
207232 }
208233 return results;
209234}
@@ -220,6 +245,20 @@ std::vector<paddle::experimental::Tensor> RunBackward(
220245 // *Inplace version check should perform at node-level
221246 // *Cross-batch accumulation happens at forward pass
222247
248+ /* --- Preprocess --- */
249+
250+ // TODO(wuweilong): output tensor duplicate check
251+ // TODO(wuweilong): build no_grad_vars_grads according no_grad_vars
252+ // TODO(wuweilong): output tensor' gradient is not in no_grad_vars
253+
254+ // TODO(wuweilong): check input tensor has grad op and stop_gradient = False
255+ // TODO(wuweilong): input tensor duplicate check
256+ // TODO(wuweilong): input tensor' gradient is not in no_grad_vars
257+
258+ // TODO(wuweilong): Prune output_targets which is not the input of startup_ops
259+ // TODO(wuweilong): input == output case
260+ // TODO(wuweilong): output_targets.size() should eaqul to output_grads.size()
261+
223262 /* --- Initialization --- */
224263 // 1. Init queue with starting nodes
225264 // 2. Prepare initial input buffers
@@ -288,14 +327,28 @@ std::vector<paddle::experimental::Tensor> RunBackward(
288327 getInDegreeMap (queue);
289328
290329 std::unordered_map<GradNodeBase*, AutogradMeta*> target_nodes_inputmeta_map;
291- std::unordered_set<GradNodeBase*> target_nodes;
330+ std::unordered_set<GradNodeBase*> target_nodes; // should be updated?
292331 GetTargetNodesInfo (inputs, &target_nodes, &target_nodes_inputmeta_map);
293332
294333 std::unordered_map<GradNodeBase*, GradNodeBase*> depending_nodes;
295334 std::unordered_set<GradNodeBase*> potential_stop_nodes;
296335 GetGraphInfoBetweenTargets (queue, &target_nodes, &depending_nodes,
297336 &potential_stop_nodes);
298337
338+ std::unordered_set<GradNodeBase*> startup_ops_;
339+ // ready_queue store all startup nodes
340+ std::queue<GradNodeBase*> ready_queue;
341+
342+ // startup op's indegree should be 0
343+ for (auto & pair : node_in_degree_map) {
344+ if (pair.second == 0 ) {
345+ auto * op = pair.first ;
346+ startup_ops_.emplace (op);
347+ ready_queue.emplace (op);
348+ }
349+ }
350+ VLOG (1 ) << " startup_ops' size is :" << startup_ops_.size ();
351+
299352 std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
300353
301354 /* --- Topological Visit --- */
@@ -306,9 +359,9 @@ std::vector<paddle::experimental::Tensor> RunBackward(
306359 // |- Prepare for next node
307360 // 3. Update queue
308361 VLOG (6 ) << " Run Backward" ;
309- while (!queue .empty ()) {
310- GradNodeBase* node = queue .front ();
311- queue .pop ();
362+ while (!ready_queue .empty ()) {
363+ GradNodeBase* node = ready_queue .front ();
364+ ready_queue .pop ();
312365
313366 // Run node: This is where Hook happens
314367 PADDLE_ENFORCE (
@@ -334,7 +387,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
334387
335388 // Run Pre Backward Node and get outputs
336389 std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
337- (*node)(node_input_buffer->Buffers ());
390+ (*node)(node_input_buffer->Buffers (), create_graph );
338391 // TODO(jiabin): Should we erase it or find a more efficient way.
339392 node_input_buffers_dict.erase (node);
340393
@@ -410,13 +463,13 @@ std::vector<paddle::experimental::Tensor> RunBackward(
410463 }
411464
412465 if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
413- queue .emplace (std::move (next_node));
466+ ready_queue .emplace (std::move (next_node));
414467 }
415468 }
416469 }
417470 }
418471 if (!inputs.empty ()) {
419- return GetResults (inputs, results_map);
472+ return GetResults (inputs, results_map, allow_unused );
420473 }
421474
422475 VLOG (1 ) << " Run backward in the end, return {}" ;
0 commit comments