3535 'password' ,
3636 'database' ,
3737 'ssl' ,
38-  'ssl_is_advisory ' ,
38+  'alt_retry_ssl_first ' ,
3939 'connect_timeout' ,
4040 'server_settings' ,
4141 ])
@@ -402,8 +402,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402402 if  ssl  is  None  and  have_tcp_addrs :
403403 ssl  =  'prefer' 
404404
405-  # ssl_is_advisory is only allowed to come from the sslmode parameter. 
406-  ssl_is_advisory  =  None 
405+  # alt_retry_ssl_first is particularly for "allow" and "prefer" 
406+  # to alternatively try SSL/non-SSL connections (once each if supported): 
407+  # False - allow (try non-SSL first) 
408+  # True - prefer (try SSL first) 
409+  # None - other (don't retry, stick with the "ssl" parameter) 
410+  alt_retry_ssl_first  =  None 
411+ 
407412 if  isinstance (ssl , str ):
408413 SSLMODES  =  {
409414 'disable' : 0 ,
@@ -420,26 +425,21 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
420425 raise  exceptions .InterfaceError (
421426 '`sslmode` parameter must be one of: {}' .format (modes ))
422427
423-  # sslmode 'allow' is currently handled as 'prefer' because we're 
424-  # missing the "retry with SSL" behavior for 'allow', but do have the 
425-  # "retry without SSL" behavior for 'prefer'. 
426-  # Not changing 'allow' to 'prefer' here would be effectively the same 
427-  # as changing 'allow' to 'disable'. 
428428 if  sslmode  ==  SSLMODES ['allow' ]:
429-  sslmode  =  SSLMODES ['prefer' ]
429+  alt_retry_ssl_first  =  False 
430+  elif  sslmode  ==  SSLMODES ['prefer' ]:
431+  alt_retry_ssl_first  =  True 
430432
431433 # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html 
432434 # Not implemented: sslcert & sslkey & sslrootcert & sslcrl params. 
433-  if  sslmode  <=   SSLMODES ['allow' ]:
435+  if  sslmode  <  SSLMODES ['allow' ]:
434436 ssl  =  False 
435-  ssl_is_advisory  =  sslmode  >=  SSLMODES ['allow' ]
436437 else :
437438 ssl  =  ssl_module .create_default_context ()
438439 ssl .check_hostname  =  sslmode  >=  SSLMODES ['verify-full' ]
439440 ssl .verify_mode  =  ssl_module .CERT_REQUIRED 
440441 if  sslmode  <=  SSLMODES ['require' ]:
441442 ssl .verify_mode  =  ssl_module .CERT_NONE 
442-  ssl_is_advisory  =  sslmode  <=  SSLMODES ['prefer' ]
443443 elif  ssl  is  True :
444444 ssl  =  ssl_module .create_default_context ()
445445
@@ -453,7 +453,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453453
454454 params  =  _ConnectionParameters (
455455 user = user , password = password , database = database , ssl = ssl ,
456-  ssl_is_advisory = ssl_is_advisory , connect_timeout = connect_timeout ,
456+  alt_retry_ssl_first = alt_retry_ssl_first ,
457+  connect_timeout = connect_timeout ,
457458 server_settings = server_settings )
458459
459460 return  addrs , params 
@@ -520,9 +521,8 @@ def data_received(self, data):
520521 data  ==  b'N' ):
521522 # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, 
522523 # since the only way to get ssl_is_advisory is from 
523-  # sslmode=prefer (or sslmode=allow). But be extra sure to 
524-  # disallow insecure connections when the ssl context asks for 
525-  # real security. 
524+  # sslmode=prefer. But be extra sure to disallow insecure 
525+  # connections when the ssl context asks for real security. 
526526 self .on_data .set_result (False )
527527 else :
528528 self .on_data .set_exception (
@@ -566,6 +566,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566566 new_tr  =  tr 
567567
568568 pg_proto  =  protocol_factory ()
569+  pg_proto .is_ssl  =  do_ssl_upgrade 
569570 pg_proto .connection_made (new_tr )
570571 new_tr .set_protocol (pg_proto )
571572
@@ -584,7 +585,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584585 tr .close ()
585586
586587 try :
587-  return  await  conn_factory (sock = sock )
588+  new_tr , pg_proto  =  await  conn_factory (sock = sock )
589+  pg_proto .is_ssl  =  do_ssl_upgrade 
590+  return  new_tr , pg_proto 
588591 except  (Exception , asyncio .CancelledError ):
589592 sock .close ()
590593 raise 
@@ -605,8 +608,6 @@ async def _connect_addr(
605608 if  timeout  <=  0 :
606609 raise  asyncio .TimeoutError 
607610
608-  connected  =  _create_future (loop )
609- 
610611 params_input  =  params 
611612 if  callable (params .password ):
612613 if  inspect .iscoroutinefunction (params .password ):
@@ -615,6 +616,44 @@ async def _connect_addr(
615616 password  =  params .password ()
616617
617618 params  =  params ._replace (password = password )
619+  args  =  (addr , loop , config , connection_class , record_class , params_input )
620+ 
621+  # skip retry if alt_retry is not enabled 
622+  if  params .alt_retry_ssl_first  is  None :
623+  return  await  __connect_addr (params , timeout , * args )
624+ 
625+  # prepare the params (which attempt has ssl) for the 2 attempts 
626+  params_retry  =  params ._replace (ssl = None )
627+  if  not  params .alt_retry_ssl_first :
628+  params , params_retry  =  params_retry , params 
629+ 
630+  # first attempt 
631+  before  =  time .monotonic ()
632+  try :
633+  return  await  __connect_addr (params , timeout , * args )
634+  except  ConnectionError :
635+  pass 
636+ 
637+  # the second attempt with alt_retry_ssl_first=None 
638+  timeout  -=  time .monotonic () -  before 
639+  if  timeout  <=  0 :
640+  raise  asyncio .TimeoutError 
641+  else :
642+  params_retry  =  params_retry ._replace (alt_retry_ssl_first = None )
643+  return  await  __connect_addr (params_retry , timeout , * args )
644+ 
645+ 
646+ async  def  __connect_addr (
647+  params ,
648+  timeout ,
649+  addr ,
650+  loop ,
651+  config ,
652+  connection_class ,
653+  record_class ,
654+  params_input ,
655+ ):
656+  connected  =  _create_future (loop )
618657
619658 proto_factory  =  lambda : protocol .Protocol (
620659 addr , connected , params , record_class , loop )
@@ -625,7 +664,7 @@ async def _connect_addr(
625664 elif  params .ssl :
626665 connector  =  _create_ssl_connection (
627666 proto_factory , * addr , loop = loop , ssl_context = params .ssl ,
628-  ssl_is_advisory = params .ssl_is_advisory )
667+  ssl_is_advisory = params .alt_retry_ssl_first )
629668 else :
630669 connector  =  loop .create_connection (proto_factory , * addr )
631670
@@ -638,6 +677,23 @@ async def _connect_addr(
638677 if  timeout  <=  0 :
639678 raise  asyncio .TimeoutError 
640679 await  compat .wait_for (connected , timeout = timeout )
680+  except  exceptions .InvalidAuthorizationSpecificationError :
681+  tr .close ()
682+ 
683+  # pr.is_ssl is a bool, so this equal test implies 
684+  # alt_retry_ssl_first is not None (should do alt_retry) 
685+  if  params .alt_retry_ssl_first  ==  pr .is_ssl :
686+  # Elevate the error to ConnectionError to trigger retry 
687+  raise  ConnectionError ("Connection rejected trying {} SSL" .format (
688+  'with'  if  pr .is_ssl  else  'without' ))
689+ 
690+  else :
691+  # Don't retry if alt_retry_ssl_first is None, or we don't need to 
692+  # (alt_retry_ssl_first=True and pr.is_ssl=False means the server 
693+  # doesn't support SSL, and we've already tried to Startup without 
694+  # SSL but failed; The opposite case doesn't exist). 
695+  raise 
696+ 
641697 except  (Exception , asyncio .CancelledError ):
642698 tr .close ()
643699 raise 
@@ -684,6 +740,7 @@ class CancelProto(asyncio.Protocol):
684740
685741 def  __init__ (self ):
686742 self .on_disconnect  =  _create_future (loop )
743+  self .is_ssl  =  False 
687744
688745 def  connection_lost (self , exc ):
689746 if  not  self .on_disconnect .done ():
@@ -692,17 +749,31 @@ def connection_lost(self, exc):
692749 if  isinstance (addr , str ):
693750 tr , pr  =  await  loop .create_unix_connection (CancelProto , addr )
694751 else :
695-  if  params .ssl :
696-  tr , pr  =  await  _create_ssl_connection (
697-  CancelProto ,
698-  * addr ,
699-  loop = loop ,
700-  ssl_context = params .ssl ,
701-  ssl_is_advisory = params .ssl_is_advisory )
752+  async  def  _connect (params_in , ssl_is_advisory ):
753+  if  params_in .ssl :
754+  return  await  _create_ssl_connection (
755+  CancelProto ,
756+  * addr ,
757+  loop = loop ,
758+  ssl_context = params_in .ssl ,
759+  ssl_is_advisory = ssl_is_advisory )
760+  else :
761+  rv  =  await  loop .create_connection (
762+  CancelProto , * addr )
763+  _set_nodelay (_get_socket (rv [0 ]))
764+  return  rv 
765+ 
766+  if  params .alt_retry_ssl_first  is  None :
767+  tr , pr  =  await  _connect (params , False )
702768 else :
703-  tr , pr  =  await  loop .create_connection (
704-  CancelProto , * addr )
705-  _set_nodelay (_get_socket (tr ))
769+  params_retry  =  params ._replace (ssl = None )
770+  if  not  params .alt_retry_ssl_first :
771+  params , params_retry  =  params_retry , params 
772+  try :
773+  tr , pr  =  await  _connect (params , True )
774+  except  ConnectionError :
775+  tr , pr  =  await  _connect (
776+  params ._replace (alt_retry_ssl_first = None ), False )
706777
707778 # Pack a CancelRequest message 
708779 msg  =  struct .pack ('!llll' , 16 , 80877102 , backend_pid , backend_secret )
0 commit comments