@@ -217,14 +217,15 @@ class PoolOptions(object):
217
217
'__connect_timeout' , '__socket_timeout' ,
218
218
'__wait_queue_timeout' , '__wait_queue_multiple' ,
219
219
'__ssl_context' , '__ssl_match_hostname' , '__socket_keepalive' ,
220
- '__event_listeners' , '__appname' , '__metadata' )
220
+ '__event_listeners' , '__appname' , '__metadata' ,
221
+ '__handshake_callback' )
221
222
222
223
def __init__ (self , max_pool_size = 100 , min_pool_size = 0 ,
223
224
max_idle_time_ms = None , connect_timeout = None ,
224
225
socket_timeout = None , wait_queue_timeout = None ,
225
226
wait_queue_multiple = None , ssl_context = None ,
226
227
ssl_match_hostname = True , socket_keepalive = False ,
227
- event_listeners = None , appname = None ):
228
+ event_listeners = None , appname = None , handshake_callback = None ):
228
229
229
230
self .__max_pool_size = max_pool_size
230
231
self .__min_pool_size = min_pool_size
@@ -242,6 +243,27 @@ def __init__(self, max_pool_size=100, min_pool_size=0,
242
243
if appname :
243
244
self .__metadata ['application' ] = {'name' : appname }
244
245
246
+ self .__handshake_callback = handshake_callback
247
+
248
+ def with_options (self , ** kwargs ):
249
+ options = {
250
+ 'max_pool_size' : self .max_pool_size ,
251
+ 'min_pool_size' : self .min_pool_size ,
252
+ 'max_idle_time_ms' : self .max_idle_time_ms ,
253
+ 'connect_timeout' : self .connect_timeout ,
254
+ 'socket_timeout' : self .socket_timeout ,
255
+ 'wait_queue_timeout' : self .wait_queue_timeout ,
256
+ 'wait_queue_multiple' : self .wait_queue_multiple ,
257
+ 'ssl_context' : self .ssl_context ,
258
+ 'ssl_match_hostname' : self .ssl_match_hostname ,
259
+ 'socket_keepalive' : self .socket_keepalive ,
260
+ 'event_listeners' : self .event_listeners ,
261
+ 'appname' : self .appname ,
262
+ 'handshake_callback' : self .handshake_callback }
263
+
264
+ options .update (kwargs )
265
+ return PoolOptions (** options )
266
+
245
267
@property
246
268
def max_pool_size (self ):
247
269
"""The maximum allowable number of concurrent connections to each
@@ -335,6 +357,11 @@ def metadata(self):
335
357
"""
336
358
return self .__metadata .copy ()
337
359
360
+ @property
361
+ def handshake_callback (self ):
362
+ """Receives an ismaster reply and updates the topology."""
363
+ return self .__handshake_callback
364
+
338
365
339
366
class SocketInfo (object ):
340
367
"""Store a socket with some metadata.
@@ -746,6 +773,8 @@ def connect(self):
746
773
('ismaster' , 1 ),
747
774
('client' , self .opts .metadata )
748
775
])
776
+
777
+ start = _time ()
749
778
ismaster = IsMaster (
750
779
command (sock ,
751
780
'admin' ,
@@ -754,13 +783,20 @@ def connect(self):
754
783
False ,
755
784
ReadPreference .PRIMARY ,
756
785
DEFAULT_CODEC_OPTIONS ))
786
+
787
+ # Can raise ConnectionFailure.
788
+ self ._handshake_callback (ismaster , _time () - start )
757
789
else :
758
790
ismaster = None
759
791
return SocketInfo (sock , self , ismaster , self .address )
760
792
except socket .error as error :
761
793
if sock is not None :
762
794
sock .close ()
763
795
_raise_connection_failure (self .address , error )
796
+ except :
797
+ if sock is not None :
798
+ sock .close ()
799
+ raise
764
800
765
801
@contextlib .contextmanager
766
802
def get_socket (self , all_credentials , checkout = False ):
@@ -889,6 +925,14 @@ def _check(self, sock_info):
889
925
else :
890
926
return self .connect ()
891
927
928
+ def _handshake_callback (self , ismaster , round_trip_time ):
929
+ callback = self .opts .handshake_callback
930
+ if callback :
931
+ kept = callback (self .address , ismaster , round_trip_time )
932
+ if not kept :
933
+ _raise_connection_failure (
934
+ self .address , "server removed from topology" )
935
+
892
936
def _raise_wait_queue_timeout (self ):
893
937
raise ConnectionFailure (
894
938
'Timed out waiting for socket from pool with max_size %r and'
0 commit comments