@@ -504,6 +504,95 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
504504 return addrs , params , config
505505
506506
507+ class TLSUpgradeProto (asyncio .Protocol ):
508+ def __init__ (self , loop , host , port , ssl_context , ssl_is_advisory ):
509+ self .on_data = _create_future (loop )
510+ self .host = host
511+ self .port = port
512+ self .ssl_context = ssl_context
513+ self .ssl_is_advisory = ssl_is_advisory
514+
515+ def data_received (self , data ):
516+ if data == b'S' :
517+ self .on_data .set_result (True )
518+ elif (self .ssl_is_advisory and
519+ self .ssl_context .verify_mode == ssl_module .CERT_NONE and
520+ data == b'N' ):
521+ # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522+ # since the only way to get ssl_is_advisory is from
523+ # sslmode=prefer (or sslmode=allow). But be extra sure to
524+ # disallow insecure connections when the ssl context asks for
525+ # real security.
526+ self .on_data .set_result (False )
527+ else :
528+ self .on_data .set_exception (
529+ ConnectionError (
530+ 'PostgreSQL server at "{host}:{port}" '
531+ 'rejected SSL upgrade' .format (
532+ host = self .host , port = self .port )))
533+
534+ def connection_lost (self , exc ):
535+ if not self .on_data .done ():
536+ if exc is None :
537+ exc = ConnectionError ('unexpected connection_lost() call' )
538+ self .on_data .set_exception (exc )
539+
540+
541+ async def _create_ssl_connection (protocol_factory , host , port , * ,
542+ loop , ssl_context , ssl_is_advisory = False ):
543+
544+ if ssl_context is True :
545+ ssl_context = ssl_module .create_default_context ()
546+
547+ tr , pr = await loop .create_connection (
548+ lambda : TLSUpgradeProto (loop , host , port ,
549+ ssl_context , ssl_is_advisory ),
550+ host , port )
551+
552+ tr .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
553+
554+ try :
555+ do_ssl_upgrade = await pr .on_data
556+ except (Exception , asyncio .CancelledError ):
557+ tr .close ()
558+ raise
559+
560+ if hasattr (loop , 'start_tls' ):
561+ if do_ssl_upgrade :
562+ try :
563+ new_tr = await loop .start_tls (
564+ tr , pr , ssl_context , server_hostname = host )
565+ except (Exception , asyncio .CancelledError ):
566+ tr .close ()
567+ raise
568+ else :
569+ new_tr = tr
570+
571+ pg_proto = protocol_factory ()
572+ pg_proto .connection_made (new_tr )
573+ new_tr .set_protocol (pg_proto )
574+
575+ return new_tr , pg_proto
576+ else :
577+ conn_factory = functools .partial (
578+ loop .create_connection , protocol_factory )
579+
580+ if do_ssl_upgrade :
581+ conn_factory = functools .partial (
582+ conn_factory , ssl = ssl_context , server_hostname = host )
583+
584+ sock = _get_socket (tr )
585+ sock = sock .dup ()
586+ _set_nodelay (sock )
587+ tr .close ()
588+
589+ try :
590+ return await conn_factory (sock = sock )
591+ except (Exception , asyncio .CancelledError ):
592+ sock .close ()
593+ raise
594+
595+
507596async def _connect_addr (* , addr , loop , timeout , params , config ,
508597 connection_class ):
509598 assert loop is not None
@@ -526,8 +615,6 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
526615 else :
527616 connector = loop .create_connection (proto_factory , * addr )
528617
529- connector = asyncio .ensure_future (connector )
530-
531618 before = time .monotonic ()
532619 try :
533620 tr , pr = await asyncio .wait_for (
@@ -575,79 +662,41 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
575662 raise last_error
576663
577664
578- async def _negotiate_ssl_connection (host , port , conn_factory , * , loop , ssl ,
579- server_hostname , ssl_is_advisory = False ):
580- # Note: ssl_is_advisory only affects behavior when the server does not
581- # accept SSLRequests. If the SSLRequest is accepted but either the SSL
582- # negotiation fails or the PostgreSQL user isn't permitted to use SSL,
583- # there's nothing that would attempt to reconnect with a non-SSL socket.
584- reader , writer = await asyncio .open_connection (host , port )
585-
586- tr = writer .transport
587- try :
588- sock = _get_socket (tr )
589- _set_nodelay (sock )
590-
591- writer .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
592- await writer .drain ()
593- resp = await reader .readexactly (1 )
594-
595- if resp == b'S' :
596- conn_factory = functools .partial (
597- conn_factory , ssl = ssl , server_hostname = server_hostname )
598- elif (ssl_is_advisory and
599- ssl .verify_mode == ssl_module .CERT_NONE and
600- resp == b'N' ):
601- # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
602- # since the only way to get ssl_is_advisory is from sslmode=prefer
603- # (or sslmode=allow). But be extra sure to disallow insecure
604- # connections when the ssl context asks for real security.
605- pass
606- else :
607- raise ConnectionError (
608- 'PostgreSQL server at "{}:{}" rejected SSL upgrade' .format (
609- host , port ))
610-
611- sock = sock .dup () # Must come before tr.close()
612- finally :
613- writer .close ()
614- await compat .wait_closed (writer )
615-
616- try :
617- return await conn_factory (sock = sock ) # Must come after tr.close()
618- except (Exception , asyncio .CancelledError ):
619- sock .close ()
620- raise
665+ async def _cancel (* , loop , addr , params : _ConnectionParameters ,
666+ backend_pid , backend_secret ):
621667
668+ class CancelProto (asyncio .Protocol ):
622669
623- async def _create_ssl_connection (protocol_factory , host , port , * ,
624- loop , ssl_context , ssl_is_advisory = False ):
625- return await _negotiate_ssl_connection (
626- host , port ,
627- functools .partial (loop .create_connection , protocol_factory ),
628- loop = loop ,
629- ssl = ssl_context ,
630- server_hostname = host ,
631- ssl_is_advisory = ssl_is_advisory )
670+ def __init__ (self ):
671+ self .on_disconnect = _create_future (loop )
632672
673+ def connection_lost (self , exc ):
674+ if not self .on_disconnect .done ():
675+ self .on_disconnect .set_result (True )
633676
634- async def _open_connection (* , loop , addr , params : _ConnectionParameters ):
635677 if isinstance (addr , str ):
636- r , w = await asyncio . open_unix_connection ( addr )
678+ tr , pr = await loop . create_unix_connection ( CancelProto , addr )
637679 else :
638680 if params .ssl :
639- r , w = await _negotiate_ssl_connection (
681+ tr , pr = await _create_ssl_connection (
682+ CancelProto ,
640683 * addr ,
641- asyncio .open_connection ,
642684 loop = loop ,
643- ssl = params .ssl ,
644- server_hostname = addr [0 ],
685+ ssl_context = params .ssl ,
645686 ssl_is_advisory = params .ssl_is_advisory )
646687 else :
647- r , w = await asyncio .open_connection (* addr )
648- _set_nodelay (_get_socket (w .transport ))
688+ tr , pr = await loop .create_connection (
689+ CancelProto , * addr )
690+ _set_nodelay (_get_socket (tr ))
691+
692+ # Pack a CancelRequest message
693+ msg = struct .pack ('!llll' , 16 , 80877102 , backend_pid , backend_secret )
649694
650- return r , w
695+ try :
696+ tr .write (msg )
697+ await pr .on_disconnect
698+ finally :
699+ tr .close ()
651700
652701
653702def _get_socket (transport ):
0 commit comments