Skip to content

Commit dfb84f9

Browse files
committed
Automatic per-socket authentication now occurs in Connection after initial Database.authenticate call
1 parent d425cbf commit dfb84f9

File tree

5 files changed

+68
-111
lines changed

5 files changed

+68
-111
lines changed

pymongo/connection.py

Lines changed: 27 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,18 @@ class _Pool(threading.local):
135135
"""
136136

137137
# Non thread-locals
138-
__slots__ = ["sockets", "socket_factory", "pool_size"]
138+
__slots__ = ["sockets", "socket_factory", "pool_size",
139+
"connection", "auth_credentials"]
139140
sock = None
140141

141-
def __init__(self, socket_factory):
142+
def __init__(self, socket_factory, connection):
142143
self.pool_size = 10
143144
self.socket_factory = socket_factory
145+
self.connection = connection
144146
if not hasattr(self, "sockets"):
145147
self.sockets = []
148+
if not hasattr(self, "auth_credentials"):
149+
self.auth_credentials = {}
146150

147151
def socket(self):
148152
# we store the pid here to avoid issues with fork /
@@ -159,6 +163,15 @@ def socket(self):
159163
except IndexError:
160164
self.sock = (pid, self.socket_factory())
161165

166+
# Authenticate new socket for known DBs, 'admin' by preference
167+
if 'admin' in self.auth_credentials:
168+
username, password = self.auth_credentials['admin']
169+
self.connection['admin'].authenticate(username, password)
170+
else:
171+
# Authenticate against all known databases
172+
for db_name, (u, p) in self.auth_credentials.items():
173+
self.connection[db_name].authenticate(u, p)
174+
162175
return self.sock[1]
163176

164177
def return_socket(self):
@@ -172,8 +185,12 @@ def return_socket(self):
172185
self.sock[1].close()
173186
self.sock = None
174187

175-
def socket_ids(self):
176-
return [id(sock) for sock in self.sockets]
188+
def add_db_auth(self, db_name, username, password):
189+
self.auth_credentials[db_name] = (username, password)
190+
191+
def remove_db_auth(self, db_name):
192+
if db_name in self.auth_credentials:
193+
del(self.auth_credentials[db_name])
177194

178195

179196
class Connection(object):
@@ -294,7 +311,7 @@ def __init__(self, host=None, port=None, pool_size=None,
294311

295312
self.__cursor_manager = CursorManager(self)
296313

297-
self.__pool = _Pool(self.__connect)
314+
self.__pool = _Pool(self.__connect, self)
298315
self.__last_checkout = time.time()
299316

300317
self.__network_timeout = network_timeout
@@ -307,15 +324,10 @@ def __init__(self, host=None, port=None, pool_size=None,
307324
if _connect:
308325
self.__find_master()
309326

310-
# cache of auth username/password credential keyed by DB name
311-
self.__auth_credentials = {}
312-
self.__sock_auths_by_id = {}
313327
if username:
314328
database = database or "admin"
315329
if not self[database].authenticate(username, password):
316330
raise ConfigurationError("authentication failed")
317-
# Add database auth credentials for auto-auth later
318-
self.add_db_auth(database, username, password)
319331

320332
@classmethod
321333
def from_uri(cls, uri="mongodb://localhost", **connection_args):
@@ -569,7 +581,7 @@ def disconnect(self):
569581
.. seealso:: :meth:`end_request`
570582
.. versionadded:: 1.3
571583
"""
572-
self.__pool = _Pool(self.__connect)
584+
self.__pool = _Pool(self.__connect, self)
573585
self.__host = None
574586
self.__port = None
575587

@@ -622,30 +634,6 @@ def __check_response_to_last_error(self, response):
622634
else:
623635
raise OperationFailure(error["err"])
624636

