1111import collections .abc
1212import functools
1313import itertools
14+ import inspect
1415import os
1516import sys
1617import time
1718import traceback
19+ import typing
1820import warnings
1921import weakref
2022
@@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
133135 :param str channel: Channel to listen on.
134136
135137 :param callable callback:
136- A callable receiving the following arguments:
138+ A callable or a coroutine function receiving the following
139+ arguments:
137140 **connection**: a Connection the callback is registered with;
138141 **pid**: PID of the Postgres server that sent the notification;
139142 **channel**: name of the channel the notification was sent to;
140143 **payload**: the payload.
144+
145+ .. versionchanged:: 0.24.0
146+ The ``callback`` argument may be a coroutine function.
141147 """
142148 self ._check_open ()
143149 if channel not in self ._listeners :
144150 await self .fetch ('LISTEN {}' .format (utils ._quote_ident (channel )))
145151 self ._listeners [channel ] = set ()
146- self ._listeners [channel ].add (callback )
152+ self ._listeners [channel ].add (_Callback . from_callable ( callback ) )
147153
148154 async def remove_listener (self , channel , callback ):
149155 """Remove a listening callback on the specified channel."""
150156 if self .is_closed ():
151157 return
152158 if channel not in self ._listeners :
153159 return
154- if callback not in self ._listeners [channel ]:
160+ cb = _Callback .from_callable (callback )
161+ if cb not in self ._listeners [channel ]:
155162 return
156- self ._listeners [channel ].remove (callback )
163+ self ._listeners [channel ].remove (cb )
157164 if not self ._listeners [channel ]:
158165 del self ._listeners [channel ]
159166 await self .fetch ('UNLISTEN {}' .format (utils ._quote_ident (channel )))
@@ -166,44 +173,51 @@ def add_log_listener(self, callback):
166173 DEBUG, INFO, or LOG.
167174
168175 :param callable callback:
169- A callable receiving the following arguments:
176+ A callable or a coroutine function receiving the following
177+ arguments:
170178 **connection**: a Connection the callback is registered with;
171179 **message**: the `exceptions.PostgresLogMessage` message.
172180
173181 .. versionadded:: 0.12.0
182+
183+ .. versionchanged:: 0.24.0
184+ The ``callback`` argument may be a coroutine function.
174185 """
175186 if self .is_closed ():
176187 raise exceptions .InterfaceError ('connection is closed' )
177- self ._log_listeners .add (callback )
188+ self ._log_listeners .add (_Callback . from_callable ( callback ) )
178189
179190 def remove_log_listener (self , callback ):
180191 """Remove a listening callback for log messages.
181192
182193 .. versionadded:: 0.12.0
183194 """
184- self ._log_listeners .discard (callback )
195+ self ._log_listeners .discard (_Callback . from_callable ( callback ) )
185196
186197 def add_termination_listener (self , callback ):
187198 """Add a listener that will be called when the connection is closed.
188199
189200 :param callable callback:
190- A callable receiving one argument:
201+ A callable or a coroutine function receiving one argument:
191202 **connection**: a Connection the callback is registered with.
192203
193204 .. versionadded:: 0.21.0
205+
206+ .. versionchanged:: 0.24.0
207+ The ``callback`` argument may be a coroutine function.
194208 """
195- self ._termination_listeners .add (callback )
209+ self ._termination_listeners .add (_Callback . from_callable ( callback ) )
196210
197211 def remove_termination_listener (self , callback ):
198212 """Remove a listening callback for connection termination.
199213
200214 :param callable callback:
201- The callable that was passed to
215+ The callable or coroutine function that was passed to
202216 :meth:`Connection.add_termination_listener`.
203217
204218 .. versionadded:: 0.21.0
205219 """
206- self ._termination_listeners .discard (callback )
220+ self ._termination_listeners .discard (_Callback . from_callable ( callback ) )
207221
208222 def get_server_pid (self ):
209223 """Return the PID of the Postgres server the connection is bound to."""
@@ -1430,8 +1444,11 @@ def _process_log_message(self, fields, last_query):
14301444
14311445 con_ref = self ._unwrap ()
14321446 for cb in self ._log_listeners :
1433- self ._loop .call_soon (
1434- self ._call_log_listener , cb , con_ref , message )
1447+ if cb .is_async :
1448+ self ._loop .create_task (cb .cb (con_ref , message ))
1449+ else :
1450+ self ._loop .call_soon (
1451+ self ._call_log_listener , cb .cb , con_ref , message )
14351452
14361453 def _call_log_listener (self , cb , con_ref , message ):
14371454 try :
@@ -1449,16 +1466,19 @@ def _call_termination_listeners(self):
14491466
14501467 con_ref = self ._unwrap ()
14511468 for cb in self ._termination_listeners :
1452- try :
1453- cb (con_ref )
1454- except Exception as ex :
1455- self ._loop .call_exception_handler ({
1456- 'message' : (
1457- 'Unhandled exception in asyncpg connection '
1458- 'termination listener callback {!r}' .format (cb )
1459- ),
1460- 'exception' : ex
1461- })
1469+ if cb .is_async :
1470+ self ._loop .create_task (cb .cb (con_ref ))
1471+ else :
1472+ try :
1473+ cb .cb (con_ref )
1474+ except Exception as ex :
1475+ self ._loop .call_exception_handler ({
1476+ 'message' : (
1477+ 'Unhandled exception in asyncpg connection '
1478+ 'termination listener callback {!r}' .format (cb )
1479+ ),
1480+ 'exception' : ex
1481+ })
14621482
14631483 self ._termination_listeners .clear ()
14641484
@@ -1468,8 +1488,11 @@ def _process_notification(self, pid, channel, payload):
14681488
14691489 con_ref = self ._unwrap ()
14701490 for cb in self ._listeners [channel ]:
1471- self ._loop .call_soon (
1472- self ._call_listener , cb , con_ref , pid , channel , payload )
1491+ if cb .is_async :
1492+ self ._loop .create_task (cb .cb (con_ref , pid , channel , payload ))
1493+ else :
1494+ self ._loop .call_soon (
1495+ self ._call_listener , cb .cb , con_ref , pid , channel , payload )
14731496
14741497 def _call_listener (self , cb , con_ref , pid , channel , payload ):
14751498 try :
@@ -2154,6 +2177,26 @@ def _maybe_cleanup(self):
21542177 self ._on_remove (old_entry ._statement )
21552178
21562179
2180+ class _Callback (typing .NamedTuple ):
2181+
2182+ cb : typing .Callable [..., None ]
2183+ is_async : bool
2184+
2185+ @classmethod
2186+ def from_callable (cls , cb : typing .Callable [..., None ]) -> '_Callback' :
2187+ if inspect .iscoroutinefunction (cb ):
2188+ is_async = True
2189+ elif callable (cb ):
2190+ is_async = False
2191+ else :
2192+ raise exceptions .InterfaceError (
2193+ 'expected a callable or an `async def` function,'
2194+ 'got {!r}' .format (cb )
2195+ )
2196+
2197+ return cls (cb , is_async )
2198+
2199+
21572200class _Atomic :
21582201 __slots__ = ('_acquired' ,)
21592202
0 commit comments