1515from __future__ import print_function
1616
1717from paddle .utils import gast
18- from .utils import is_paddle_api , is_dygraph_api , is_numpy_api , index_in_list
18+ from .logging_utils import warn
19+ from .utils import is_paddle_api , is_dygraph_api , is_numpy_api , index_in_list , ast_to_source_code
1920
2021__all__ = ['AstNodeWrapper' , 'NodeVarType' , 'StaticAnalysisVisitor' ]
2122
@@ -57,6 +58,15 @@ class NodeVarType(object):
5758 # If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent.
5859 TENSOR_TYPES = {TENSOR , PADDLE_RETURN_TYPES }
5960
61+ Annotation_map = {
62+ "Tensor" : TENSOR ,
63+ "paddle.Tensor" : TENSOR ,
64+ "int" : INT ,
65+ "float" : FLOAT ,
66+ "bool" : BOOLEAN ,
67+ "str" : STRING
68+ }
69+
6070 @staticmethod
6171 def binary_op_output_type (in_type1 , in_type2 ):
6272 if in_type1 == in_type2 :
@@ -83,6 +93,16 @@ def binary_op_output_type(in_type1, in_type2):
8393 return NodeVarType .UNKNOWN
8494 return max (in_type1 , in_type2 )
8595
96+ @staticmethod
97+ def type_from_annotation (annotation ):
98+ annotation_str = ast_to_source_code (annotation ).strip ()
99+ if annotation_str in NodeVarType .Annotation_map :
100+ return NodeVarType .Annotation_map [annotation_str ]
101+
102+ # raise warning if not found
103+ warn ("Currently we don't support annotation: %s" % annotation_str )
104+ return NodeVarType .UNKNOWN
105+
86106
87107class AstNodeWrapper (object ):
88108 """
@@ -316,6 +336,18 @@ def _get_node_var_type(self, cur_wrapper):
316336 self .var_env .set_var_type (target .id , ret_type )
317337 return ret_type
318338
339+ if isinstance (node , gast .AnnAssign ):
340+ # TODO(0x45f): To determine whether need to support assignment statements
341+ # like `self.x: float = 2.1`.
342+ ret_type = {NodeVarType .type_from_annotation (node .annotation )}
343+ # if annotation and value(Constant) are diffent type, we use value type
344+ if node .value :
345+ ret_type = self .node_to_wrapper_map [node .value ].node_var_type
346+ if isinstance (node .target , gast .Name ):
347+ self .node_to_wrapper_map [node .target ].node_var_type = ret_type
348+ self .var_env .set_var_type (node .target .id , ret_type )
349+ return ret_type
350+
319351 if isinstance (node , gast .Name ):
320352 if node .id == "None" :
321353 return {NodeVarType .NONE }
@@ -325,21 +357,8 @@ def _get_node_var_type(self, cur_wrapper):
325357 parent_node_wrapper = cur_wrapper .parent
326358 if parent_node_wrapper and isinstance (parent_node_wrapper .node ,
327359 gast .arguments ):
328- parent_node = parent_node_wrapper .node
329- var_type = {NodeVarType .UNKNOWN }
330- if parent_node .defaults :
331- index = index_in_list (parent_node .args , node )
332- args_len = len (parent_node .args )
333- if index != - 1 and args_len - index <= len (
334- parent_node .defaults ):
335- defaults_node = parent_node .defaults [index - args_len ]
336- if isinstance (defaults_node , gast .Constant ):
337- var_type = self ._get_constant_node_type (
338- defaults_node )
339-
340- # Add node with identified type into cur_env.
341- self .var_env .set_var_type (node .id , var_type )
342- return var_type
360+
361+ return self ._get_func_argument_type (parent_node_wrapper , node )
343362
344363 return self .var_env .get_var_type (node .id )
345364
@@ -373,3 +392,42 @@ def _get_node_var_type(self, cur_wrapper):
373392 return {NodeVarType .TENSOR }
374393
375394 return {NodeVarType .STATEMENT }
395+
396+ def _get_func_argument_type (self , parent_node_wrapper , node ):
397+ """
398+ Returns type information by parsing annotation or default values.
399+
400+ For example:
401+ 1. parse by default values.
402+ foo(x, y=1, z='s') -> x: UNKNOWN, y: INT, z: STR
403+
404+ 2. parse by Py3 type annotation.
405+ foo(x: Tensor, y: int, z: str) -> x: Tensor, y: INT, z: STR
406+
407+ 3. parse by type annotation and default values.
408+ foo(x: Tensor, y: int, z: str = 'abc') -> x: Tensor, y: INT, z: STR
409+
410+ NOTE: Currently, we only support Tensor, int, bool, float, str et.al.
411+ Other complicate types will be supported later.
412+ """
413+ assert isinstance (node , gast .Name )
414+
415+ parent_node = parent_node_wrapper .node
416+ var_type = {NodeVarType .UNKNOWN }
417+ if node .annotation is not None :
418+ var_type = {NodeVarType .type_from_annotation (node .annotation )}
419+ self .var_env .set_var_type (node .id , var_type )
420+
421+ # if annotation and value(Constant) are diffent type, we use value type
422+ if parent_node .defaults :
423+ index = index_in_list (parent_node .args , node )
424+ args_len = len (parent_node .args )
425+ if index != - 1 and args_len - index <= len (parent_node .defaults ):
426+ defaults_node = parent_node .defaults [index - args_len ]
427+ if isinstance (defaults_node , gast .Constant ):
428+ var_type = self ._get_constant_node_type (defaults_node )
429+
430+ # Add node with identified type into cur_env.
431+ self .var_env .set_var_type (node .id , var_type )
432+
433+ return var_type
0 commit comments