@@ -270,6 +270,16 @@ def copy_modified(self, *,
270270 self .line , self .column )
271271
272272
273+ class TypeGuardType (Type ):
274+ """Only used by find_instance_check() etc."""
275+ def __init__ (self , type_guard : Type ):
276+ super ().__init__ (line = type_guard .line , column = type_guard .column )
277+ self .type_guard = type_guard
278+
279+ def __repr__ (self ) -> str :
280+ return "TypeGuard({})" .format (self .type_guard )
281+
282+
273283class ProperType (Type ):
274284 """Not a type alias.
275285
@@ -1005,6 +1015,7 @@ class CallableType(FunctionLike):
10051015 # tools that consume mypy ASTs
10061016 'def_extras' , # Information about original definition we want to serialize.
10071017 # This is used for more detailed error messages.
1018+ 'type_guard' , # T, if -> TypeGuard[T] (ret_type is bool in this case).
10081019 )
10091020
10101021 def __init__ (self ,
@@ -1024,6 +1035,7 @@ def __init__(self,
10241035 from_type_type : bool = False ,
10251036 bound_args : Sequence [Optional [Type ]] = (),
10261037 def_extras : Optional [Dict [str , Any ]] = None ,
1038+ type_guard : Optional [Type ] = None ,
10271039 ) -> None :
10281040 super ().__init__ (line , column )
10291041 assert len (arg_types ) == len (arg_kinds ) == len (arg_names )
@@ -1058,6 +1070,7 @@ def __init__(self,
10581070 not definition .is_static else None }
10591071 else :
10601072 self .def_extras = {}
1073+ self .type_guard = type_guard
10611074
10621075 def copy_modified (self ,
10631076 arg_types : Bogus [Sequence [Type ]] = _dummy ,
@@ -1075,7 +1088,9 @@ def copy_modified(self,
10751088 special_sig : Bogus [Optional [str ]] = _dummy ,
10761089 from_type_type : Bogus [bool ] = _dummy ,
10771090 bound_args : Bogus [List [Optional [Type ]]] = _dummy ,
1078- def_extras : Bogus [Dict [str , Any ]] = _dummy ) -> 'CallableType' :
1091+ def_extras : Bogus [Dict [str , Any ]] = _dummy ,
1092+ type_guard : Bogus [Optional [Type ]] = _dummy ,
1093+ ) -> 'CallableType' :
10791094 return CallableType (
10801095 arg_types = arg_types if arg_types is not _dummy else self .arg_types ,
10811096 arg_kinds = arg_kinds if arg_kinds is not _dummy else self .arg_kinds ,
@@ -1094,6 +1109,7 @@ def copy_modified(self,
10941109 from_type_type = from_type_type if from_type_type is not _dummy else self .from_type_type ,
10951110 bound_args = bound_args if bound_args is not _dummy else self .bound_args ,
10961111 def_extras = def_extras if def_extras is not _dummy else dict (self .def_extras ),
1112+ type_guard = type_guard if type_guard is not _dummy else self .type_guard ,
10971113 )
10981114
10991115 def var_arg (self ) -> Optional [FormalArgument ]:
@@ -1255,6 +1271,8 @@ def __eq__(self, other: object) -> bool:
12551271 def serialize (self ) -> JsonDict :
12561272 # TODO: As an optimization, leave out everything related to
12571273 # generic functions for non-generic functions.
1274+ assert (self .type_guard is None
1275+ or isinstance (get_proper_type (self .type_guard ), Instance )), str (self .type_guard )
12581276 return {'.class' : 'CallableType' ,
12591277 'arg_types' : [t .serialize () for t in self .arg_types ],
12601278 'arg_kinds' : self .arg_kinds ,
@@ -1269,6 +1287,7 @@ def serialize(self) -> JsonDict:
12691287 'bound_args' : [(None if t is None else t .serialize ())
12701288 for t in self .bound_args ],
12711289 'def_extras' : dict (self .def_extras ),
1290+ 'type_guard' : self .type_guard .serialize () if self .type_guard is not None else None ,
12721291 }
12731292
12741293 @classmethod
@@ -1286,7 +1305,9 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
12861305 implicit = data ['implicit' ],
12871306 bound_args = [(None if t is None else deserialize_type (t ))
12881307 for t in data ['bound_args' ]],
1289- def_extras = data ['def_extras' ]
1308+ def_extras = data ['def_extras' ],
1309+ type_guard = (deserialize_type (data ['type_guard' ])
1310+ if data ['type_guard' ] is not None else None ),
12901311 )
12911312
12921313
@@ -2097,7 +2118,10 @@ def visit_callable_type(self, t: CallableType) -> str:
20972118 s = '({})' .format (s )
20982119
20992120 if not isinstance (get_proper_type (t .ret_type ), NoneType ):
2100- s += ' -> {}' .format (t .ret_type .accept (self ))
2121+ if t .type_guard is not None :
2122+ s += ' -> TypeGuard[{}]' .format (t .type_guard .accept (self ))
2123+ else :
2124+ s += ' -> {}' .format (t .ret_type .accept (self ))
21012125
21022126 if t .variables :
21032127 vs = []
0 commit comments