Skip to content

Commit db8425e

Browse files
0x45fAurelius84
andauthored
[Dy2stat]support Python3 type annotation (#36544)
* Support Py3 type annotations in @to_static * support type hint for args in func * support type hint assign * if annotation and value(Constant) are diffent type, we use value type * polish type_from_annotation() * code format * code format * remove useless commentary * fix review Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
1 parent 0590277 commit db8425e

File tree

5 files changed

+103
-28
lines changed

5 files changed

+103
-28
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/error.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,10 @@ def create_message(self):
208208
message_lines.append("")
209209

210210
# Add paddle traceback after user code traceback
211-
paddle_traceback_start_idnex = user_code_traceback_index[
211+
paddle_traceback_start_index = user_code_traceback_index[
212212
-1] + 1 if user_code_traceback_index else 0
213213
for filepath, lineno, funcname, code in self.origin_traceback[
214-
paddle_traceback_start_idnex:]:
214+
paddle_traceback_start_index:]:
215215
traceback_frame = TraceBackFrame(
216216
Location(filepath, lineno), funcname, code)
217217
message_lines.append(traceback_frame.formated_message())
@@ -305,10 +305,10 @@ def _simplify_error_value(self):
305305
error_frame.append("")
306306

307307
# Add paddle traceback after user code traceback
308-
paddle_traceback_start_idnex = user_code_traceback_index[
308+
paddle_traceback_start_index = user_code_traceback_index[
309309
-1] + 1 if user_code_traceback_index else 0
310310
for filepath, lineno, funcname, code in error_traceback[
311-
paddle_traceback_start_idnex:]:
311+
paddle_traceback_start_index:]:
312312
traceback_frame = TraceBackFrame(
313313
Location(filepath, lineno), funcname, code)
314314
error_frame.append(traceback_frame.formated_message())

python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from __future__ import print_function
1616

1717
from paddle.utils import gast
18-
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list
18+
from .logging_utils import warn
19+
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list, ast_to_source_code
1920

2021
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
2122

@@ -57,6 +58,15 @@ class NodeVarType(object):
5758
# If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent.
5859
TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES}
5960

61+
Annotation_map = {
62+
"Tensor": TENSOR,
63+
"paddle.Tensor": TENSOR,
64+
"int": INT,
65+
"float": FLOAT,
66+
"bool": BOOLEAN,
67+
"str": STRING
68+
}
69+
6070
@staticmethod
6171
def binary_op_output_type(in_type1, in_type2):
6272
if in_type1 == in_type2:
@@ -83,6 +93,16 @@ def binary_op_output_type(in_type1, in_type2):
8393
return NodeVarType.UNKNOWN
8494
return max(in_type1, in_type2)
8595

96+
@staticmethod
97+
def type_from_annotation(annotation):
98+
annotation_str = ast_to_source_code(annotation).strip()
99+
if annotation_str in NodeVarType.Annotation_map:
100+
return NodeVarType.Annotation_map[annotation_str]
101+
102+
# raise warning if not found
103+
warn("Currently we don't support annotation: %s" % annotation_str)
104+
return NodeVarType.UNKNOWN
105+
86106