625-
def _authenticate_socket_for_db(self, sock, db_name):
626-
# Periodically remove cached auth flags of expired sockets
627-
if len(self.__sock_auths_by_id) > self.pool_size:
628-
cached_sock_ids = self.__sock_auths_by_id.keys()
629-
current_sock_ids = self.__pool.socket_ids()
630-
for sock_id in cached_sock_ids:
631-
if not sock_id in current_sock_ids:
632-
del(self.__sock_auths_by_id[sock_id])
633-
if not self.__auth_credentials:
634-
return # No credentials for any database
635-
sock_id = id(sock)
636-
if db_name in self.__sock_auths_by_id.get(sock_id, {}):
637-
return # Already authenticated for database
638-
if not self.has_db_auth(db_name):
639-
return # No credentials for database
640-
username, password = self.get_db_auth(db_name)
641-
if not self[db_name].authenticate(username, password):
642-
raise ConfigurationError("authentication to db %s failed for %s"
643-
% (db_name, username))
644-
if not sock_id in self.__sock_auths_by_id:
645-
self.__sock_auths_by_id[sock_id] = {}
646-
self.__sock_auths_by_id[sock_id][db_name] = 1
647-
return True
648-
649637
def _send_message(self, message, with_last_error=False,
650638
collection_name=None):
651639
"""Say something to Mongo.
@@ -663,14 +651,6 @@ def _send_message(self, message, with_last_error=False,
663651
"""
664652
sock = self.__socket()
665653
try:
666-
# Always authenticate for admin database, if possible
667-
if self._authenticate_socket_for_db(sock, 'admin'):
668-
pass # No need for futher auth with admin login
669-
elif collection_name and collection_name.split('.') >= 1:
670-
# Authenticate for specific database
671-
db_name = collection_name.split('.')[0]
672-
self._authenticate_socket_for_db(sock, db_name)
673-
674654
(request_id, data) = message
675655
sock.sendall(data)
676656
# Safe mode. We pack the message together with a lastError
@@ -928,28 +908,8 @@ def __iter__(self):
928908
def next(self):
929909
raise TypeError("'Connection' object is not iterable")
930910

931-
def add_db_auth(self, db_name, username, password):
932-
if not username or not isinstance(username, basestring):
933-
raise ConfigurationError('invalid username')
934-
if not password or not isinstance(password, basestring):
935-
raise ConfigurationError('invalid password')
936-
self.__auth_credentials[db_name] = (username, password)
937-
938-
def has_db_auth(self, db_name):
939-
return db_name in self.__auth_credentials
911+
def _add_db_auth(self, db_name, username, password):
912+
self.__pool.add_db_auth(db_name, username, password)
940913

941-
def get_db_auth(self, db_name):
942-
if self.has_db_auth(db_name):
943-
return self.__auth_credentials[db_name]
944-
return None
945-
946-
def remove_db_auth(self, db_name):
947-
if self.has_db_auth(db_name):
948-
del(self.__auth_credentials[db_name])
949-
# Force close any existing sockets to flush auths
950-
self.disconnect()
951-
952-
def clear_db_auths(self):
953-
self.__auth_credentials = {} # Forget all credentials
954-
# Force close any existing sockets to flush auths
955-
self.disconnect()
914+
def _remove_db_auth(self, db_name):
915+
self.__pool.remove_db_auth(db_name)

