@@ -135,14 +135,18 @@ class _Pool(threading.local):
135
135
"""
136
136
137
137
# Non thread-locals
138
- __slots__ = ["sockets" , "socket_factory" , "pool_size" ]
138
+ __slots__ = ["sockets" , "socket_factory" , "pool_size" ,
139
+ "connection" , "auth_credentials" ]
139
140
sock = None
140
141
141
- def __init__ (self , socket_factory ):
142
+ def __init__ (self , socket_factory , connection ):
142
143
self .pool_size = 10
143
144
self .socket_factory = socket_factory
145
+ self .connection = connection
144
146
if not hasattr (self , "sockets" ):
145
147
self .sockets = []
148
+ if not hasattr (self , "auth_credentials" ):
149
+ self .auth_credentials = {}
146
150
147
151
def socket (self ):
148
152
# we store the pid here to avoid issues with fork /
@@ -159,6 +163,15 @@ def socket(self):
159
163
except IndexError :
160
164
self .sock = (pid , self .socket_factory ())
161
165
166
+ # Authenticate new socket for known DBs, 'admin' by preference
167
+ if 'admin' in self .auth_credentials :
168
+ username , password = self .auth_credentials ['admin' ]
169
+ self .connection ['admin' ].authenticate (username , password )
170
+ else :
171
+ # Authenticate against all known databases
172
+ for db_name , (u , p ) in self .auth_credentials .items ():
173
+ self .connection [db_name ].authenticate (u , p )
174
+
162
175
return self .sock [1 ]
163
176
164
177
def return_socket (self ):
@@ -172,8 +185,12 @@ def return_socket(self):
172
185
self .sock [1 ].close ()
173
186
self .sock = None
174
187
175
- def socket_ids (self ):
176
- return [id (sock ) for sock in self .sockets ]
188
+ def add_db_auth (self , db_name , username , password ):
189
+ self .auth_credentials [db_name ] = (username , password )
190
+
191
+ def remove_db_auth (self , db_name ):
192
+ if db_name in self .auth_credentials :
193
+ del (self .auth_credentials [db_name ])
177
194
178
195
179
196
class Connection (object ):
@@ -294,7 +311,7 @@ def __init__(self, host=None, port=None, pool_size=None,
294
311
295
312
self .__cursor_manager = CursorManager (self )
296
313
297
- self .__pool = _Pool (self .__connect )
314
+ self .__pool = _Pool (self .__connect , self )
298
315
self .__last_checkout = time .time ()
299
316
300
317
self .__network_timeout = network_timeout
@@ -307,15 +324,10 @@ def __init__(self, host=None, port=None, pool_size=None,
307
324
if _connect :
308
325
self .__find_master ()
309
326
310
- # cache of auth username/password credential keyed by DB name
311
- self .__auth_credentials = {}
312
- self .__sock_auths_by_id = {}
313
327
if username :
314
328
database = database or "admin"
315
329
if not self [database ].authenticate (username , password ):
316
330
raise ConfigurationError ("authentication failed" )
317
- # Add database auth credentials for auto-auth later
318
- self .add_db_auth (database , username , password )
319
331
320
332
@classmethod
321
333
def from_uri (cls , uri = "mongodb://localhost" , ** connection_args ):
@@ -569,7 +581,7 @@ def disconnect(self):
569
581
.. seealso:: :meth:`end_request`
570
582
.. versionadded:: 1.3
571
583
"""
572
- self .__pool = _Pool (self .__connect )
584
+ self .__pool = _Pool (self .__connect , self )
573
585
self .__host = None
574
586
self .__port = None
575
587
@@ -622,30 +634,6 @@ def __check_response_to_last_error(self, response):
622
634
else :
623
635
raise OperationFailure (error ["err" ])
624
636
625
- def _authenticate_socket_for_db (self , sock , db_name ):
626
- # Periodically remove cached auth flags of expired sockets
627
- if len (self .__sock_auths_by_id ) > self .pool_size :
628
- cached_sock_ids = self .__sock_auths_by_id .keys ()
629
- current_sock_ids = self .__pool .socket_ids ()
630
- for sock_id in cached_sock_ids :
631
- if not sock_id in current_sock_ids :
632
- del (self .__sock_auths_by_id [sock_id ])
633
- if not self .__auth_credentials :
634
- return # No credentials for any database
635
- sock_id = id (sock )
636
- if db_name in self .__sock_auths_by_id .get (sock_id , {}):
637
- return # Already authenticated for database
638
- if not self .has_db_auth (db_name ):
639
- return # No credentials for database
640
- username , password = self .get_db_auth (db_name )
641
- if not self [db_name ].authenticate (username , password ):
642
- raise ConfigurationError ("authentication to db %s failed for %s"
643
- % (db_name , username ))
644
- if not sock_id in self .__sock_auths_by_id :
645
- self .__sock_auths_by_id [sock_id ] = {}
646
- self .__sock_auths_by_id [sock_id ][db_name ] = 1
647
- return True
648
-
649
637
def _send_message (self , message , with_last_error = False ,
650
638
collection_name = None ):
651
639
"""Say something to Mongo.
@@ -663,14 +651,6 @@ def _send_message(self, message, with_last_error=False,
663
651
"""
664
652
sock = self .__socket ()
665
653
try :
666
- # Always authenticate for admin database, if possible
667
- if self ._authenticate_socket_for_db (sock , 'admin' ):
668
- pass # No need for futher auth with admin login
669
- elif collection_name and collection_name .split ('.' ) >= 1 :
670
- # Authenticate for specific database
671
- db_name = collection_name .split ('.' )[0 ]
672
- self ._authenticate_socket_for_db (sock , db_name )
673
-
674
654
(request_id , data ) = message
675
655
sock .sendall (data )
676
656
# Safe mode. We pack the message together with a lastError
@@ -928,28 +908,8 @@ def __iter__(self):
928
908
def next (self ):
929
909
raise TypeError ("'Connection' object is not iterable" )
930
910
931
- def add_db_auth (self , db_name , username , password ):
932
- if not username or not isinstance (username , basestring ):
933
- raise ConfigurationError ('invalid username' )
934
- if not password or not isinstance (password , basestring ):
935
- raise ConfigurationError ('invalid password' )
936
- self .__auth_credentials [db_name ] = (username , password )
937
-
938
- def has_db_auth (self , db_name ):
939
- return db_name in self .__auth_credentials
911
+ def _add_db_auth (self , db_name , username , password ):
912
+ self .__pool .add_db_auth (db_name , username , password )
940
913
941
- def get_db_auth (self , db_name ):
942
- if self .has_db_auth (db_name ):
943
- return self .__auth_credentials [db_name ]
944
- return None
945
-
946
- def remove_db_auth (self , db_name ):
947
- if self .has_db_auth (db_name ):
948
- del (self .__auth_credentials [db_name ])
949
- # Force close any existing sockets to flush auths
950
- self .disconnect ()
951
-
952
- def clear_db_auths (self ):
953
- self .__auth_credentials = {} # Forget all credentials
954
- # Force close any existing sockets to flush auths
955
- self .disconnect ()
914
+ def _remove_db_auth (self , db_name ):
915
+ self .__pool .remove_db_auth (db_name )
0 commit comments