@@ -950,12 +950,12 @@ def parse_op_info(op_name):
950950 op_proto = OpProtoHolder .instance ().get_op_proto (op_name )
951951
952952 in_names = [x .name for x in op_proto .inputs ]
953- out_names = [x .name for x in op_proto .outputs ]
954953 attr_names = [
955954 x .name for x in op_proto .attrs if x .name not in DEFAULT_OP_ATTR_NAMES
956955 ]
956+ out_names = [x .name for x in op_proto .outputs ]
957957
958- return in_names , out_names , attr_names
958+ return in_names , attr_names , out_names
959959
960960
961961def _import_module_from_library (module_name , build_directory , verbose = False ):
@@ -1038,28 +1038,58 @@ def remove_if_exit(filepath):
10381038 return custom_module
10391039
10401040
1041+ def _gen_output_content (in_names , out_names , inplace_reverse_idx ):
1042+ # ' ' * tab space * tab number
1043+ indent = ' ' * 4 * 2
1044+ dynamic_content = ""
1045+ static_content = ""
1046+ for out_idx , out_name in enumerate (out_names ):
1047+ in_idx = - 1
1048+ if out_idx in inplace_reverse_idx :
1049+ in_idx = inplace_reverse_idx [out_idx ]
1050+ if in_idx != - 1 and "@VECTOR" in in_names [in_idx ]:
1051+ lower_in_names = in_names [in_idx ].split ("@" )[0 ].lower ()
1052+ dynamic_content += f"""
1053+ { indent } outs['{ out_name } '] = [core.eager.Tensor() for _ in range(len({ lower_in_names } ))]
1054+ { indent } ctx.add_outputs(outs['{ out_name } '])"""
1055+ static_content += f"""
1056+ { indent } outs['{ out_name } '] = [helper.create_variable(dtype='float32') for _ in range(len({ lower_in_names } ))]"""
1057+ else :
1058+ dynamic_content += f"""
1059+ { indent } outs['{ out_name } '] = core.eager.Tensor()
1060+ { indent } ctx.add_outputs(outs['{ out_name } '])"""
1061+ static_content += f"""
1062+ { indent } outs['{ out_name } '] = helper.create_variable(dtype='float32')"""
1063+
1064+ return dynamic_content , static_content
1065+
1066+
10411067def _custom_api_content (op_name ):
10421068 (
1043- params_str ,
1044- ins_str ,
1045- attrs_str ,
1046- outs_str ,
1069+ params_list ,
1070+ ins_map ,
1071+ attrs_map ,
1072+ outs_list ,
10471073 in_names ,
1048- attrs_names ,
1074+ attr_names ,
1075+ out_names ,
1076+ inplace_reverse_idx ,
10491077 ) = _get_api_inputs_str (op_name )
1050- lower_in_names = [p .split ("@" )[0 ].lower () for p in in_names ]
1078+ dynamic_content , static_content = _gen_output_content (
1079+ in_names , out_names , inplace_reverse_idx
1080+ )
1081+ lower_in_list = [p .split ("@" )[0 ].lower () for p in in_names ]
10511082 API_TEMPLATE = textwrap .dedent (
10521083 """
10531084 import paddle.fluid.core as core
10541085 from paddle.fluid.core import VarBase, CustomOpKernelContext
10551086 from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode
10561087 from paddle.fluid.layer_helper import LayerHelper
10571088
1058- def {op_name}({inputs }):
1089+ def {op_name}({params_list }):
10591090 # prepare inputs and outputs
1060- attrs = {attrs}
10611091 outs = {{}}
1062- out_names = {out_names }
1092+ outs_list = {outs_list }
10631093
10641094 # The output variable's dtype use default value 'float32',
10651095 # and the actual dtype of output variable will be inferred in runtime.
@@ -1069,23 +1099,19 @@ def {op_name}({inputs}):
10691099 ctx.add_inputs(i)
10701100 for j in {attr_names}:
10711101 ctx.add_attr(j)
1072- for out_name in out_names:
1073- outs[out_name] = core.eager.Tensor()
1074- ctx.add_outputs(outs[out_name])
1102+ {dynamic_content}
10751103 core.eager._run_custom_op(ctx, "{op_name}", True)
10761104 else:
10771105 ins = {{}}
1078- for key, value in dict({ins }).items():
1106+ for key, value in dict({ins_map }).items():
10791107 # handle optional inputs
10801108 if value is not None:
10811109 ins[key] = value
10821110 helper = LayerHelper("{op_name}", **locals())
1083- for out_name in out_names:
1084- outs[out_name] = helper.create_variable(dtype='float32')
1085-
1086- helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs)
1111+ {static_content}
1112+ helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs={attrs_map})
10871113
1088- res = [outs[out_name] for out_name in out_names ]
1114+ res = [outs[out_name] for out_name in outs_list ]
10891115
10901116 return res[0] if len(res)==1 else res
10911117 """
@@ -1094,13 +1120,15 @@ def {op_name}({inputs}):
10941120 # generate python api file
10951121 api_content = API_TEMPLATE .format (
10961122 op_name = op_name ,
1097- inputs = params_str ,
1098- ins = ins_str ,
1099- attrs = attrs_str ,
1123+ params_list = params_list ,
1124+ ins_map = ins_map ,
1125+ attrs_map = attrs_map ,
11001126 # "[x, y, z]""
1101- in_names = "[" + "," .join (lower_in_names ) + "]" ,
1102- attr_names = "[" + "," .join (attrs_names ) + "]" ,
1103- out_names = outs_str ,
1127+ in_names = "[" + "," .join (lower_in_list ) + "]" ,
1128+ attr_names = "[" + "," .join (attr_names ) + "]" ,
1129+ outs_list = outs_list ,
1130+ dynamic_content = dynamic_content ,
1131+ static_content = static_content ,
11041132 )
11051133
11061134 return api_content
@@ -1132,30 +1160,42 @@ def _get_api_inputs_str(op_name):
11321160 """
11331161 Returns string of api parameters and inputs dict.
11341162 """
1135- in_names , out_names , attr_names = parse_op_info (op_name )
1163+ in_names , attr_names , out_names = parse_op_info (op_name )
11361164 # e.g: x, y, z
11371165 param_names = in_names + attr_names
11381166 # NOTE(chenweihang): we add suffix `@VECTOR` for std::vector<Tensor> input,
11391167 # but the string contains `@` cannot used as argument name, so we split
11401168 # input name by `@`, and only use first substr as argument
1141- params_str = ',' .join ([p .split ("@" )[0 ].lower () for p in param_names ])
1169+ params_list = ',' .join ([p .split ("@" )[0 ].lower () for p in param_names ])
11421170 # e.g: {'X': x, 'Y': y, 'Z': z}
1143- ins_str = "{%s}" % ',' .join (
1171+ ins_map = "{%s}" % ',' .join (
11441172 [
11451173 "'{}' : {}" .format (in_name , in_name .split ("@" )[0 ].lower ())
11461174 for in_name in in_names
11471175 ]
11481176 )
11491177 # e.g: {'num': n}
1150- attrs_str = "{%s}" % "," .join (
1178+ attrs_map = "{%s}" % "," .join (
11511179 [
11521180 "'{}' : {}" .format (attr_name , attr_name .split ("@" )[0 ].lower ())
11531181 for attr_name in attr_names
11541182 ]
11551183 )
11561184 # e.g: ['Out', 'Index']
1157- outs_str = "[%s]" % ',' .join (["'{}'" .format (name ) for name in out_names ])
1158- return params_str , ins_str , attrs_str , outs_str , in_names , attr_names
1185+ outs_list = "[%s]" % ',' .join (["'{}'" .format (name ) for name in out_names ])
1186+
1187+ inplace_reverse_idx = core .eager ._get_custom_operator_inplace_map (op_name )
1188+
1189+ return (
1190+ params_list ,
1191+ ins_map ,
1192+ attrs_map ,
1193+ outs_list ,
1194+ in_names ,
1195+ attr_names ,
1196+ out_names ,
1197+ inplace_reverse_idx ,
1198+ )
11591199
11601200
11611201def _write_setup_file (
0 commit comments