|
24 | 24 | core_ops_args_type_info = {} |
25 | 25 |
|
26 | 26 |
|
| 27 | +yaml_types_mapping = { |
| 28 | + 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ |
| 29 | + 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ |
| 30 | + 'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \ |
| 31 | + 'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', |
| 32 | + 'Tensor' : 'Tensor', |
| 33 | + 'Tensor[]' : 'std::vector<Tensor>', |
| 34 | + 'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>' |
| 35 | +} |
| 36 | + |
| 37 | + |
27 | 38 | def ParseArguments(): |
28 | 39 | parser = argparse.ArgumentParser( |
29 | 40 | description='Eager Code Generator Args Parser') |
@@ -59,7 +70,9 @@ def IsPlainTensorType(string): |
59 | 70 |
|
60 | 71 |
|
61 | 72 | def IsVectorTensorType(string): |
62 | | - vector_tensor_types = ['list(Tensor)'] |
| 73 | + vector_tensor_types = [ |
| 74 | + 'std::vector<std::vector<Tensor>>', 'std::vector<Tensor>' |
| 75 | + ] |
63 | 76 | if string in vector_tensor_types: |
64 | 77 | return True |
65 | 78 | return False |
@@ -180,6 +193,9 @@ def ParseYamlArgs(string): |
180 | 193 | arg_name = m.group(3).split("=")[0].strip() |
181 | 194 | default_value = m.group(3).split("=")[1].strip() if len( |
182 | 195 | m.group(3).split("=")) > 1 else None |
| 196 | + |
| 197 | + assert arg_type in yaml_types_mapping.keys() |
| 198 | + arg_type = yaml_types_mapping[arg_type] |
183 | 199 | if "Tensor" in arg_type: |
184 | 200 | assert default_value is None |
185 | 201 | inputs_list.append([arg_name, arg_type, i]) |
@@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string): |
219 | 235 | m = re.search(pattern, ret) |
220 | 236 | ret_type = m.group(1) |
221 | 237 | ret_name = m.group(2) |
| 238 | + |
| 239 | + assert ret_type in yaml_types_mapping.keys() |
| 240 | + ret_type = yaml_types_mapping[ret_type] |
| 241 | + |
222 | 242 | assert "Tensor" in ret_type |
223 | 243 | returns_list.append([ret_name, ret_type, i]) |
224 | 244 |
|
|
0 commit comments