Skip to content

Commit a227198

Browse files
committed
PYTHON-1613 Invalidate cache on changed salt or iterations
1 parent 47b0d8e commit a227198

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

pymongo/auth.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,19 @@ def _authenticate_scram(credentials, sock_info, mechanism):
272272
raise OperationFailure("Server returned an invalid nonce.")
273273

274274
without_proof = b"c=biws,r=" + rnonce
275-
keys = cache.data
276-
if keys:
277-
client_key, server_key = keys
275+
if cache.data:
276+
client_key, server_key, csalt, citerations = cache.data
278277
else:
278+
client_key, server_key, csalt, citerations = None, None, None, None
279+
280+
# Salt and / or iterations could change for a number of different
281+
# reasons. Either changing invalidates the cache.
282+
if not client_key or salt != csalt or iterations != citerations:
279283
salted_pass = _hi(
280284
digest, data, standard_b64decode(salt), iterations)
281285
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
282286
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
283-
cache.data = (client_key, server_key)
287+
cache.data = (client_key, server_key, salt, iterations)
284288
stored_key = digestmod(client_key).digest()
285289
auth_msg = b",".join((first_bare, server_first, without_proof))
286290
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()

test/test_auth.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -584,11 +584,14 @@ def test_cache(self):
584584
credentials = all_credentials.get('admin')
585585
cache = credentials.cache
586586
self.assertIsNotNone(cache)
587-
keys = cache.data
588-
self.assertIsNotNone(keys)
589-
self.assertEqual(len(keys), 2)
590-
for elt in keys:
591-
self.assertIsInstance(elt, bytes)
587+
data = cache.data
588+
self.assertIsNotNone(data)
589+
self.assertEqual(len(data), 4)
590+
ckey, skey, salt, iterations = data
591+
self.assertIsInstance(ckey, bytes)
592+
self.assertIsInstance(skey, bytes)
593+
self.assertIsInstance(salt, bytes)
594+
self.assertIsInstance(iterations, int)
592595

593596
pool = next(iter(client._topology._servers.values()))._pool
594597
with pool.get_socket(all_credentials) as sock_info:
@@ -601,7 +604,7 @@ def test_cache(self):
601604
sock_credentials = next(iter(authset))
602605
sock_cache = sock_credentials.cache
603606
self.assertIsNotNone(sock_cache)
604-
self.assertEqual(sock_cache.data, keys)
607+
self.assertEqual(sock_cache.data, data)
605608

606609
def test_scram_threaded(self):
607610

0 commit comments

Comments
 (0)