@@ -34,6 +34,8 @@ std::unordered_map<std::string, std::vector<std::string>>
3434 core_ops_returns_info = {};
3535std::unordered_map<std::string, std::vector<std::string>> core_ops_args_info =
3636 {};
37+ std::unordered_map<std::string, std::vector<std::string>>
38+ core_ops_args_type_info = {};
3739
3840/* --- Static maps to handle corner cases --- */
3941static std::unordered_map<std::string, paddle::framework::AttributeMap>
@@ -1120,7 +1122,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
11201122 std::string generated_function_body = " " ;
11211123 std::string dygraph_function_args_str = " " ;
11221124 core_ops_args_info[op_type] = {};
1125+ core_ops_args_type_info[op_type] = {};
11231126 core_ops_args_info[op_type].resize (in_vars.size ());
1127+ core_ops_args_type_info[op_type].resize (in_vars.size ());
11241128
11251129 /* ------ Dygraph forward function generation ------ */
11261130 generated_function_body += " // Dygraph Forward Pass\n " ;
@@ -1138,10 +1142,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
11381142 " const std::vector<egr::EagerTensor>& %s" ;
11391143 input_args_str_list[input_position] =
11401144 paddle::string::Sprintf (FWD_INS_ARG_TEMPLATE, input_name);
1145+
1146+ core_ops_args_type_info[op_type][input_position] = " list" ;
11411147 } else {
11421148 const char * FWD_INS_ARG_TEMPLATE = " const egr::EagerTensor& %s" ;
11431149 input_args_str_list[input_position] =
11441150 paddle::string::Sprintf (FWD_INS_ARG_TEMPLATE, input_name);
1151+
1152+ core_ops_args_type_info[op_type][input_position] = " tensor" ;
11451153 }
11461154 core_ops_args_info[op_type][input_position] = input_name;
11471155
@@ -1210,11 +1218,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
12101218 paddle::string::Sprintf (FWD_NUM_ARG_TEMPLATE, output_var_name);
12111219 dygraph_function_args_str += arg_str;
12121220
1221+ core_ops_args_type_info[op_type].push_back (" list" );
12131222 } else {
12141223 const char * FWD_NUM_ARG_TEMPLATE = " , egr::EagerTensor* %s" ;
12151224 std::string arg_str =
12161225 paddle::string::Sprintf (FWD_NUM_ARG_TEMPLATE, output_var_name);
12171226 dygraph_function_args_str += arg_str;
1227+
1228+ core_ops_args_type_info[op_type].push_back (" tensor" );
12181229 }
12191230 const char * FWD_OUTS_CONTENT_TEMPLATE =
12201231 " { \" %s\" , egr::EagerUtils::TrySyncToVars(%s) }," ;
@@ -1236,6 +1247,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
12361247 outs_contents_str += paddle::string::Sprintf (FWD_OUTS_CONTENT_TEMPLATE,
12371248 output_name, outnum);
12381249 core_ops_args_info[op_type].push_back (outnum);
1250+ core_ops_args_type_info[op_type].push_back (" int" );
12391251 } else {
12401252 const char * FWD_OUTS_CONTENT_TEMPLATE =
12411253 " { \" %s\" , "
@@ -1840,6 +1852,9 @@ static std::string GenerateDygraphHFileIncludes() {
18401852 dygraph_forward_api_includes_str +=
18411853 " extern std::unordered_map<std::string, std::vector<std::string>> "
18421854 " core_ops_args_info;\n " ;
1855+ dygraph_forward_api_includes_str +=
1856+ " extern std::unordered_map<std::string, std::vector<std::string>> "
1857+ " core_ops_args_type_info;\n " ;
18431858 dygraph_forward_api_includes_str +=
18441859 " extern std::unordered_map<std::string, std::vector<std::string>> "
18451860 " core_ops_returns_info;\n\n " ;
@@ -1936,16 +1951,20 @@ static std::string GenerateCoreOpsReturnsInfo() {
19361951 " std::unordered_map<std::string, std::vector<std::string>> "
19371952 " core_ops_args_info = { %s };\n "
19381953 " std::unordered_map<std::string, std::vector<std::string>> "
1954+ " core_ops_args_type_info = { %s };\n "
1955+ " std::unordered_map<std::string, std::vector<std::string>> "
19391956 " core_ops_returns_info = { %s };\n " ;
19401957
19411958 std::string core_ops_args_info_init_str =
19421959 ConvertCoreOpsInfosToString (core_ops_args_info);
1960+ std::string core_ops_args_type_info_init_str =
1961+ ConvertCoreOpsInfosToString (core_ops_args_type_info);
19431962 std::string core_ops_returns_info_init_str =
19441963 ConvertCoreOpsInfosToString (core_ops_returns_info);
19451964
19461965 std::string core_ops_info_str = paddle::string::Sprintf (
19471966 Core_Ops_Returns_MAP_TEMPLATE, core_ops_args_info_init_str,
1948- core_ops_returns_info_init_str);
1967+ core_ops_args_type_info_init_str, core_ops_returns_info_init_str);
19491968
19501969 return core_ops_info_str;
19511970}
0 commit comments