Skip to content

Commit 6c2e009

Browse files
author
A. Jesse Jiryu Davis
committed
Fix MongoReplicaSetClient race conditions PYTHON-467
* RSState, Member, and MovingAverage are now immutable * In refresh(), try up members before down ones * A test_ha fixup (clear process-list after killing them in each test, so we don't re-kill previous tests' processes)
1 parent 55f1df7 commit 6c2e009

File tree

8 files changed

+564
-357
lines changed

8 files changed

+564
-357
lines changed

pymongo/mongo_replica_set_client.py

Lines changed: 390 additions & 193 deletions
Large diffs are not rendered by default.

pymongo/read_preferences.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
"""Utilities for choosing which member of a replica set to read from."""
1616

1717
import random
18-
import threading
19-
from collections import deque
2018

2119
from pymongo.errors import ConfigurationError
2220

@@ -86,7 +84,6 @@ def mongos_mode(mode):
8684
def mongos_enum(enum):
8785
return _mongos_modes.index(enum)
8886

89-
9087
def select_primary(members):
9188
for member in members:
9289
if member.is_primary:
@@ -151,6 +148,7 @@ def select_member(
151148
return select_primary(members)
152149

153150
elif mode == PRIMARY_PREFERRED:
151+
# Recurse.
154152
candidate_primary = select_member(members, PRIMARY, [{}], latency)
155153
if candidate_primary:
156154
return candidate_primary
@@ -166,6 +164,7 @@ def select_member(
166164
return None
167165

168166
elif mode == SECONDARY_PREFERRED:
167+
# Recurse.
169168
candidate_secondary = select_member(
170169
members, SECONDARY, tag_sets, latency)
171170
if candidate_secondary:
@@ -196,32 +195,16 @@ def select_member(
196195

197196

198197
class MovingAverage(object):
199-
"""Tracks a moving average.
200-
"""
201-
def __init__(self, window_sz):
202-
self.window_sz = window_sz
203-
self.samples = deque()
204-
self.total = 0
205-
self.lock = threading.Lock()
206-
207-
def update(self, sample):
208-
# One reason we synchronize MovingAverage is that Jython's
209-
# popleft isn't safe: http://bugs.jython.org/issue2001
210-
self.lock.acquire()
211-
try:
212-
self.samples.append(sample)
213-
self.total += sample
214-
if len(self.samples) > self.window_sz:
215-
self.total -= self.samples.popleft()
216-
finally:
217-
self.lock.release()
198+
def __init__(self, samples):
199+
"""Immutable structure to track a 5-sample moving average.
200+
"""
201+
self.samples = samples[-5:]
202+
assert self.samples
203+
self.average = sum(self.samples) / float(len(self.samples))
204+
205+
def clone_with(self, sample):
206+
"""Get a copy of this instance plus a new sample"""
207+
return MovingAverage(self.samples + [sample])
218208

219209
def get(self):
220-
self.lock.acquire()
221-
try:
222-
if self.samples:
223-
return self.total / float(len(self.samples))
224-
else:
225-
return None
226-
finally:
227-
self.lock.release()
210+
return self.average

test/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@
2020
import pymongo
2121
from pymongo.errors import ConnectionFailure
2222

23-
host = os.environ.get("DB_IP", 'localhost')
23+
# hostnames retrieved by MongoReplicaSetClient from isMaster will be of unicode
24+
# type in Python 2, so ensure these hostnames are unicodes, too. It makes tests
25+
# like `test_repr` predictable.
26+
host = unicode(os.environ.get("DB_IP", 'localhost'))
2427
port = int(os.environ.get("DB_PORT", 27017))
2528
pair = '%s:%d' % (host, port)
2629

27-
host2 = os.environ.get("DB_IP2", 'localhost')
30+
host2 = unicode(os.environ.get("DB_IP2", 'localhost'))
2831
port2 = int(os.environ.get("DB_PORT2", 27018))
2932

30-
host3 = os.environ.get("DB_IP3", 'localhost')
33+
host3 = unicode(os.environ.get("DB_IP3", 'localhost'))
3134
port3 = int(os.environ.get("DB_PORT3", 27019))
3235

3336

test/high_availability/test_ha.py

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,11 @@
2727
from ha_tools import use_greenlets
2828

2929

30-
from pymongo import (MongoReplicaSetClient,
31-
ReadPreference)
32-
from pymongo.mongo_replica_set_client import (
33-
Member, Monitor, MongoReplicaSetClient)
34-
from pymongo.mongo_client import _partition_node
35-
from pymongo.mongo_client import MongoClient
3630
from pymongo.errors import AutoReconnect, OperationFailure, ConnectionFailure
37-
from pymongo.read_preferences import modes
31+
from pymongo.mongo_replica_set_client import Member, Monitor
32+
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
33+
from pymongo.mongo_client import MongoClient, _partition_node
34+
from pymongo.read_preferences import ReadPreference, modes
3835

3936
from test import utils
4037
from test.utils import one
@@ -52,7 +49,15 @@
5249
NEAREST = ReadPreference.NEAREST
5350

5451

55-
class TestDirectConnection(unittest.TestCase):
52+
class HATestCase(unittest.TestCase):
53+
"""A test case for connections to replica sets or mongos."""
54+
55+
def tearDown(self):
56+
ha_tools.kill_all_members()
57+
ha_tools.nodes.clear()
58+
59+
60+
class TestDirectConnection(HATestCase):
5661

5762
def setUp(self):
5863
members = [{}, {}, {'arbiterOnly': True}]
@@ -145,10 +150,10 @@ def test_secondary_connection(self):
145150

146151
def tearDown(self):
147152
self.c.close()
148-
ha_tools.kill_all_members()
153+
super(TestDirectConnection, self).tearDown()
149154

150155

151-
class TestPassiveAndHidden(unittest.TestCase):
156+
class TestPassiveAndHidden(HATestCase):
152157

153158
def setUp(self):
154159
members = [{},
@@ -177,10 +182,10 @@ def test_passive_and_hidden(self):
177182

178183
def tearDown(self):
179184
self.c.close()
180-
ha_tools.kill_all_members()
185+
super(TestPassiveAndHidden, self).tearDown()
181186

182187

183-
class TestMonitorRemovesRecoveringMember(unittest.TestCase):
188+
class TestMonitorRemovesRecoveringMember(HATestCase):
184189
# Members in STARTUP2 or RECOVERING states are shown in the primary's
185190
# isMaster response, but aren't secondaries and shouldn't be read from.
186191
# Verify that if a secondary goes into RECOVERING mode, the Monitor removes
@@ -211,10 +216,10 @@ def test_monitor_removes_recovering_member(self):
211216

212217
def tearDown(self):
213218
self.c.close()
214-
ha_tools.kill_all_members()
219+
super(TestMonitorRemovesRecoveringMember, self).tearDown()
215220

216221

217-
class TestTriggeredRefresh(unittest.TestCase):
222+
class TestTriggeredRefresh(HATestCase):
218223
# Verify that if a secondary goes into RECOVERING mode or if the primary
219224
# changes, the next exception triggers an immediate refresh.
220225

@@ -282,14 +287,14 @@ def test_stepdown_triggers_refresh(self):
282287
# We've detected the stepdown
283288
self.assertTrue(
284289
not c_find_one.primary
285-
or primary != _partition_node(c_find_one.primary))
290+
or _partition_node(primary) != c_find_one.primary)
286291

287292
def tearDown(self):
288293
Monitor._refresh_interval = MONITOR_INTERVAL
289-
ha_tools.kill_all_members()
294+
super(TestTriggeredRefresh, self).tearDown()
290295

291296

292-
class TestHealthMonitor(unittest.TestCase):
297+
class TestHealthMonitor(HATestCase):
293298

294299
def setUp(self):
295300
res = ha_tools.start_replica_set([{}, {}, {}])
@@ -357,17 +362,10 @@ def primary_changed():
357362

358363
ha_tools.stepdown_primary()
359364
self.assertTrue(primary_changed())
360-
361-
# There can be a delay between finding the primary and updating
362-
# secondaries
363-
sleep(5)
364365
self.assertNotEqual(secondaries, c.secondaries)
365366

366-
def tearDown(self):
367-
ha_tools.kill_all_members()
368-
369367

370-
class TestWritesWithFailover(unittest.TestCase):
368+
class TestWritesWithFailover(HATestCase):
371369

372370
def setUp(self):
373371
res = ha_tools.start_replica_set([{}, {}, {}])
@@ -398,11 +396,8 @@ def try_write():
398396
self.assertTrue(primary != c.primary)
399397
self.assertEqual('baz', db.test.find_one({'bar': 'baz'})['bar'])
400398

401-
def tearDown(self):
402-
ha_tools.kill_all_members()
403-
404399

405-
class TestReadWithFailover(unittest.TestCase):
400+
class TestReadWithFailover(HATestCase):
406401

407402
def setUp(self):
408403
res = ha_tools.start_replica_set([{}, {}, {}])
@@ -435,11 +430,8 @@ def iter_cursor(cursor):
435430
self.assertTrue(iter_cursor(cursor))
436431
self.assertEqual(10, cursor._Cursor__retrieved)
437432

438-
def tearDown(self):
439-
ha_tools.kill_all_members()
440433

441-
442-
class TestReadPreference(unittest.TestCase):
434+
class TestReadPreference(HATestCase):
443435
def setUp(self):
444436
members = [
445437
# primary
@@ -784,11 +776,10 @@ def test_pinning(self):
784776

785777
def tearDown(self):
786778
self.c.close()
787-
ha_tools.kill_all_members()
788-
self.clear_ping_times()
779+
super(TestReadPreference, self).tearDown()
789780

790781

791-
class TestReplicaSetAuth(unittest.TestCase):
782+
class TestReplicaSetAuth(HATestCase):
792783
def setUp(self):
793784
members = [
794785
{},
@@ -829,10 +820,10 @@ def test_auth_during_failover(self):
829820

830821
def tearDown(self):
831822
self.c.close()
832-
ha_tools.kill_all_members()
823+
super(TestReplicaSetAuth, self).tearDown()
833824

834825

835-
class TestAlive(unittest.TestCase):
826+
class TestAlive(HATestCase):
836827
def setUp(self):
837828
members = [{}, {}]
838829
self.seed, self.name = ha_tools.start_replica_set(members)
@@ -866,11 +857,8 @@ def test_alive(self):
866857
finally:
867858
rsc.close()
868859

869-
def tearDown(self):
870-
ha_tools.kill_all_members()
871860

872-
873-
class TestMongosHighAvailability(unittest.TestCase):
861+
class TestMongosHighAvailability(HATestCase):
874862
def setUp(self):
875863
seed_list = ha_tools.create_sharded_cluster()
876864
self.dbname = 'pymongo_mongos_ha'
@@ -910,10 +898,10 @@ def test_mongos_ha(self):
910898

911899
def tearDown(self):
912900
self.client.drop_database(self.dbname)
913-
ha_tools.kill_all_members()
901+
super(TestMongosHighAvailability, self).tearDown()
914902

915903

916-
class TestReplicaSetRequest(unittest.TestCase):
904+
class TestReplicaSetRequest(HATestCase):
917905
def setUp(self):
918906
members = [{}, {}, {'arbiterOnly': True}]
919907
res = ha_tools.start_replica_set(members)
@@ -928,8 +916,9 @@ def test_request_during_failover(self):
928916
self.assertTrue(self.c.auto_start_request)
929917
self.assertTrue(self.c.in_request())
930918

931-
primary_pool = self.c._MongoReplicaSetClient__members[primary].pool
932-
secondary_pool = self.c._MongoReplicaSetClient__members[secondary].pool
919+
rs_state = self.c._MongoReplicaSetClient__rs_state
920+
primary_pool = rs_state.get(primary).pool
921+
secondary_pool = rs_state.get(secondary).pool
933922

934923
# Trigger start_request on primary pool
935924
utils.assertReadFrom(self, self.c, primary, PRIMARY)
@@ -967,7 +956,7 @@ def test_request_during_failover(self):
967956

968957
def tearDown(self):
969958
self.c.close()
970-
ha_tools.kill_all_members()
959+
super(TestReplicaSetRequest, self).tearDown()
971960

972961

973962
if __name__ == '__main__':

test/test_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def test_host_w_port(self):
100100
ConnectionFailure, MongoClient, "%s:1234567" % (host,), port)
101101

102102
def test_repr(self):
103-
self.assertEqual(repr(MongoClient(host, port)),
103+
# Making host a str avoids the 'u' prefix in Python 2, so the repr is
104+
# the same in Python 2 and 3.
105+
self.assertEqual(repr(MongoClient(str(host), port)),
104106
"MongoClient('%s', %d)" % (host, port))
105107

106108
def test_getters(self):

test/test_read_preferences.py

Lines changed: 18 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def test_nearest(self):
180180
not_used = data_members.difference(used)
181181
latencies = ', '.join(
182182
'%s: %dms' % (member.host, member.ping_time.get())
183-
for member in c._MongoReplicaSetClient__members.values())
183+
for member in c._MongoReplicaSetClient__rs_state.members)
184184

185185
self.assertFalse(not_used,
186186
"Expected to use primary and all secondaries for mode NEAREST,"
@@ -426,51 +426,24 @@ def test_aggregate(self):
426426

427427

428428
class TestMovingAverage(unittest.TestCase):
429-
def test_empty_moving_average(self):
430-
avg = MovingAverage(0)
431-
self.assertEqual(None, avg.get())
432-
avg.update(10)
433-
self.assertEqual(None, avg.get())
434-
435-
def test_trivial_moving_average(self):
436-
avg = MovingAverage(1)
437-
self.assertEqual(None, avg.get())
438-
avg.update(10)
439-
self.assertEqual(10, avg.get())
440-
avg.update(20)
441-
self.assertEqual(20, avg.get())
442-
avg.update(0)
443-
self.assertEqual(0, avg.get())
444-
445-
def test_2_sample_moving_average(self):
446-
avg = MovingAverage(2)
447-
self.assertEqual(None, avg.get())
448-
avg.update(10)
449-
self.assertEqual(10, avg.get())
450-
avg.update(20)
451-
self.assertEqual(15, avg.get())
452-
avg.update(30)
453-
self.assertEqual(25, avg.get())
454-
avg.update(-100)
455-
self.assertEqual(-35, avg.get())
456-
457-
def test_5_sample_moving_average(self):
458-
avg = MovingAverage(5)
459-
self.assertEqual(None, avg.get())
460-
avg.update(10)
429+
def test_empty_init(self):
430+
self.assertRaises(AssertionError, MovingAverage, [])
431+
432+
def test_moving_average(self):
433+
avg = MovingAverage([10])
461434
self.assertEqual(10, avg.get())
462-
avg.update(20)
463-
self.assertEqual(15, avg.get())
464-
avg.update(30)
465-
self.assertEqual(20, avg.get())
466-
avg.update(-100)
467-
self.assertEqual((10 + 20 + 30 - 100) / 4, avg.get())
468-
avg.update(17)
469-
self.assertEqual((10 + 20 + 30 - 100 + 17) / 5., avg.get())
470-
avg.update(43)
471-
self.assertEqual((20 + 30 - 100 + 17 + 43) / 5., avg.get())
472-
avg.update(-1111)
473-
self.assertEqual((30 - 100 + 17 + 43 - 1111) / 5., avg.get())
435+
avg2 = avg.clone_with(20)
436+
self.assertEqual(15, avg2.get())
437+
avg3 = avg2.clone_with(30)
438+
self.assertEqual(20, avg3.get())
439+
avg4 = avg3.clone_with(-100)
440+
self.assertEqual((10 + 20 + 30 - 100) / 4., avg4.get())
441+
avg5 = avg4.clone_with(17)
442+
self.assertEqual((10 + 20 + 30 - 100 + 17) / 5., avg5.get())
443+
avg6 = avg5.clone_with(43)
444+
self.assertEqual((20 + 30 - 100 + 17 + 43) / 5., avg6.get())
445+
avg7 = avg6.clone_with(-1111)
446+
self.assertEqual((30 - 100 + 17 + 43 - 1111) / 5., avg7.get())
474447

475448

476449
class TestMongosConnection(unittest.TestCase):

0 commit comments

Comments
 (0)