1111import  collections .abc 
1212import  functools 
1313import  itertools 
14+ import  inspect 
1415import  os 
1516import  sys 
1617import  time 
1718import  traceback 
19+ import  typing 
1820import  warnings 
1921import  weakref 
2022
@@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
133135 :param str channel: Channel to listen on. 
134136
135137 :param callable callback: 
136-  A callable receiving the following arguments: 
138+  A callable or a coroutine function receiving the following 
139+  arguments: 
137140 **connection**: a Connection the callback is registered with; 
138141 **pid**: PID of the Postgres server that sent the notification; 
139142 **channel**: name of the channel the notification was sent to; 
140143 **payload**: the payload. 
144+ 
145+  .. versionchanged:: 0.24.0 
146+  The ``callback`` argument may be a coroutine function. 
141147 """ 
142148 self ._check_open ()
143149 if  channel  not  in self ._listeners :
144150 await  self .fetch ('LISTEN {}' .format (utils ._quote_ident (channel )))
145151 self ._listeners [channel ] =  set ()
146-  self ._listeners [channel ].add (callback )
152+  self ._listeners [channel ].add (_Callback . from_callable ( callback ) )
147153
148154 async  def  remove_listener (self , channel , callback ):
149155 """Remove a listening callback on the specified channel.""" 
150156 if  self .is_closed ():
151157 return 
152158 if  channel  not  in self ._listeners :
153159 return 
154-  if  callback  not  in self ._listeners [channel ]:
160+  cb  =  _Callback .from_callable (callback )
161+  if  cb  not  in self ._listeners [channel ]:
155162 return 
156-  self ._listeners [channel ].remove (callback )
163+  self ._listeners [channel ].remove (cb )
157164 if  not  self ._listeners [channel ]:
158165 del  self ._listeners [channel ]
159166 await  self .fetch ('UNLISTEN {}' .format (utils ._quote_ident (channel )))
@@ -166,44 +173,51 @@ def add_log_listener(self, callback):
166173 DEBUG, INFO, or LOG. 
167174
168175 :param callable callback: 
169-  A callable receiving the following arguments: 
176+  A callable or a coroutine function receiving the following 
177+  arguments: 
170178 **connection**: a Connection the callback is registered with; 
171179 **message**: the `exceptions.PostgresLogMessage` message. 
172180
173181 .. versionadded:: 0.12.0 
182+ 
183+  .. versionchanged:: 0.24.0 
184+  The ``callback`` argument may be a coroutine function. 
174185 """ 
175186 if  self .is_closed ():
176187 raise  exceptions .InterfaceError ('connection is closed' )
177-  self ._log_listeners .add (callback )
188+  self ._log_listeners .add (_Callback . from_callable ( callback ) )
178189
179190 def  remove_log_listener (self , callback ):
180191 """Remove a listening callback for log messages. 
181192
182193 .. versionadded:: 0.12.0 
183194 """ 
184-  self ._log_listeners .discard (callback )
195+  self ._log_listeners .discard (_Callback . from_callable ( callback ) )
185196
186197 def  add_termination_listener (self , callback ):
187198 """Add a listener that will be called when the connection is closed. 
188199
189200 :param callable callback: 
190-  A callable receiving one argument: 
201+  A callable or a coroutine function  receiving one argument: 
191202 **connection**: a Connection the callback is registered with. 
192203
193204 .. versionadded:: 0.21.0 
205+ 
206+  .. versionchanged:: 0.24.0 
207+  The ``callback`` argument may be a coroutine function. 
194208 """ 
195-  self ._termination_listeners .add (callback )
209+  self ._termination_listeners .add (_Callback . from_callable ( callback ) )
196210
197211 def  remove_termination_listener (self , callback ):
198212 """Remove a listening callback for connection termination. 
199213
200214 :param callable callback: 
201-  The callable that was passed to 
215+  The callable or coroutine function  that was passed to 
202216 :meth:`Connection.add_termination_listener`. 
203217
204218 .. versionadded:: 0.21.0 
205219 """ 
206-  self ._termination_listeners .discard (callback )
220+  self ._termination_listeners .discard (_Callback . from_callable ( callback ) )
207221
208222 def  get_server_pid (self ):
209223 """Return the PID of the Postgres server the connection is bound to.""" 
@@ -1449,35 +1463,21 @@ def _process_log_message(self, fields, last_query):
14491463
14501464 con_ref  =  self ._unwrap ()
14511465 for  cb  in  self ._log_listeners :
1452-  self ._loop .call_soon (
1453-  self ._call_log_listener , cb , con_ref , message )
1454- 
1455-  def  _call_log_listener (self , cb , con_ref , message ):
1456-  try :
1457-  cb (con_ref , message )
1458-  except  Exception  as  ex :
1459-  self ._loop .call_exception_handler ({
1460-  'message' : 'Unhandled exception in asyncpg log message ' 
1461-  'listener callback {!r}' .format (cb ),
1462-  'exception' : ex 
1463-  })
1466+  if  cb .is_async :
1467+  self ._loop .create_task (cb .cb (con_ref , message ))
1468+  else :
1469+  self ._loop .call_soon (cb .cb , con_ref , message )
14641470
14651471 def  _call_termination_listeners (self ):
14661472 if  not  self ._termination_listeners :
14671473 return 
14681474
14691475 con_ref  =  self ._unwrap ()
14701476 for  cb  in  self ._termination_listeners :
1471-  try :
1472-  cb (con_ref )
1473-  except  Exception  as  ex :
1474-  self ._loop .call_exception_handler ({
1475-  'message' : (
1476-  'Unhandled exception in asyncpg connection ' 
1477-  'termination listener callback {!r}' .format (cb )
1478-  ),
1479-  'exception' : ex 
1480-  })
1477+  if  cb .is_async :
1478+  self ._loop .create_task (cb .cb (con_ref ))
1479+  else :
1480+  self ._loop .call_soon (cb .cb , con_ref )
14811481
14821482 self ._termination_listeners .clear ()
14831483
@@ -1487,18 +1487,10 @@ def _process_notification(self, pid, channel, payload):
14871487
14881488 con_ref  =  self ._unwrap ()
14891489 for  cb  in  self ._listeners [channel ]:
1490-  self ._loop .call_soon (
1491-  self ._call_listener , cb , con_ref , pid , channel , payload )
1492- 
1493-  def  _call_listener (self , cb , con_ref , pid , channel , payload ):
1494-  try :
1495-  cb (con_ref , pid , channel , payload )
1496-  except  Exception  as  ex :
1497-  self ._loop .call_exception_handler ({
1498-  'message' : 'Unhandled exception in asyncpg notification ' 
1499-  'listener callback {!r}' .format (cb ),
1500-  'exception' : ex 
1501-  })
1490+  if  cb .is_async :
1491+  self ._loop .create_task (cb .cb (con_ref , pid , channel , payload ))
1492+  else :
1493+  self ._loop .call_soon (cb .cb , con_ref , pid , channel , payload )
15021494
15031495 def  _unwrap (self ):
15041496 if  self ._proxy  is  None :
@@ -2173,6 +2165,26 @@ def _maybe_cleanup(self):
21732165 self ._on_remove (old_entry ._statement )
21742166
21752167
2168+ class  _Callback (typing .NamedTuple ):
2169+ 
2170+  cb : typing .Callable [..., None ]
2171+  is_async : bool 
2172+ 
2173+  @classmethod  
2174+  def  from_callable (cls , cb : typing .Callable [..., None ]) ->  '_Callback' :
2175+  if  inspect .iscoroutinefunction (cb ):
2176+  is_async  =  True 
2177+  elif  callable (cb ):
2178+  is_async  =  False 
2179+  else :
2180+  raise  exceptions .InterfaceError (
2181+  'expected a callable or an `async def` function,' 
2182+  'got {!r}' .format (cb )
2183+  )
2184+ 
2185+  return  cls (cb , is_async )
2186+ 
2187+ 
21762188class  _Atomic :
21772189 __slots__  =  ('_acquired' ,)
21782190
0 commit comments