@@ -577,11 +577,6 @@ static std::string GenerateGradNodeCreationContent(
577577 // If single output slotname and not duplicable,
578578 // then generate: "egr::AutogradMeta* p_autograd_out =
579579 // egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"
580-
581- // TODO(zhanlve): in case of multiple slotname but none of which are
582- // duplicable,
583- // avoid constructing vector<AutogradMeta*>, generate seperate
584- // AutogradMeta* objects respectively.
585580 std::string get_autograd_meta_str = " // Prepare Autograd Meta \n " ;
586581 for (const proto::OpProto::Var& input : op_proto.inputs ()) {
587582 const std::string& input_name = input.name ();
@@ -607,11 +602,6 @@ static std::string GenerateGradNodeCreationContent(
607602 // If single output slotname and not duplicable,
608603 // then generate: "egr::AutogradMeta* p_autograd_out =
609604 // egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
610-
611- // TODO(zhanlve): in case of multiple slotname but none of which are
612- // duplicable,
613- // avoid constructing vector<AutogradMeta*>, generate seperate
614- // AutogradMeta* objects respectively.
615605 for (const proto::OpProto::Var& output : op_proto.outputs ()) {
616606 const std::string& output_name = output.name ();
617607 const std::string& output_autograd_name = " p_autograd_" + output_name;
@@ -725,9 +715,9 @@ static std::string GenerateGradNodeCreationContent(
725715 // [Generation] GradNode Creation
726716 const char * GRAD_NODE_CREATION_TEMPLATE =
727717 " %s"
728- " bool require_any_grad = egr::ComputeRequireGrad(%s);\n "
718+ " bool require_any_grad = egr::EagerUtils:: ComputeRequireGrad(%s);\n "
729719 " if(require_any_grad) {\n "
730- " egr::PassStopGradient(%s);\n "
720+ " egr::EagerUtils:: PassStopGradient(%s);\n "
731721 " %s\n }" ;
732722 std::string grad_node_creation_body_str = paddle::string::Sprintf (
733723 GRAD_NODE_CREATION_TEMPLATE, prepare_autograd_meta_str,
@@ -793,7 +783,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
793783 Controller.Instance().GetExpectedPlace(), {});
794784
795785 // According to fwd_outputs_names
796- std::vector<egr::EagerTensor> Out0 = GetOutputs (outs["Out0"]);
786+ std::vector<egr::EagerTensor> Out0 = GGetOutputetOutputs (outs["Out0"]);
797787 egr::EagerTensor Out1 = GetOutputs(outs["Out1"][0]);
798788 std::vector<egr::EagerTensor> Out2 = GetOutputs(outs["Out2"]);
799789
@@ -830,7 +820,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
830820 input_args_str_list[input_position] =
831821 paddle::string::Sprintf (FWD_INS_ARG_TEMPLATE, input_name);
832822 }
833- const char * FWD_INS_CONTENT_TEMPLATE = " { \" %s\" , egr::SyncToVars(%s) }," ;
823+ const char * FWD_INS_CONTENT_TEMPLATE =
824+ " { \" %s\" , egr::EagerUtils::SyncToVars(%s) }," ;
834825 ins_contents_str += paddle::string::Sprintf (FWD_INS_CONTENT_TEMPLATE,
835826 input_name, input_name);
836827 }
@@ -925,14 +916,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
925916 if (output.duplicable ()) {
926917 const char * FWD_OUT_TENSORS_TEMPLATE =
927918 " std::vector<egr::EagerTensor> %s = "
928- " egr::GetOutputs(outs[\" %s\" ]);\n " ;
919+ " egr::EagerUtils:: GetOutputs(outs[\" %s\" ]);\n " ;
929920 out_tensor_str = paddle::string::Sprintf (FWD_OUT_TENSORS_TEMPLATE,
930921 output_name, output_name);
931922 return_types[return_position] = " std::vector<egr::EagerTensor>" ;
932923 } else {
933924 const char * FWD_OUT_TENSOR_TEMPLATE =
934925 " egr::EagerTensor %s = "
935- " egr::GetOutput(outs[\" %s\" ][0]);\n " ;
926+ " egr::EagerUtils:: GetOutput(outs[\" %s\" ][0]);\n " ;
936927 out_tensor_str = paddle::string::Sprintf (FWD_OUT_TENSOR_TEMPLATE,
937928 output_name, output_name);
938929 return_types[return_position] = " egr::EagerTensor" ;
@@ -1093,7 +1084,8 @@ static std::string GenerateGradNodeCCContents(
10931084 grad_ins_fwd_slotname_map.at (grad_input_name) + " _" ;
10941085 const char * GRAD_INS_FWD_CONTENT_TEMPLATE =
10951086 " { \" %s\" , "
1096- " egr::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&this->%s, "
1087+ " egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&"
1088+ " this->%s, "
10971089 " nullptr)) }," ;
10981090 ins_contents_str +=
10991091 paddle::string::Sprintf (GRAD_INS_FWD_CONTENT_TEMPLATE,
@@ -1104,7 +1096,7 @@ static std::string GenerateGradNodeCCContents(
11041096 size_t fwd_output_position = fwd_outputs_name_pos_map.at (
11051097 grad_ins_grad_slotname_map.at (grad_input_name));
11061098 const char * GRAD_INS_GRAD_CONTENT_TEMPLATE =
1107- " { \" %s\" , egr::SyncToVars(grads[%d]) }," ;
1099+ " { \" %s\" , egr::EagerUtils:: SyncToVars(grads[%d]) }," ;
11081100 ins_contents_str += paddle::string::Sprintf (
11091101 GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
11101102
@@ -1206,7 +1198,7 @@ static std::string GenerateGradNodeCCContents(
12061198 fwd_inputs_name_pos_map.at (grad_outs_slotname_map.at (grad_out_name));
12071199
12081200 const char * BWD_OUTPUT_TEMPLATE =
1209- " outputs[%d] = GetOutputs(outs[\" %s\" ]);\n " ;
1201+ " outputs[%d] = egr::EagerUtils:: GetOutputs(outs[\" %s\" ]);\n " ;
12101202 outputs_str += paddle::string::Sprintf (BWD_OUTPUT_TEMPLATE,
12111203 fwd_input_position, grad_out_name);
12121204 }
@@ -1526,6 +1518,9 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
15261518 GenerateForwardHFile (output_dir, dygraph_forward_api_str);
15271519}
15281520
1521+ } // namespace framework
1522+ } // namespace paddle
1523+
15291524int main (int argc, char * argv[]) {
15301525 if (argc != 2 ) {
15311526 std::cerr << " argc must be 2" << std::endl;
@@ -1537,6 +1532,3 @@ int main(int argc, char* argv[]) {
15371532
15381533 return 0 ;
15391534}
1540-
1541- } // namespace framework
1542- } // namespace paddle
0 commit comments