87107
class AstNodeWrapper(object):
88108
"""
@@ -316,6 +336,18 @@ def _get_node_var_type(self, cur_wrapper):
316336
self.var_env.set_var_type(target.id, ret_type)
317337
return ret_type
318338

339+
if isinstance(node, gast.AnnAssign):
340+
# TODO(0x45f): To determine whether need to support assignment statements
341+
# like `self.x: float = 2.1`.
342+
ret_type = {NodeVarType.type_from_annotation(node.annotation)}
343+
# if annotation and value(Constant) are diffent type, we use value type
344+
if node.value:
345+
ret_type = self.node_to_wrapper_map[node.value].node_var_type
346+
if isinstance(node.target, gast.Name):
347+
self.node_to_wrapper_map[node.target].node_var_type = ret_type
348+
self.var_env.set_var_type(node.target.id, ret_type)
349+
return ret_type
350+
319351
if isinstance(node, gast.Name):
320352
if node.id == "None":
321353
return {NodeVarType.NONE}
@@ -325,21 +357,8 @@ def _get_node_var_type(self, cur_wrapper):
325357
parent_node_wrapper = cur_wrapper.parent
326358
if parent_node_wrapper and isinstance(parent_node_wrapper.node,
327359
gast.arguments):
328-
parent_node = parent_node_wrapper.node
329-
var_type = {NodeVarType.UNKNOWN}
330-
if parent_node.defaults:
331-
index = index_in_list(parent_node.args, node)
332-
args_len = len(parent_node.args)
333-
if index != -1 and args_len - index <= len(
334-
parent_node.defaults):
335-
defaults_node = parent_node.defaults[index - args_len]
336-
if isinstance(defaults_node, gast.Constant):
337-
var_type = self._get_constant_node_type(
338-
defaults_node)
339-
340-
# Add node with identified type into cur_env.
341-
self.var_env.set_var_type(node.id, var_type)
342-
return var_type
360+
361+
return self._get_func_argument_type(parent_node_wrapper, node)
343362

344363
return self.var_env.get_var_type(node.id)
345364

@@ -373,3 +392,42 @@ def _get_node_var_type(self, cur_wrapper):
373392
return {NodeVarType.TENSOR}
374393

375394
return {NodeVarType.STATEMENT}
395+
396+
def _get_func_argument_type(self, parent_node_wrapper, node):
397+
"""
398+
Returns type information by parsing annotation or default values.
399+
400+
For example:
401+
1. parse by default values.
402+
foo(x, y=1, z='s') -> x: UNKNOWN, y: INT, z: STR
403+
404+
2. parse by Py3 type annotation.
405+
foo(x: Tensor, y: int, z: str) -> x: Tensor, y: INT, z: STR
406+
407+
3. parse by type annotation and default values.
408+
foo(x: Tensor, y: int, z: str = 'abc') -> x: Tensor, y: INT, z: STR
409+
410+
NOTE: Currently, we only support Tensor, int, bool, float, str et.al.
411+
Other complicate types will be supported later.
412+
"""
413+
assert isinstance(node, gast.Name)
414+
415+
parent_node = parent_node_wrapper.node
416+
var_type = {NodeVarType.UNKNOWN}
417+
if node.annotation is not None:
418+
var_type = {NodeVarType.type_from_annotation(node.annotation)}
419+
self.var_env.set_var_type(node.id, var_type)
420+
421+
# if annotation and value(Constant) are diffent type, we use value type
422+
if parent_node.defaults:
423+
index = index_in_list(parent_node.args, node)
424+
args_len = len(parent_node.args)
425+
if index != -1 and args_len - index <= len(parent_node.defaults):
426+
defaults_node = parent_node.defaults[index - args_len]
427+
if isinstance(defaults_node, gast.Constant):
428+
var_type = self._get_constant_node_type(defaults_node)
429+
430+
# Add node with identified type into cur_env.
431+
self.var_env.set_var_type(node.id, var_type)
432+
433+
return var_type

python/paddle/fluid/dygraph/dygraph_to_static/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,8 @@ def remove_if_exit(filepath):
520520

521521
def _inject_import_statements():
522522
import_statements = [
523-
"import paddle", "import paddle.fluid as fluid", "from typing import *",
523+
"import paddle", "from paddle import Tensor",
524+
"import paddle.fluid as fluid", "from typing import *",
524525
"import numpy as np"
525526
]
526527
return '\n'.join(import_statements) + '\n'

python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def set_test_func(self):
6565
self.func = simple_func
6666

6767
def set_static_lineno(self):
68-
self.static_abs_lineno_list = [5, 6, 7]
68+
self.static_abs_lineno_list = [6, 7, 8]
6969

7070
def set_dygraph_info(self):
7171
self.line_num = 3
@@ -149,7 +149,7 @@ def set_test_func(self):
149149
self.func = nested_func
150150

151151
def set_static_lineno(self):
152-
self.static_abs_lineno_list = [5, 7, 8, 9, 10]
152+
self.static_abs_lineno_list = [6, 8, 9, 10, 11]
153153

154154
def set_dygraph_info(self):
155155
self.line_num = 5
@@ -174,7 +174,7 @@ def set_test_func(self):
174174
self.func = decorated_func
175175

176176
def set_static_lineno(self):
177-
self.static_abs_lineno_list = [5, 6]
177+
self.static_abs_lineno_list = [6, 7]
178178

179179
def set_dygraph_info(self):
180180
self.line_num = 2
@@ -208,7 +208,7 @@ def set_test_func(self):
208208
self.func = decorated_func2
209209

210210
def set_static_lineno(self):
211-
self.static_abs_lineno_list = [5, 6]
211+
self.static_abs_lineno_list = [6, 7]
212212

213213
def set_dygraph_info(self):
214214
self.line_num = 2

python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def func_to_test3():
5757
h = None
5858
i = False
5959
j = None + 1
60+
k: float = 1.0
61+
l: paddle.Tensor = paddle.to_tensor([1, 2])
6062

6163

6264
result_var_type3 = {
@@ -69,7 +71,9 @@ def func_to_test3():
6971
'g': {NodeVarType.STRING},
7072
'h': {NodeVarType.NONE},
7173
'i': {NodeVarType.BOOLEAN},
72-
'j': {NodeVarType.UNKNOWN}
74+
'j': {NodeVarType.UNKNOWN},
75+
'k': {NodeVarType.FLOAT},
76+
'l': {NodeVarType.PADDLE_RETURN_TYPES}
7377
}
7478

7579

@@ -139,13 +143,25 @@ def add(x, y):
139143
'add': {NodeVarType.INT}
140144
}
141145

146+
147+
def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float='diff'):
148+
a = True
149+
150+
151+
result_var_type7 = {
152+
'a': {NodeVarType.BOOLEAN},
153+
'b': {NodeVarType.FLOAT},
154+
'c': {NodeVarType.TENSOR},
155+
'd': {NodeVarType.STRING}
156+
}
157+
142158
test_funcs = [
143159
func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5,
144-
func_to_test6
160+
func_to_test6, func_to_test7
145161
]
146162
result_var_type = [
147163
result_var_type1, result_var_type2, result_var_type3, result_var_type4,
148-
result_var_type5, result_var_type6
164+
result_var_type5, result_var_type6, result_var_type7
149165
]
150166

151167

0 commit comments

Comments
 (0)