77
88import asyncio
99import collections
10+ import enum
1011import functools
1112import getpass
1213import os
2829from . import protocol
2930
3031
32+ class SSLMode (enum .IntEnum ):
33+ disable = 0
34+ allow = 1
35+ prefer = 2
36+ require = 3
37+ verify_ca = 4
38+ verify_full = 5
39+
40+ @classmethod
41+ def parse (cls , sslmode ):
42+ if isinstance (sslmode , cls ):
43+ return sslmode
44+ return getattr (cls , sslmode .replace ('-' , '_' ))
45+
46+
3147_ConnectionParameters = collections .namedtuple (
3248 'ConnectionParameters' ,
3349 [
3450 'user' ,
3551 'password' ,
3652 'database' ,
3753 'ssl' ,
38- 'ssl_is_advisory ' ,
54+ 'sslmode ' ,
3955 'connect_timeout' ,
4056 'server_settings' ,
4157 ])
@@ -402,46 +418,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402418 if ssl is None and have_tcp_addrs :
403419 ssl = 'prefer'
404420
405- # ssl_is_advisory is only allowed to come from the sslmode parameter.
406- ssl_is_advisory = None
407- if isinstance (ssl , str ):
408- SSLMODES = {
409- 'disable' : 0 ,
410- 'allow' : 1 ,
411- 'prefer' : 2 ,
412- 'require' : 3 ,
413- 'verify-ca' : 4 ,
414- 'verify-full' : 5 ,
415- }
421+ if isinstance (ssl , (str , SSLMode )):
416422 try :
417- sslmode = SSLMODES [ ssl ]
418- except KeyError :
419- modes = ', ' .join (SSLMODES . keys () )
423+ sslmode = SSLMode . parse ( ssl )
424+ except AttributeError :
425+ modes = ', ' .join (m . name . replace ( '_' , '-' ) for m in SSLMode )
420426 raise exceptions .InterfaceError (
421427 '`sslmode` parameter must be one of: {}' .format (modes ))
422428
423- # sslmode 'allow' is currently handled as 'prefer' because we're
424- # missing the "retry with SSL" behavior for 'allow', but do have the
425- # "retry without SSL" behavior for 'prefer'.
426- # Not changing 'allow' to 'prefer' here would be effectively the same
427- # as changing 'allow' to 'disable'.
428- if sslmode == SSLMODES ['allow' ]:
429- sslmode = SSLMODES ['prefer' ]
430-
431429 # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
432430 # Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
433- if sslmode <= SSLMODES [ ' allow' ] :
431+ if sslmode < SSLMode . allow :
434432 ssl = False
435- ssl_is_advisory = sslmode >= SSLMODES ['allow' ]
436433 else :
437434 ssl = ssl_module .create_default_context ()
438- ssl .check_hostname = sslmode >= SSLMODES [ 'verify-full' ]
435+ ssl .check_hostname = sslmode >= SSLMode . verify_full
439436 ssl .verify_mode = ssl_module .CERT_REQUIRED
440- if sslmode <= SSLMODES [ ' require' ] :
437+ if sslmode <= SSLMode . require :
441438 ssl .verify_mode = ssl_module .CERT_NONE
442- ssl_is_advisory = sslmode <= SSLMODES ['prefer' ]
443439 elif ssl is True :
444440 ssl = ssl_module .create_default_context ()
441+ sslmode = SSLMode .verify_full
442+ else :
443+ sslmode = SSLMode .disable
445444
446445 if server_settings is not None and (
447446 not isinstance (server_settings , dict ) or
@@ -453,7 +452,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453452
454453 params = _ConnectionParameters (
455454 user = user , password = password , database = database , ssl = ssl ,
456- ssl_is_advisory = ssl_is_advisory , connect_timeout = connect_timeout ,
455+ sslmode = sslmode , connect_timeout = connect_timeout ,
457456 server_settings = server_settings )
458457
459458 return addrs , params
@@ -520,9 +519,8 @@ def data_received(self, data):
520519 data == b'N' ):
521520 # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522521 # since the only way to get ssl_is_advisory is from
523- # sslmode=prefer (or sslmode=allow). But be extra sure to
524- # disallow insecure connections when the ssl context asks for
525- # real security.
522+ # sslmode=prefer. But be extra sure to disallow insecure
523+ # connections when the ssl context asks for real security.
526524 self .on_data .set_result (False )
527525 else :
528526 self .on_data .set_exception (
@@ -566,6 +564,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566564 new_tr = tr
567565
568566 pg_proto = protocol_factory ()
567+ pg_proto .is_ssl = do_ssl_upgrade
569568 pg_proto .connection_made (new_tr )
570569 new_tr .set_protocol (pg_proto )
571570
@@ -584,7 +583,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584583 tr .close ()
585584
586585 try :
587- return await conn_factory (sock = sock )
586+ new_tr , pg_proto = await conn_factory (sock = sock )
587+ pg_proto .is_ssl = do_ssl_upgrade
588+ return new_tr , pg_proto
588589 except (Exception , asyncio .CancelledError ):
589590 sock .close ()
590591 raise
@@ -605,8 +606,6 @@ async def _connect_addr(
605606 if timeout <= 0 :
606607 raise asyncio .TimeoutError
607608
608- connected = _create_future (loop )
609-
610609 params_input = params
611610 if callable (params .password ):
612611 if inspect .iscoroutinefunction (params .password ):
@@ -615,6 +614,49 @@ async def _connect_addr(
615614 password = params .password ()
616615
617616 params = params ._replace (password = password )
617+ args = (addr , loop , config , connection_class , record_class , params_input )
618+
619+ # prepare the params (which attempt has ssl) for the 2 attempts
620+ if params .sslmode == SSLMode .allow :
621+ params_retry = params
622+ params = params ._replace (ssl = None )
623+ elif params .sslmode == SSLMode .prefer :
624+ params_retry = params ._replace (ssl = None )
625+ else :
626+ # skip retry if we don't have to
627+ return await __connect_addr (params , timeout , False , * args )
628+
629+ # first attempt
630+ before = time .monotonic ()
631+ try :
632+ return await __connect_addr (params , timeout , True , * args )
633+ except _Retry :
634+ pass
635+
636+ # second attempt
637+ timeout -= time .monotonic () - before
638+ if timeout <= 0 :
639+ raise asyncio .TimeoutError
640+ else :
641+ return await __connect_addr (params_retry , timeout , False , * args )
642+
643+
644+ class _Retry (Exception ):
645+ pass
646+
647+
648+ async def __connect_addr (
649+ params ,
650+ timeout ,
651+ retry ,
652+ addr ,
653+ loop ,
654+ config ,
655+ connection_class ,
656+ record_class ,
657+ params_input ,
658+ ):
659+ connected = _create_future (loop )
618660
619661 proto_factory = lambda : protocol .Protocol (
620662 addr , connected , params , record_class , loop )
@@ -625,7 +667,7 @@ async def _connect_addr(
625667 elif params .ssl :
626668 connector = _create_ssl_connection (
627669 proto_factory , * addr , loop = loop , ssl_context = params .ssl ,
628- ssl_is_advisory = params .ssl_is_advisory )
670+ ssl_is_advisory = params .sslmode == SSLMode . prefer )
629671 else :
630672 connector = loop .create_connection (proto_factory , * addr )
631673
@@ -638,6 +680,35 @@ async def _connect_addr(
638680 if timeout <= 0 :
639681 raise asyncio .TimeoutError
640682 await compat .wait_for (connected , timeout = timeout )
683+ except (
684+ exceptions .InvalidAuthorizationSpecificationError ,
685+ exceptions .ConnectionDoesNotExistError , # seen on Windows
686+ ):
687+ tr .close ()
688+
689+ # retry=True here is a redundant check because we don't want to
690+ # accidentally raise the internal _Retry to the outer world
691+ if retry and (
692+ params .sslmode == SSLMode .allow and not pr .is_ssl or
693+ params .sslmode == SSLMode .prefer and pr .is_ssl
694+ ):
695+ # Trigger retry when:
696+ # 1. First attempt with sslmode=allow, ssl=None failed
697+ # 2. First attempt with sslmode=prefer, ssl=ctx failed while the
698+ # server claimed to support SSL (returning "S" for SSLRequest)
699+ # (likely because pg_hba.conf rejected the connection)
700+ raise _Retry ()
701+
702+ else :
703+ # but will NOT retry if:
704+ # 1. First attempt with sslmode=prefer failed but the server
705+ # doesn't support SSL (returning 'N' for SSLRequest), because
706+ # we already tried to connect without SSL thru ssl_is_advisory
707+ # 2. Second attempt with sslmode=prefer, ssl=None failed
708+ # 3. Second attempt with sslmode=allow, ssl=ctx failed
709+ # 4. Any other sslmode
710+ raise
711+
641712 except (Exception , asyncio .CancelledError ):
642713 tr .close ()
643714 raise
@@ -684,6 +755,7 @@ class CancelProto(asyncio.Protocol):
684755
685756 def __init__ (self ):
686757 self .on_disconnect = _create_future (loop )
758+ self .is_ssl = False
687759
688760 def connection_lost (self , exc ):
689761 if not self .on_disconnect .done ():
@@ -692,13 +764,13 @@ def connection_lost(self, exc):
692764 if isinstance (addr , str ):
693765 tr , pr = await loop .create_unix_connection (CancelProto , addr )
694766 else :
695- if params .ssl :
767+ if params .ssl and params . sslmode != SSLMode . allow :
696768 tr , pr = await _create_ssl_connection (
697769 CancelProto ,
698770 * addr ,
699771 loop = loop ,
700772 ssl_context = params .ssl ,
701- ssl_is_advisory = params .ssl_is_advisory )
773+ ssl_is_advisory = params .sslmode == SSLMode . prefer )
702774 else :
703775 tr , pr = await loop .create_connection (
704776 CancelProto , * addr )
0 commit comments