33from __future__ import annotations
44
55import ast
6+ from collections import defaultdict
67from typing import TYPE_CHECKING , Any , cast
78
89from .flake8asyncvisitor import Flake8AsyncVisitor , Flake8AsyncVisitor_cst
@@ -181,20 +182,25 @@ def __init__(self, *args: Any, **kwargs: Any):
181182 self .asynccontextmanager = False
182183 self .aenter = False
183184
185+ self .potential_errors : dict [str , list [ast .Call ]] = defaultdict (list )
186+
184187 def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ):
185- self .save_state (node , "aenter" )
188+ self .save_state (node , "aenter" , "asynccontextmanager" , "potential_errors" )
186189
187- self .aenter = node .name == "__aenter__" or has_decorator (
188- node , "asynccontextmanager"
189- )
190+ self .aenter = node .name == "__aenter__"
191+ self .asynccontextmanager = has_decorator (node , "asynccontextmanager" )
190192
191193 def visit_FunctionDef (self , node : ast .FunctionDef ):
192- self .save_state (node , "aenter" )
194+ self .save_state (node , "aenter" , "asynccontextmanager" , "potential_errors" )
193195 # sync function should never be named __aenter__ or have @asynccontextmanager
194- self .aenter = False
196+ self .aenter = self . asynccontextmanager = False
195197
196198 def visit_Yield (self , node : ast .Yield ):
197- self .aenter = False
199+ for nodes in self .potential_errors .values ():
200+ for n in nodes :
201+ self .error (n )
202+ self .potential_errors .clear ()
203+ self .aenter = self .asynccontextmanager = False
198204
199205 def visit_Call (self , node : ast .Call ) -> None :
200206 def is_startable (n : ast .expr , * startable_list : str ) -> bool :
@@ -210,14 +216,14 @@ def is_startable(n: ast.expr, *startable_list: str) -> bool:
210216 return any (is_startable (nn , * startable_list ) for nn in n .args )
211217 return False
212218
213- def is_nursery_call (node : ast .expr ):
219+ def is_nursery_call (node : ast .expr ) -> str | None :
214220 if not isinstance (node , ast .Attribute ) or node .attr not in (
215221 "start_soon" ,
216222 "create_task" ,
217223 ):
218- return False
224+ return None
219225 var = ast .unparse (node .value )
220- return (
226+ if (
221227 ("trio" in self .library and var .endswith ("nursery" ))
222228 or ("anyio" in self .library and var .endswith ("task_group" ))
223229 or (
@@ -228,11 +234,12 @@ def is_nursery_call(node: ast.expr):
228234 "asyncio.TaskGroup" ,
229235 )
230236 )
231- )
237+ ):
238+ return var
239+ return None
232240
233241 if (
234- self .aenter
235- and is_nursery_call (node .func )
242+ (var := is_nursery_call (node .func )) is not None
236243 and len (node .args ) > 0
237244 and is_startable (
238245 node .args [0 ],
@@ -241,7 +248,24 @@ def is_nursery_call(node: ast.expr):
241248 * self .options .startable_in_context_manager ,
242249 )
243250 ):
244- self .error (node )
251+ if self .aenter :
252+ self .error (node )
253+ elif self .asynccontextmanager :
254+ self .potential_errors [var ].append (node )
255+
256+ def visit_AsyncWith (self , node : ast .AsyncWith | ast .With ):
257+ # Entirely skip any nurseries that doesn't have any yields in them.
258+ # This fixes an otherwise very thorny false alarm.
259+ # In the worst case this does mean we iterate over the body twice, but might
260+ # actually be a performance gain on average due to setting `novisit`
261+ if not any (isinstance (n , ast .Yield ) for b in node .body for n in ast .walk (b )):
262+ self .novisit = True
263+ return
264+
265+ # open_nursery/create_task_group only works with AsyncWith, but in case somebody
266+ # is doing something very weird we'll be conservative and possibly avoid
267+ # some potential false alarms
268+ visit_With = visit_AsyncWith
245269
246270
247271# Checks that all async functions with a "task_status" parameter have a match in
0 commit comments