1111import collections .abc
1212import functools
1313import itertools
14+ import inspect
1415import os
1516import sys
1617import time
1718import traceback
19+ import typing
1820import warnings
1921import weakref
2022
@@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
133135 :param str channel: Channel to listen on.
134136
135137 :param callable callback:
136- A callable receiving the following arguments:
138+ A callable or a coroutine function receiving the following
139+ arguments:
137140 **connection**: a Connection the callback is registered with;
138141 **pid**: PID of the Postgres server that sent the notification;
139142 **channel**: name of the channel the notification was sent to;
140143 **payload**: the payload.
144+
145+ .. versionchanged:: 0.24.0
146+ The ``callback`` argument may be a coroutine function.
141147 """
142148 self ._check_open ()
143149 if channel not in self ._listeners :
144150 await self .fetch ('LISTEN {}' .format (utils ._quote_ident (channel )))
145151 self ._listeners [channel ] = set ()
146- self ._listeners [channel ].add (callback )
152+ self ._listeners [channel ].add (_Callback . from_callable ( callback ) )
147153
148154 async def remove_listener (self , channel , callback ):
149155 """Remove a listening callback on the specified channel."""
150156 if self .is_closed ():
151157 return
152158 if channel not in self ._listeners :
153159 return
154- if callback not in self ._listeners [channel ]:
160+ cb = _Callback .from_callable (callback )
161+ if cb not in self ._listeners [channel ]:
155162 return
156- self ._listeners [channel ].remove (callback )
163+ self ._listeners [channel ].remove (cb )
157164 if not self ._listeners [channel ]:
158165 del self ._listeners [channel ]
159166 await self .fetch ('UNLISTEN {}' .format (utils ._quote_ident (channel )))
@@ -166,44 +173,51 @@ def add_log_listener(self, callback):
166173 DEBUG, INFO, or LOG.
167174
168175 :param callable callback:
169- A callable receiving the following arguments:
176+ A callable or a coroutine function receiving the following
177+ arguments:
170178 **connection**: a Connection the callback is registered with;
171179 **message**: the `exceptions.PostgresLogMessage` message.
172180
173181 .. versionadded:: 0.12.0
182+
183+ .. versionchanged:: 0.24.0
184+ The ``callback`` argument may be a coroutine function.
174185 """
175186 if self .is_closed ():
176187 raise exceptions .InterfaceError ('connection is closed' )
177- self ._log_listeners .add (callback )
188+ self ._log_listeners .add (_Callback . from_callable ( callback ) )
178189
179190 def remove_log_listener (self , callback ):
180191 """Remove a listening callback for log messages.
181192
182193 .. versionadded:: 0.12.0
183194 """
184- self ._log_listeners .discard (callback )
195+ self ._log_listeners .discard (_Callback . from_callable ( callback ) )
185196
186197 def add_termination_listener (self , callback ):
187198 """Add a listener that will be called when the connection is closed.
188199
189200 :param callable callback:
190- A callable receiving one argument:
201+ A callable or a coroutine function receiving one argument:
191202 **connection**: a Connection the callback is registered with.
192203
193204 .. versionadded:: 0.21.0
205+
206+ .. versionchanged:: 0.24.0
207+ The ``callback`` argument may be a coroutine function.
194208 """
195- self ._termination_listeners .add (callback )
209+ self ._termination_listeners .add (_Callback . from_callable ( callback ) )
196210
197211 def remove_termination_listener (self , callback ):
198212 """Remove a listening callback for connection termination.
199213
200214 :param callable callback:
201- The callable that was passed to
215+ The callable or coroutine function that was passed to
202216 :meth:`Connection.add_termination_listener`.
203217
204218 .. versionadded:: 0.21.0
205219 """
206- self ._termination_listeners .discard (callback )
220+ self ._termination_listeners .discard (_Callback . from_callable ( callback ) )
207221
208222 def get_server_pid (self ):
209223 """Return the PID of the Postgres server the connection is bound to."""
@@ -1430,35 +1444,21 @@ def _process_log_message(self, fields, last_query):
14301444
14311445 con_ref = self ._unwrap ()
14321446 for cb in self ._log_listeners :
1433- self ._loop .call_soon (
1434- self ._call_log_listener , cb , con_ref , message )
1435-
1436- def _call_log_listener (self , cb , con_ref , message ):
1437- try :
1438- cb (con_ref , message )
1439- except Exception as ex :
1440- self ._loop .call_exception_handler ({
1441- 'message' : 'Unhandled exception in asyncpg log message '
1442- 'listener callback {!r}' .format (cb ),
1443- 'exception' : ex
1444- })
1447+ if cb .is_async :
1448+ self ._loop .create_task (cb .cb (con_ref , message ))
1449+ else :
1450+ self ._loop .call_soon (cb .cb , con_ref , message )
14451451
14461452 def _call_termination_listeners (self ):
14471453 if not self ._termination_listeners :
14481454 return
14491455
14501456 con_ref = self ._unwrap ()
14511457 for cb in self ._termination_listeners :
1452- try :
1453- cb (con_ref )
1454- except Exception as ex :
1455- self ._loop .call_exception_handler ({
1456- 'message' : (
1457- 'Unhandled exception in asyncpg connection '
1458- 'termination listener callback {!r}' .format (cb )
1459- ),
1460- 'exception' : ex
1461- })
1458+ if cb .is_async :
1459+ self ._loop .create_task (cb .cb (con_ref ))
1460+ else :
1461+ self ._loop .call_soon (cb .cb , con_ref )
14621462
14631463 self ._termination_listeners .clear ()
14641464
@@ -1468,18 +1468,10 @@ def _process_notification(self, pid, channel, payload):
14681468
14691469 con_ref = self ._unwrap ()
14701470 for cb in self ._listeners [channel ]:
1471- self ._loop .call_soon (
1472- self ._call_listener , cb , con_ref , pid , channel , payload )
1473-
1474- def _call_listener (self , cb , con_ref , pid , channel , payload ):
1475- try :
1476- cb (con_ref , pid , channel , payload )
1477- except Exception as ex :
1478- self ._loop .call_exception_handler ({
1479- 'message' : 'Unhandled exception in asyncpg notification '
1480- 'listener callback {!r}' .format (cb ),
1481- 'exception' : ex
1482- })
1471+ if cb .is_async :
1472+ self ._loop .create_task (cb .cb (con_ref , pid , channel , payload ))
1473+ else :
1474+ self ._loop .call_soon (cb .cb , con_ref , pid , channel , payload )
14831475
14841476 def _unwrap (self ):
14851477 if self ._proxy is None :
@@ -2154,6 +2146,26 @@ def _maybe_cleanup(self):
21542146 self ._on_remove (old_entry ._statement )
21552147
21562148
2149+ class _Callback (typing .NamedTuple ):
2150+
2151+ cb : typing .Callable [..., None ]
2152+ is_async : bool
2153+
2154+ @classmethod
2155+ def from_callable (cls , cb : typing .Callable [..., None ]) -> '_Callback' :
2156+ if inspect .iscoroutinefunction (cb ):
2157+ is_async = True
2158+ elif callable (cb ):
2159+ is_async = False
2160+ else :
2161+ raise exceptions .InterfaceError (
2162+ 'expected a callable or an `async def` function,'
2163+ 'got {!r}' .format (cb )
2164+ )
2165+
2166+ return cls (cb , is_async )
2167+
2168+
21572169class _Atomic :
21582170 __slots__ = ('_acquired' ,)
21592171
0 commit comments