@@ -504,6 +504,95 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
504504 return  addrs , params , config 
505505
506506
507+ class  TLSUpgradeProto (asyncio .Protocol ):
508+  def  __init__ (self , loop , host , port , ssl_context , ssl_is_advisory ):
509+  self .on_data  =  _create_future (loop )
510+  self .host  =  host 
511+  self .port  =  port 
512+  self .ssl_context  =  ssl_context 
513+  self .ssl_is_advisory  =  ssl_is_advisory 
514+ 
515+  def  data_received (self , data ):
516+  if  data  ==  b'S' :
517+  self .on_data .set_result (True )
518+  elif  (self .ssl_is_advisory  and 
519+  self .ssl_context .verify_mode  ==  ssl_module .CERT_NONE  and 
520+  data  ==  b'N' ):
521+  # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, 
522+  # since the only way to get ssl_is_advisory is from 
523+  # sslmode=prefer (or sslmode=allow). But be extra sure to 
524+  # disallow insecure connections when the ssl context asks for 
525+  # real security. 
526+  self .on_data .set_result (False )
527+  else :
528+  self .on_data .set_exception (
529+  ConnectionError (
530+  'PostgreSQL server at "{host}:{port}" ' 
531+  'rejected SSL upgrade' .format (
532+  host = self .host , port = self .port )))
533+ 
534+  def  connection_lost (self , exc ):
535+  if  not  self .on_data .done ():
536+  if  exc  is  None :
537+  exc  =  ConnectionError ('unexpected connection_lost() call' )
538+  self .on_data .set_exception (exc )
539+ 
540+ 
541+ async  def  _create_ssl_connection (protocol_factory , host , port , * ,
542+  loop , ssl_context , ssl_is_advisory = False ):
543+ 
544+  if  ssl_context  is  True :
545+  ssl_context  =  ssl_module .create_default_context ()
546+ 
547+  tr , pr  =  await  loop .create_connection (
548+  lambda : TLSUpgradeProto (loop , host , port ,
549+  ssl_context , ssl_is_advisory ),
550+  host , port )
551+ 
552+  tr .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message. 
553+ 
554+  try :
555+  do_ssl_upgrade  =  await  pr .on_data 
556+  except  (Exception , asyncio .CancelledError ):
557+  tr .close ()
558+  raise 
559+ 
560+  if  hasattr (loop , 'start_tls' ):
561+  if  do_ssl_upgrade :
562+  try :
563+  new_tr  =  await  loop .start_tls (
564+  tr , pr , ssl_context , server_hostname = host )
565+  except  (Exception , asyncio .CancelledError ):
566+  tr .close ()
567+  raise 
568+  else :
569+  new_tr  =  tr 
570+ 
571+  pg_proto  =  protocol_factory ()
572+  pg_proto .connection_made (new_tr )
573+  new_tr .set_protocol (pg_proto )
574+ 
575+  return  new_tr , pg_proto 
576+  else :
577+  conn_factory  =  functools .partial (
578+  loop .create_connection , protocol_factory )
579+ 
580+  if  do_ssl_upgrade :
581+  conn_factory  =  functools .partial (
582+  conn_factory , ssl = ssl_context , server_hostname = host )
583+ 
584+  sock  =  _get_socket (tr )
585+  sock  =  sock .dup ()
586+  _set_nodelay (sock )
587+  tr .close ()
588+ 
589+  try :
590+  return  await  conn_factory (sock = sock )
591+  except  (Exception , asyncio .CancelledError ):
592+  sock .close ()
593+  raise 
594+ 
595+ 
507596async  def  _connect_addr (* , addr , loop , timeout , params , config ,
508597 connection_class ):
509598 assert  loop  is  not None 
@@ -526,8 +615,6 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
526615 else :
527616 connector  =  loop .create_connection (proto_factory , * addr )
528617
529-  connector  =  asyncio .ensure_future (connector )
530- 
531618 before  =  time .monotonic ()
532619 try :
533620 tr , pr  =  await  asyncio .wait_for (
@@ -575,79 +662,41 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
575662 raise  last_error 
576663
577664
578- async  def  _negotiate_ssl_connection (host , port , conn_factory , * , loop , ssl ,
579-  server_hostname , ssl_is_advisory = False ):
580-  # Note: ssl_is_advisory only affects behavior when the server does not 
581-  # accept SSLRequests. If the SSLRequest is accepted but either the SSL 
582-  # negotiation fails or the PostgreSQL user isn't permitted to use SSL, 
583-  # there's nothing that would attempt to reconnect with a non-SSL socket. 
584-  reader , writer  =  await  asyncio .open_connection (host , port )
585- 
586-  tr  =  writer .transport 
587-  try :
588-  sock  =  _get_socket (tr )
589-  _set_nodelay (sock )
590- 
591-  writer .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message. 
592-  await  writer .drain ()
593-  resp  =  await  reader .readexactly (1 )
594- 
595-  if  resp  ==  b'S' :
596-  conn_factory  =  functools .partial (
597-  conn_factory , ssl = ssl , server_hostname = server_hostname )
598-  elif  (ssl_is_advisory  and 
599-  ssl .verify_mode  ==  ssl_module .CERT_NONE  and 
600-  resp  ==  b'N' ):
601-  # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, 
602-  # since the only way to get ssl_is_advisory is from sslmode=prefer 
603-  # (or sslmode=allow). But be extra sure to disallow insecure 
604-  # connections when the ssl context asks for real security. 
605-  pass 
606-  else :
607-  raise  ConnectionError (
608-  'PostgreSQL server at "{}:{}" rejected SSL upgrade' .format (
609-  host , port ))
610- 
611-  sock  =  sock .dup () # Must come before tr.close() 
612-  finally :
613-  writer .close ()
614-  await  compat .wait_closed (writer )
615- 
616-  try :
617-  return  await  conn_factory (sock = sock ) # Must come after tr.close() 
618-  except  (Exception , asyncio .CancelledError ):
619-  sock .close ()
620-  raise 
665+ async  def  _cancel (* , loop , addr , params : _ConnectionParameters ,
666+  backend_pid , backend_secret ):
621667
668+  class  CancelProto (asyncio .Protocol ):
622669
623- async  def  _create_ssl_connection (protocol_factory , host , port , * ,
624-  loop , ssl_context , ssl_is_advisory = False ):
625-  return  await  _negotiate_ssl_connection (
626-  host , port ,
627-  functools .partial (loop .create_connection , protocol_factory ),
628-  loop = loop ,
629-  ssl = ssl_context ,
630-  server_hostname = host ,
631-  ssl_is_advisory = ssl_is_advisory )
670+  def  __init__ (self ):
671+  self .on_disconnect  =  _create_future (loop )
632672
673+  def  connection_lost (self , exc ):
674+  if  not  self .on_disconnect .done ():
675+  self .on_disconnect .set_result (True )
633676
634- async  def  _open_connection (* , loop , addr , params : _ConnectionParameters ):
635677 if  isinstance (addr , str ):
636-  r ,  w  =  await  asyncio . open_unix_connection ( addr )
678+  tr ,  pr  =  await  loop . create_unix_connection ( CancelProto ,  addr )
637679 else :
638680 if  params .ssl :
639-  r , w  =  await  _negotiate_ssl_connection (
681+  tr , pr  =  await  _create_ssl_connection (
682+  CancelProto ,
640683 * addr ,
641-  asyncio .open_connection ,
642684 loop = loop ,
643-  ssl = params .ssl ,
644-  server_hostname = addr [0 ],
685+  ssl_context = params .ssl ,
645686 ssl_is_advisory = params .ssl_is_advisory )
646687 else :
647-  r , w  =  await  asyncio .open_connection (* addr )
648-  _set_nodelay (_get_socket (w .transport ))
688+  tr , pr  =  await  loop .create_connection (
689+  CancelProto , * addr )
690+  _set_nodelay (_get_socket (tr ))
691+ 
692+  # Pack a CancelRequest message 
693+  msg  =  struct .pack ('!llll' , 16 , 80877102 , backend_pid , backend_secret )
649694
650-  return  r , w 
695+  try :
696+  tr .write (msg )
697+  await  pr .on_disconnect 
698+  finally :
699+  tr .close ()
651700
652701
653702def  _get_socket (transport ):
0 commit comments