pymongo/database.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -464,28 +464,28 @@ def authenticate(self, name, password):
464464
Once authenticated, the user has full read and write access to
465465
this database. Raises :class:`TypeError` if either `name` or
466466
`password` is not an instance of ``(str,
467-
unicode)``. Authentication lasts for the life of the database
468-
connection, or until :meth:`logout` is called.
467+
unicode)``. Authentication lasts for the life of the underlying
468+
:class:`Connection`, or until :meth:`logout` is called.
469469
470470
The "admin" database is special. Authenticating on "admin"
471471
gives access to *all* databases. Effectively, "admin" access
472472
means root access to the database.
473473
474-
.. note:: Currently, authentication is per
475-
:class:`~socket.socket`. This means that there are a couple
476-
of situations in which re-authentication is necessary:
477-
478-
- On failover (when an
479-
:class:`~pymongo.errors.AutoReconnect` exception is
480-
raised).
481-
482-
- After a call to
483-
:meth:`~pymongo.connection.Connection.disconnect` or
484-
:meth:`~pymongo.connection.Connection.end_request`.
474+
.. note:: This method authenticates the current connection, and
475+
will also cause all new :class:`~socket.socket` connections
476+
in the underlying :class:`~pymongo.connection.Connection` to
477+
be authenticated automatically.
485478
486479
- When sharing a :class:`~pymongo.connection.Connection`
487-
between multiple threads, each thread will need to
488-
authenticate separately.
480+
between multiple threads, all threads will share the
481+
authentication. If you need different authentication profiles
482+
for different purposes (e.g. admin users) you must use
483+
distinct :class:`~pymongo.connection.Connection`s.
484+
485+
- To get authentication to apply immediately to all
486+
connections including existing ones, you may need to
487+
reset the connections sockets using
488+
:meth:`~pymongo.connection.Connection.disconnect`.
489489
490490
.. warning:: Currently, calls to
491491
:meth:`~pymongo.connection.Connection.end_request` will
@@ -511,16 +511,23 @@ def authenticate(self, name, password):
511511
try:
512512
self.command("authenticate", user=unicode(name),
513513
nonce=nonce, key=key)
514+
self.connection._add_db_auth(self.name, unicode(name),
515+
unicode(password))
514516
return True
515517
except OperationFailure:
516518
return False
517519

518520
def logout(self):
519-
"""Deauthorize use of this database for this connection.
521+
"""Deauthorize use of this database for this connection and
522+
future connections.
520523
521-
Note that other databases may still be authorized.
524+
Note that other databases may still be authenticated, and that other
525+
existing :class:`~socket.socket` connections may remain
526+
authenticated unless you reset all sockets with
527+
:meth:`~pymongo.connection.Connection.disconnect`.
522528
"""
523529
self.command("logout")
530+
self.connection._remove_db_auth(self.name)
524531

525532
def dereference(self, dbref):
526533
"""Dereference a DBRef, getting the SON object it points to.

test/test_connection.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def test_parse_uri(self):
270270
self.assertEqual(([("localhost", 27017)], None, None, None),
271271
_parse_uri("localhost/", 27017))
272272

273-
def test_from_uri(self):
273+
def test_auth_from_uri(self):
274274
c = Connection(self.host, self.port)
275275

276276
self.assertRaises(InvalidURI, Connection, "mongodb://localhost/baz")
@@ -450,7 +450,7 @@ def test_tz_aware(self):
450450
self.assertEqual(aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None),
451451
naive.pymongo_test.test.find_one()["x"])
452452

453-
def test_auto_db_authentication(self):
453+
def test_auth_from_database(self):
454454
conn = Connection(self.host, self.port)
455455

456456
# Setup admin user
@@ -465,11 +465,6 @@ def test_auto_db_authentication(self):
465465

466466
conn.pymongo_test.drop_collection("test")
467467

468-
self.assertRaises(TypeError, conn.add_db_auth, "", "password")
469-
self.assertRaises(TypeError, conn.add_db_auth, 5, "password")
470-
self.assertRaises(TypeError, conn.add_db_auth, "test-user", "")
471-
self.assertRaises(TypeError, conn.add_db_auth, "test-user", 5)
472-
473468
# Not yet logged in
474469
conn = Connection(self.host, self.port)
475470
try:
@@ -482,53 +477,48 @@ def test_auto_db_authentication(self):
482477
# Not yet logged in
483478
conn = Connection(self.host, self.port)
484479
self.assertRaises(OperationFailure, conn.pymongo_test.test.count)
485-
self.assertFalse(conn.has_db_auth('admin'))
486-
self.assertEquals(None, conn.get_db_auth('admin'))
487480

488481
# Admin log in via URI
489482
conn = Connection('admin-user:password@%s' % self.host, self.port)
490-
self.assertTrue(conn.has_db_auth('admin'))
491-
self.assertEquals('admin-user', conn.get_db_auth('admin')[0])
492483
conn.admin.system.users.find()
493484
conn.pymongo_test.test.insert({'_id':1, 'test':'data'}, safe=True)
494485
self.assertEquals(1, conn.pymongo_test.test.find({'_id':1}).count())
495486
conn.pymongo_test.test.remove({'_id':1})
496487

