@@ -31,7 +31,8 @@ class Connection:
3131 '_type_by_name_stmt' , '_top_xact' , '_uid' , '_aborted' ,
3232 '_stmt_cache_max_size' , '_stmt_cache' , '_stmts_to_close' ,
3333 '_addr' , '_opts' , '_command_timeout' , '_listeners' ,
34- '_server_version' , '_intro_query' )
34+ '_server_version' , '_server_caps' , '_intro_query' ,
35+ '_reset_query' )
3536
3637 def __init__ (self , protocol , transport , loop , addr , opts , * ,
3738 statement_cache_size , command_timeout ):
@@ -55,15 +56,21 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
5556
5657 self ._listeners = {}
5758
58- ver_string = self ._protocol .get_settings ().server_version
59+ settings = self ._protocol .get_settings ()
60+ ver_string = settings .server_version
5961 self ._server_version = \
6062 serverversion .split_server_version_string (ver_string )
6163
64+ self ._server_caps = _detect_server_capabilities (
65+ self ._server_version , settings )
66+
6267 if self ._server_version < (9 , 2 ):
6368 self ._intro_query = introspection .INTRO_LOOKUP_TYPES_91
6469 else :
6570 self ._intro_query = introspection .INTRO_LOOKUP_TYPES
6671
72+ self ._reset_query = None
73+
6774 async def add_listener (self , channel , callback ):
6875 """Add a listener for Postgres notifications.
6976
@@ -107,6 +114,7 @@ def get_server_version(self):
107114 ServerVersion(major=9, minor=6, micro=1,
108115 releaselevel='final', serial=0)
109116
117+ .. versionadded:: 0.8.0
110118 """
111119 return self ._server_version
112120
@@ -394,22 +402,10 @@ def terminate(self):
394402 self ._protocol .abort ()
395403
396404 async def reset (self ):
397- self ._listeners = {}
398-
399- await self .execute ('''
400- DO $$
401- BEGIN
402- PERFORM * FROM pg_listening_channels() LIMIT 1;
403- IF FOUND THEN
404- UNLISTEN *;
405- END IF;
406- END;
407- $$;
408- SET SESSION AUTHORIZATION DEFAULT;
409- RESET ALL;
410- CLOSE ALL;
411- SELECT pg_advisory_unlock_all();
412- ''' )
405+ self ._listeners .clear ()
406+ reset_query = self ._get_reset_query ()
407+ if reset_query :
408+ await self .execute (reset_query )
413409
414410 def _get_unique_id (self ):
415411 self ._uid += 1
@@ -492,6 +488,35 @@ def _notify(self, pid, channel, payload):
492488 'exception' : ex
493489 })
494490
491+ def _get_reset_query (self ):
492+ if self ._reset_query is not None :
493+ return self ._reset_query
494+
495+ caps = self ._server_caps
496+
497+ _reset_query = ''
498+ if caps .advisory_locks :
499+ _reset_query += 'SELECT pg_advisory_unlock_all();\n '
500+ if caps .cursors :
501+ _reset_query += 'CLOSE ALL;\n '
502+ if caps .notifications and caps .plpgsql :
503+ _reset_query += '''
504+ DO $$
505+ BEGIN
506+ PERFORM * FROM pg_listening_channels() LIMIT 1;
507+ IF FOUND THEN
508+ UNLISTEN *;
509+ END IF;
510+ END;
511+ $$;
512+ '''
513+ if caps .sql_reset :
514+ _reset_query += 'RESET ALL;\n '
515+
516+ self ._reset_query = _reset_query
517+
518+ return _reset_query
519+
495520
496521async def connect (dsn = None , * ,
497522 host = None , port = None ,
@@ -730,3 +755,34 @@ def _create_future(loop):
730755 return asyncio .Future (loop = loop )
731756 else :
732757 return create_future ()
758+
759+
760+ ServerCapabilities = collections .namedtuple (
761+ 'ServerCapabilities' ,
762+ ['advisory_locks' , 'cursors' , 'notifications' , 'plpgsql' , 'sql_reset' ])
763+ ServerCapabilities .__doc__ = 'PostgreSQL server capabilities.'
764+
765+
766+ def _detect_server_capabilities (server_version , connection_settings ):
767+ if hasattr (connection_settings , 'crdb_version' ):
768+ # CocroachDB detected.
769+ advisory_locks = False
770+ cursors = False
771+ notifications = False
772+ plpgsql = False
773+ sql_reset = False
774+ else :
775+ # Standard PostgreSQL server assumed.
776+ advisory_locks = True
777+ cursors = True
778+ notifications = True
779+ plpgsql = True
780+ sql_reset = True
781+
782+ return ServerCapabilities (
783+ advisory_locks = advisory_locks ,
784+ cursors = cursors ,
785+ notifications = notifications ,
786+ plpgsql = plpgsql ,
787+ sql_reset = sql_reset
788+ )
0 commit comments