Skip to content

Commit 852a872

Browse files
authored
Added attr & tensor type mapping for final state codegen (#39997)
1 parent 72e462c commit 852a872

File tree

1 file changed

+21
-1
lines changed
  • paddle/fluid/eager/auto_code_generator/final_state_generator

1 file changed

+21
-1
lines changed

paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
core_ops_args_type_info = {}
2525

2626

27+
yaml_types_mapping = {
28+
'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \
29+
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
30+
'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \
31+
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
32+
'Tensor' : 'Tensor',
33+
'Tensor[]' : 'std::vector<Tensor>',
34+
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>'
35+
}
36+
37+
2738
def ParseArguments():
2839
parser = argparse.ArgumentParser(
2940
description='Eager Code Generator Args Parser')
@@ -59,7 +70,9 @@ def IsPlainTensorType(string):
5970

6071

6172
def IsVectorTensorType(string):
62-
vector_tensor_types = ['list(Tensor)']
73+
vector_tensor_types = [
74+
'std::vector<std::vector<Tensor>>', 'std::vector<Tensor>'
75+
]
6376
if string in vector_tensor_types:
6477
return True
6578
return False
@@ -180,6 +193,9 @@ def ParseYamlArgs(string):
180193
arg_name = m.group(3).split("=")[0].strip()
181194
default_value = m.group(3).split("=")[1].strip() if len(
182195
m.group(3).split("=")) > 1 else None
196+
197+
assert arg_type in yaml_types_mapping.keys()
198+
arg_type = yaml_types_mapping[arg_type]
183199
if "Tensor" in arg_type:
184200
assert default_value is None
185201
inputs_list.append([arg_name, arg_type, i])
@@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string):
219235
m = re.search(pattern, ret)
220236
ret_type = m.group(1)
221237
ret_name = m.group(2)
238+
239+
assert ret_type in yaml_types_mapping.keys()
240+
ret_type = yaml_types_mapping[ret_type]
241+
222242
assert "Tensor" in ret_type
223243
returns_list.append([ret_name, ret_type, i])
224244

0 commit comments

Comments
 (0)