1414import  urllib .parse 
1515
1616from  . import  cursor 
17+ from  . import  exceptions 
1718from  . import  introspection 
1819from  . import  prepared_stmt 
1920from  . import  protocol 
2021from  . import  serverversion 
2122from  . import  transaction 
2223
2324
24- class  Connection :
25+ class  ConnectionMeta (type ):
26+ 
27+  def  __instancecheck__ (cls , instance ):
28+  mro  =  type (instance ).__mro__ 
29+  return  Connection  in  mro  or  _ConnectionProxy  in  mro 
30+ 
31+ 
32+ class  Connection (metaclass = ConnectionMeta ):
2533 """A representation of a database session. 
2634
2735 Connections are created by calling :func:`~asyncpg.connection.connect`. 
@@ -32,7 +40,7 @@ class Connection:
3240 '_stmt_cache_max_size' , '_stmt_cache' , '_stmts_to_close' ,
3341 '_addr' , '_opts' , '_command_timeout' , '_listeners' ,
3442 '_server_version' , '_server_caps' , '_intro_query' ,
35-  '_reset_query' )
43+  '_reset_query' ,  '_proxy' )
3644
3745 def  __init__ (self , protocol , transport , loop , addr , opts , * ,
3846 statement_cache_size , command_timeout ):
@@ -70,6 +78,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
7078 self ._intro_query  =  introspection .INTRO_LOOKUP_TYPES 
7179
7280 self ._reset_query  =  None 
81+  self ._proxy  =  None 
7382
7483 async  def  add_listener (self , channel , callback ):
7584 """Add a listener for Postgres notifications. 
@@ -478,9 +487,18 @@ def _notify(self, pid, channel, payload):
478487 if  channel  not  in self ._listeners :
479488 return 
480489
490+  if  self ._proxy  is  None :
491+  con_ref  =  self 
492+  else :
493+  # `_proxy` is not None when the connection is a member 
494+  # of a connection pool. Which means that the user is working 
495+  # with a PooledConnectionProxy instance, and expects to see it 
496+  # (and not the actual Connection) in their event callbacks. 
497+  con_ref  =  self ._proxy 
498+ 
481499 for  cb  in  self ._listeners [channel ]:
482500 try :
483-  cb (self , pid , channel , payload )
501+  cb (con_ref , pid , channel , payload )
484502 except  Exception  as  ex :
485503 self ._loop .call_exception_handler ({
486504 'message' : 'Unhandled exception in asyncpg notification ' 
@@ -517,6 +535,14 @@ def _get_reset_query(self):
517535
518536 return  _reset_query 
519537
538+  def  _set_proxy (self , proxy ):
539+  if  self ._proxy  is  not None  and  proxy  is  not None :
540+  # Should not happen unless there is a bug in `Pool`. 
541+  raise  exceptions .InterfaceError (
542+  'internal asyncpg error: connection is already proxied' )
543+ 
544+  self ._proxy  =  proxy 
545+ 
520546
521547async  def  connect (dsn = None , * ,
522548 host = None , port = None ,
@@ -526,7 +552,7 @@ async def connect(dsn=None, *,
526552 timeout = 60 ,
527553 statement_cache_size = 100 ,
528554 command_timeout = None ,
529-  connection_class = Connection ,
555+  __connection_class__ = Connection ,
530556 ** opts ):
531557 """A coroutine to establish a connection to a PostgreSQL server. 
532558
@@ -564,11 +590,7 @@ async def connect(dsn=None, *,
564590 :param float command_timeout: the default timeout for operations on 
565591 this connection (the default is no timeout). 
566592
567-  :param builtins.type connection_class: A class used to represent 
568-  the connection. 
569-  Defaults to :class:`~asyncpg.connection.Connection`. 
570- 
571-  :return: A *connection_class* instance. 
593+  :return: A :class:`~asyncpg.connection.Connection` instance. 
572594
573595 Example: 
574596
@@ -582,10 +604,6 @@ async def connect(dsn=None, *,
582604 ... print(types) 
583605 >>> asyncio.get_event_loop().run_until_complete(run()) 
584606 [<Record typname='bool' typnamespace=11 ... 
585- 
586- 
587-  .. versionadded:: 0.10.0 
588-  *connection_class* argument. 
589607 """ 
590608 if  loop  is  None :
591609 loop  =  asyncio .get_event_loop ()
@@ -629,13 +647,18 @@ async def connect(dsn=None, *,
629647 tr .close ()
630648 raise 
631649
632-  con  =  connection_class (pr , tr , loop , addr , opts ,
633-  statement_cache_size = statement_cache_size ,
634-  command_timeout = command_timeout )
650+  con  =  __connection_class__ (pr , tr , loop , addr , opts ,
651+    statement_cache_size = statement_cache_size ,
652+    command_timeout = command_timeout )
635653 pr .set_connection (con )
636654 return  con 
637655
638656
657+ class  _ConnectionProxy :
658+  # Base class to enable `isinstance(Connection)` check. 
659+  __slots__  =  ()
660+ 
661+ 
639662def  _parse_connect_params (* , dsn , host , port , user ,
640663 password , database , opts ):
641664
0 commit comments