Skip to content

Commit 7ee852a

Browse files
committed
PYTHON-1070 - Make index cache thread safe
1 parent 241a898 commit 7ee852a

File tree

3 files changed

+95
-25
lines changed

3 files changed

+95
-25
lines changed

pymongo/collection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,12 @@ def ensure_index(self, key_or_list, cache_for=300, **kwargs):
14021402
keys = helpers._index_list(key_or_list)
14031403
name = kwargs.setdefault("name", helpers._gen_index_name(keys))
14041404

1405+
# Note that there is a race condition here. One thread could
1406+
# check if the index is cached and be preempted before creating
1407+
# and caching the index. This means multiple threads attempting
1408+
# to create the same index concurrently could send the index
1409+
# to the server two or more times. This has no practical impact
1410+
# other than wasted round trips.
14051411
if not self.__database.client._cached(self.__database.name,
14061412
self.__name, name):
14071413
self.__create_index(keys, kwargs)

pymongo/mongo_client.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def __init__(
353353

354354
# Cache of existing indexes used by ensure_index ops.
355355
self.__index_cache = {}
356+
self.__index_cache_lock = threading.Lock()
356357

357358
super(MongoClient, self).__init__(options.codec_options,
358359
options.read_preference,
@@ -433,27 +434,29 @@ def _cached(self, dbname, coll, index):
433434
"""Test if `index` is cached."""
434435
cache = self.__index_cache
435436
now = datetime.datetime.utcnow()
436-
return (dbname in cache and
437-
coll in cache[dbname] and
438-
index in cache[dbname][coll] and
439-
now < cache[dbname][coll][index])
437+
with self.__index_cache_lock:
438+
return (dbname in cache and
439+
coll in cache[dbname] and
440+
index in cache[dbname][coll] and
441+
now < cache[dbname][coll][index])
440442

441443
def _cache_index(self, dbname, collection, index, cache_for):
442444
"""Add an index to the index cache for ensure_index operations."""
443445
now = datetime.datetime.utcnow()
444446
expire = datetime.timedelta(seconds=cache_for) + now
445447

446-
if database not in self.__index_cache:
447-
self.__index_cache[dbname] = {}
448-
self.__index_cache[dbname][collection] = {}
449-
self.__index_cache[dbname][collection][index] = expire
448+
with self.__index_cache_lock:
449+
if database not in self.__index_cache:
450+
self.__index_cache[dbname] = {}
451+
self.__index_cache[dbname][collection] = {}
452+
self.__index_cache[dbname][collection][index] = expire
450453

451-
elif collection not in self.__index_cache[dbname]:
452-
self.__index_cache[dbname][collection] = {}
453-
self.__index_cache[dbname][collection][index] = expire
454+
elif collection not in self.__index_cache[dbname]:
455+
self.__index_cache[dbname][collection] = {}
456+
self.__index_cache[dbname][collection][index] = expire
454457

455-
else:
456-
self.__index_cache[dbname][collection][index] = expire
458+
else:
459+
self.__index_cache[dbname][collection][index] = expire
457460

458461
def _purge_index(self, database_name,
459462
collection_name=None, index_name=None):
@@ -463,22 +466,23 @@ def _purge_index(self, database_name,
463466
464467
If `collection_name` is None purge an entire database.
465468
"""
466-
if not database_name in self.__index_cache:
467-
return
469+
with self.__index_cache_lock:
470+
if not database_name in self.__index_cache:
471+
return
468472

469-
if collection_name is None:
470-
del self.__index_cache[database_name]
471-
return
473+
if collection_name is None:
474+
del self.__index_cache[database_name]
475+
return
472476

473-
if not collection_name in self.__index_cache[database_name]:
474-
return
477+
if not collection_name in self.__index_cache[database_name]:
478+
return
475479

476-
if index_name is None:
477-
del self.__index_cache[database_name][collection_name]
478-
return
480+
if index_name is None:
481+
del self.__index_cache[database_name][collection_name]
482+
return
479483

480-
if index_name in self.__index_cache[database_name][collection_name]:
481-
del self.__index_cache[database_name][collection_name][index_name]
484+
if index_name in self.__index_cache[database_name][collection_name]:
485+
del self.__index_cache[database_name][collection_name][index_name]
482486

483487
def _server_property(self, attr_name):
484488
"""An attribute of the current server's description.

test/test_legacy_api.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,66 @@ def test_ensure_index(self):
11181118
# Clean up indexes for later tests
11191119
db.test.drop_indexes()
11201120

1121+
def test_ensure_index_threaded(self):
1122+
coll = self.db.threaded_index_creation
1123+
index_docs = []
1124+
1125+
class Indexer(threading.Thread):
1126+
def run(self):
1127+
coll.ensure_index('foo0')
1128+
coll.ensure_index('foo1')
1129+
coll.ensure_index('foo2')
1130+
index_docs.append(coll.index_information())
1131+
1132+
try:
1133+
threads = []
1134+
for _ in range(10):
1135+
t = Indexer()
1136+
t.setDaemon(True)
1137+
threads.append(t)
1138+
1139+
for thread in threads:
1140+
thread.start()
1141+
1142+
joinall(threads)
1143+
1144+
first = index_docs[0]
1145+
for index_doc in index_docs[1:]:
1146+
self.assertEqual(index_doc, first)
1147+
finally:
1148+
coll.drop()
1149+
1150+
def test_ensure_purge_index_threaded(self):
1151+
coll = self.db.threaded_index_creation
1152+
1153+
class Indexer(threading.Thread):
1154+
def run(self):
1155+
coll.ensure_index('foo')
1156+
try:
1157+
coll.drop_index('foo')
1158+
except OperationFailure:
1159+
# The index may have already been dropped.
1160+
pass
1161+
coll.ensure_index('foo')
1162+
coll.drop_indexes()
1163+
coll.ensure_index('foo')
1164+
1165+
try:
1166+
threads = []
1167+
for _ in range(10):
1168+
t = Indexer()
1169+
t.setDaemon(True)
1170+
threads.append(t)
1171+
1172+
for thread in threads:
1173+
thread.start()
1174+
1175+
joinall(threads)
1176+
1177+
self.assertTrue('foo_1' in coll.index_information())
1178+
finally:
1179+
coll.drop()
1180+
11211181
def test_ensure_unique_index_threaded(self):
11221182
coll = self.db.test_unique_threaded
11231183
coll.drop()

0 commit comments

Comments
 (0)