44import  traceback 
55import  collections 
66import  asyncio  as  aio 
7- from  .utils  import  _get_future_result 
7+ from  .utils  import  result_noraise 
88
99
1010class  BaseAioPool (object ):
11+  ''' BaseAioPool implements features, supposed to work in all supported 
12+  python versions. Other features supposed to be implemented as mixins.''' 
1113
1214 def  __init__ (self , size = 1024 , * , loop = None ):
1315 self .loop  =  loop  or  aio .get_event_loop ()
1416
1517 self .size  =  size 
1618 self ._executed  =  0 
17-  self ._joined  =  collections .deque ()
18-  self ._waiting  =  collections .deque ()
19+  self ._joined  =  set ()
20+  self ._waiting  =  {} # future -> task 
21+  self ._spawned  =  {} # future -> task 
1922 self .semaphore  =  aio .Semaphore (value = self .size , loop = self .loop )
2023
2124 async  def  __aenter__ (self ):
@@ -41,7 +44,7 @@ async def join(self):
4144 return  True 
4245
4346 fut  =  self .loop .create_future ()
44-  self ._joined .append (fut )
47+  self ._joined .add (fut )
4548 try :
4649 return  await  fut 
4750 finally :
@@ -72,33 +75,46 @@ async def _wrap(self, coro, future, cb=None, ctx=None):
7275 return 
7376
7477 self .semaphore .release ()
75-  if  not  exc :
76-  future .set_result (res )
77-  else :
78-  future .set_exception (exc )
7978
79+  if  not  future .done ():
80+  if  exc :
81+  future .set_exception (exc )
82+  else :
83+  future .set_result (res )
84+ 
85+  del  self ._spawned [future ]
8086 if  self .is_empty :
8187 self ._release_joined ()
8288
8389 async  def  _spawn (self , future , coro , cb = None , ctx = None ):
90+  acq_error  =  False 
8491 try :
8592 await  self .semaphore .acquire ()
8693 except  Exception  as  e :
87-  future .set_exception (e )
88-  self ._waiting .remove (future )
89-  wrapped  =  self ._wrap (coro , future , cb = cb , ctx = ctx )
90-  self .loop .create_task (wrapped )
94+  acq_error  =  True 
95+  if  not  future .done ():
96+  future .set_exception (e )
97+  finally :
98+  del  self ._waiting [future ]
99+ 
100+  if  future .done ():
101+  if  not  acq_error  and  future .cancelled (): # outside action 
102+  self .semaphore .release ()
103+  else : # all good, can spawn now 
104+  wrapped  =  self ._wrap (coro , future , cb = cb , ctx = ctx )
105+  task  =  self .loop .create_task (wrapped )
106+  self ._spawned [future ] =  task 
91107 return  future 
92108
93109 async  def  spawn_n (self , coro , cb = None , ctx = None ):
94110 future  =  self .loop .create_future ()
95-  self ._waiting . append ( future )
96-  self .loop . create_task ( self . _spawn ( future ,  coro ,  cb = cb ,  ctx = ctx )) 
111+  task   =   self .loop . create_task ( self . _spawn ( future ,  coro ,  cb = cb ,  ctx = ctx ) )
112+  self ._waiting [ future ]  =   task 
97113 return  future 
98114
99115 async  def  spawn (self , coro , cb = None , ctx = None ):
100116 future  =  self .loop .create_future ()
101-  self ._waiting . append ( future ) 
117+  self ._waiting [ future ]  =   self . loop . create_future ()  # TODO omg ??? 
102118 return  await  self ._spawn (future , coro , cb = cb , ctx = ctx )
103119
104120 async  def  exec (self , coro , cb = None , ctx = None ):
@@ -113,16 +129,36 @@ async def map_n(self, fn, iterable):
113129
114130 async  def  map (self , fn , iterable , exc_as_result = True ):
115131 futures  =  await  self .map_n (fn , iterable )
116-  await  self .join ()
117- 
118-  results  =  []
119-  for  fut  in  futures :
120-  res  =  _get_future_result (fut , exc_as_result )
121-  results .append (res )
122-  return  results 
132+  await  aio .wait (futures )
133+  return  [result_noraise (fut , exc_as_result ) for  fut  in  futures ]
123134
124135 async  def  iterwait (self , * arg , ** kw ): # TODO there's a way to support 3.5? 
125136 raise  NotImplementedError ('python3.6+ required' )
126137
127138 async  def  itermap (self , * arg , ** kw ): # TODO there's a way to support 3.5? 
128139 raise  NotImplementedError ('python3.6+ required' )
140+ 
141+  def  _cancel (self , * futures ):
142+  tasks , _futures  =  [], []
143+ 
144+  if  not  len (futures ): # meaning cancel all 
145+  tasks .extend (self ._waiting .values ())
146+  tasks .extend (self ._spawned .values ())
147+  _futures .extend (self ._waiting .keys ())
148+  _futures .extend (self ._spawned .keys ())
149+  else :
150+  for  fut  in  futures :
151+  task  =  self ._spawned .get (fut , self ._waiting .get (fut ))
152+  if  task :
153+  tasks .append (task )
154+  _futures .append (fut )
155+ 
156+  cancelled  =  sum ([1  for  fut  in  tasks  if  fut .cancel ()])
157+  return  cancelled , _futures 
158+ 
159+  async  def  cancel (self , * futures , exc_as_result = True ):
160+  cancelled , _futures  =  self ._cancel (* futures )
161+  await  aio .sleep (0 ) # let them actually cancel 
162+  # need to collect them anyway, to supress warnings 
163+  results  =  [result_noraise (fut , exc_as_result ) for  fut  in  _futures ]
164+  return  cancelled , results 
0 commit comments