@@ -37,6 +37,8 @@ std::unordered_map<std::string, std::vector<std::string>>
3737 core_ops_returns_info = {};
3838std::unordered_map<std::string, std::vector<std::string>> core_ops_args_info =
3939 {};
40+ std::unordered_map<std::string, std::vector<std::string>>
41+ core_ops_args_type_info = {};
4042
4143/* --- Static maps to handle corner cases --- */
4244static std::unordered_map<std::string, paddle::framework::AttributeMap>
@@ -1225,10 +1227,16 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
12251227 */
12261228 VLOG (6 ) << " Generating Dygraph Forward Function" ;
12271229
1228- std::string generated_function_body = " " ;
1230+ const char * FORWARD_FUNCTION_TEMPLATE =
1231+ " VLOG(3) << \" Running Eager Forward Op: %s\" ;\n " ;
1232+ std::string generated_function_body =
1233+ paddle::string::Sprintf (FORWARD_FUNCTION_TEMPLATE, op_type);
1234+
12291235 std::string dygraph_function_args_str = " " ;
12301236 core_ops_args_info[op_type] = {};
1237+ core_ops_args_type_info[op_type] = {};
12311238 core_ops_args_info[op_type].resize (in_vars.size ());
1239+ core_ops_args_type_info[op_type].resize (in_vars.size ());
12321240
12331241 /* ------ Dygraph forward function generation ------ */
12341242 generated_function_body += " // Dygraph Forward Pass\n " ;
@@ -1246,10 +1254,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
12461254 " const std::vector<egr::EagerTensor>& %s" ;
12471255 input_args_str_list[input_position] =
12481256 paddle::string::Sprintf (FWD_INS_ARG_TEMPLATE, input_name);
1257+
1258+ core_ops_args_type_info[op_type][input_position] = " list" ;
12491259 } else {
12501260 const char * FWD_INS_ARG_TEMPLATE = " const egr::EagerTensor& %s" ;
12511261 input_args_str_list[input_position] =
12521262 paddle::string::Sprintf (FWD_INS_ARG_TEMPLATE, input_name);
1263+
1264+ core_ops_args_type_info[op_type][input_position] = " tensor" ;
12531265 }
12541266 core_ops_args_info[op_type][input_position] = input_name;
12551267
@@ -1318,11 +1330,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
13181330 paddle::string::Sprintf (FWD_NUM_ARG_TEMPLATE, output_var_name);
13191331 dygraph_function_args_str += arg_str;
13201332
1333+ core_ops_args_type_info[op_type].push_back (" list" );
13211334 } else {
13221335 const char * FWD_NUM_ARG_TEMPLATE = " , egr::EagerTensor* %s" ;
13231336 std::string arg_str =
13241337 paddle::string::Sprintf (FWD_NUM_ARG_TEMPLATE, output_var_name);
13251338 dygraph_function_args_str += arg_str;
1339+
1340+ core_ops_args_type_info[op_type].push_back (" tensor" );
13261341 }
13271342 const char * FWD_OUTS_CONTENT_TEMPLATE =
13281343 " { \" %s\" , egr::EagerUtils::TrySyncToVars(%s) }," ;
@@ -1344,6 +1359,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
13441359 outs_contents_str += paddle::string::Sprintf (FWD_OUTS_CONTENT_TEMPLATE,
13451360 output_name, outnum);
13461361 core_ops_args_info[op_type].push_back (outnum);
1362+ core_ops_args_type_info[op_type].push_back (" int" );
13471363 } else {
13481364 const char * FWD_OUTS_CONTENT_TEMPLATE =
13491365 " { \" %s\" , "
@@ -1811,6 +1827,11 @@ static std::string GenerateGradNodeCCContents(
18111827 }
18121828 */
18131829
1830+ const char * EAGER_LOG_TEMPLATE =
1831+ " VLOG(3) << \" Running Eager Backward Node: GradNode%s\" ;\n " ;
1832+ std::string generated_grad_function_body =
1833+ paddle::string::Sprintf (EAGER_LOG_TEMPLATE, fwd_op_type);
1834+
18141835 // This is a Copy
18151836 auto op_base_infos = bwd_info.GetOpBaseInfos ();
18161837
@@ -1829,7 +1850,6 @@ static std::string GenerateGradNodeCCContents(
18291850 op_base_infos.emplace_back (std::move (op_base_info));
18301851 }
18311852
1832- std::string generated_grad_function_body = " " ;
18331853 size_t outs_size = 0 ;
18341854 for (size_t i = 0 ; i < op_base_infos.size (); i++) {
18351855 const auto & op_base_info = op_base_infos[i];
@@ -2030,6 +2050,9 @@ static std::string GenerateDygraphHFileIncludes() {
20302050 dygraph_forward_api_includes_str +=
20312051 " extern std::unordered_map<std::string, std::vector<std::string>> "
20322052 " core_ops_args_info;\n " ;
2053+ dygraph_forward_api_includes_str +=
2054+ " extern std::unordered_map<std::string, std::vector<std::string>> "
2055+ " core_ops_args_type_info;\n " ;
20332056 dygraph_forward_api_includes_str +=
20342057 " extern std::unordered_map<std::string, std::vector<std::string>> "
20352058 " core_ops_returns_info;\n\n " ;
@@ -2126,16 +2149,20 @@ static std::string GenerateCoreOpsReturnsInfo() {
21262149 " std::unordered_map<std::string, std::vector<std::string>> "
21272150 " core_ops_args_info = { %s };\n "
21282151 " std::unordered_map<std::string, std::vector<std::string>> "
2152+ " core_ops_args_type_info = { %s };\n "
2153+ " std::unordered_map<std::string, std::vector<std::string>> "
21292154 " core_ops_returns_info = { %s };\n " ;
21302155
21312156 std::string core_ops_args_info_init_str =
21322157 ConvertCoreOpsInfosToString (core_ops_args_info);
2158+ std::string core_ops_args_type_info_init_str =
2159+ ConvertCoreOpsInfosToString (core_ops_args_type_info);
21332160 std::string core_ops_returns_info_init_str =
21342161 ConvertCoreOpsInfosToString (core_ops_returns_info);
21352162
21362163 std::string core_ops_info_str = paddle::string::Sprintf (
21372164 Core_Ops_Returns_MAP_TEMPLATE, core_ops_args_info_init_str,
2138- core_ops_returns_info_init_str);
2165+ core_ops_args_type_info_init_str, core_ops_returns_info_init_str);
21392166
21402167 return core_ops_info_str;
21412168}
0 commit comments