497-
# Clear and reset database authentication for all sockets
498-
conn.clear_db_auths()
499-
self.assertFalse(conn.has_db_auth('admin'))
488+
# Logout admin
489+
conn.admin.logout()
500490
self.assertRaises(OperationFailure, conn.pymongo_test.test.count)
501491

502-
# Admin log in via add_db_auth
492+
# Admin log in via Database.authenticate
503493
conn = Connection(self.host, self.port)
504494
conn.admin.system.users.find()
505-
conn.add_db_auth('admin', 'admin-user', 'password')
495+
conn.admin.authenticate('admin-user', 'password')
506496
conn.pymongo_test.test.insert({'_id':2, 'test':'data'}, safe=True)
507497
self.assertEquals(1, conn.pymongo_test.test.find({'_id':2}).count())
508498
conn.pymongo_test.test.remove({'_id':2})
509499

510500
# Remove database authentication for specific database
511-
self.assertTrue(conn.has_db_auth('admin'))
512-
conn.remove_db_auth('admin')
513-
self.assertFalse(conn.has_db_auth('admin'))
501+
conn.admin.logout()
514502
self.assertRaises(OperationFailure, conn.pymongo_test.test.count)
515503

516504
# Incorrect admin credentials
517505
conn = Connection(self.host, self.port)
518-
conn.add_db_auth('admin', 'admin-user', 'wrong-password')
506+
self.assertFalse(
507+
conn.admin.authenticate('admin-user', 'wrong-password'))
519508
self.assertRaises(OperationFailure, conn.pymongo_test.test.count)
520509

521510
# Database-specific log in
522511
conn = Connection(self.host, self.port)
523-
conn.add_db_auth('pymongo_test', 'test-user', 'password')
524-
self.assertRaises(OperationFailure, conn.admin.system.users.find_one)
512+
conn.pymongo_test.authenticate('test-user', 'password')
513+
self.assertRaises(OperationFailure,
514+
conn.admin.system.users.find_one)
525515
conn.pymongo_test.test.insert({'_id':3, 'test':'data'}, safe=True)
526516
self.assertEquals(1, conn.pymongo_test.test.find({'_id':3}).count())
527517
conn.pymongo_test.test.remove({'_id':3})
528518

529519
# Incorrect database credentials
530520
conn = Connection(self.host, self.port)
531-
conn.add_db_auth('pymongo_test', 'wrong-user', 'password')
521+
conn.pymongo_test.authenticate('wrong-user', 'password')
532522
self.assertRaises(OperationFailure, conn.pymongo_test.test.find_one)
533523
finally:
534524
# Remove auth users from databases

test/test_pooling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ def test_disconnect(self):
141141
run_cases(self, [SaveAndFind, Disconnect, Unique])
142142

143143
def test_independent_pools(self):
144-
p = _Pool(None)
144+
p = _Pool(None, None)
145145
self.assertEqual([], p.sockets)
146146
self.c.end_request()
147147
self.assertEqual([], p.sockets)
148148

149149
# Sensical values aren't really important here
150-
p1 = _Pool(5)
150+
p1 = _Pool(5, 32)
151151
self.assertEqual(None, p.socket_factory)
152152
self.assertEqual(5, p1.socket_factory)
153153

test/test_threads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_auto_auth_login(self):
229229

230230
# Admin auth
231231
conn = get_connection()
232-
conn.add_db_auth("admin", "admin-user", "password")
232+
conn.admin.authenticate("admin-user", "password")
233233

234234
threads = []
235235
for _ in range(10):
@@ -242,7 +242,7 @@ def test_auto_auth_login(self):
242242

243243
# Database-specific auth
244244
conn = get_connection()
245-
conn.add_db_auth("auth_test", "test-user", "password")
245+
conn.auth_test.authenticate("test-user", "password")
246246

247247
threads = []
248248
for _ in range(10):

0 commit comments

Comments
 (0)