Skip to content

Commit 14f626c

Browse files
committed
Factor Max Staleness and Server Selection tests
1 parent 306e990 commit 14f626c

File tree

3 files changed

+287
-394
lines changed

3 files changed

+287
-394
lines changed

test/test_max_staleness.py

Lines changed: 5 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -14,215 +14,31 @@
1414

1515
"""Test maxStalenessSeconds support."""
1616

17-
import datetime
1817
import os
19-
import time
2018
import sys
19+
import time
2120
import warnings
2221

2322
sys.path[0:0] = [""]
2423

25-
from bson import json_util
26-
from pymongo import MongoClient, read_preferences
27-
from pymongo.common import clean_node, HEARTBEAT_FREQUENCY
28-
from pymongo.errors import ConfigurationError, ConnectionFailure
29-
from pymongo.ismaster import IsMaster
30-
from pymongo.server_description import ServerDescription
24+
from pymongo import MongoClient
25+
from pymongo.errors import ConfigurationError
3126
from pymongo.server_selectors import writable_server_selector
32-
from pymongo.settings import TopologySettings
33-
from pymongo.topology import Topology
3427

3528
from test import client_context, unittest
3629
from test.utils import rs_or_single_client
30+
from test.utils_selection_tests import create_selection_tests
3731

3832
# Location of JSON test specifications.
3933
_TEST_PATH = os.path.join(
4034
os.path.dirname(os.path.realpath(__file__)),
4135
'max_staleness')
4236

4337

44-
class MockSocketInfo(object):
45-
def close(self):
46-
pass
47-
48-
def __enter__(self):
49-
return self
50-
51-
def __exit__(self, exc_type, exc_val, exc_tb):
52-
pass
53-
54-
55-
class MockPool(object):
56-
def __init__(self, *args, **kwargs):
57-
pass
58-
59-
def reset(self):
60-
pass
61-
62-
63-
class MockMonitor(object):
64-
def __init__(self, server_description, topology, pool, topology_settings):
65-
pass
66-
67-
def open(self):
68-
pass
69-
70-
def request_check(self):
71-
pass
72-
73-
def close(self):
74-
pass
75-
76-
77-
def get_addresses(server_list):
78-
seeds = []
79-
hosts = []
80-
for server in server_list:
81-
seeds.append(clean_node(server['address']))
82-
hosts.append(server['address'])
83-
return seeds, hosts
84-
85-
86-
def make_last_write_date(server):
87-
epoch = datetime.datetime.utcfromtimestamp(0)
88-
millis = server.get('lastWrite', {}).get('lastWriteDate')
89-
if millis:
90-
diff = ((millis % 1000) + 1000) % 1000
91-
seconds = (millis - diff) / 1000
92-
micros = diff * 1000
93-
return epoch + datetime.timedelta(
94-
seconds=seconds, microseconds=micros)
95-
else:
96-
# "Unknown" server.
97-
return epoch
98-
99-
100-
def make_server_description(server, hosts):
101-
"""Make ServerDescription from server info from JSON file."""
102-
server_type = server['type']
103-
if server_type == "Unknown":
104-
return ServerDescription(clean_node(server['address']), IsMaster({}))
105-
106-
ismaster_response = {'ok': True, 'hosts': hosts}
107-
if server_type != "Standalone" and server_type != "Mongos":
108-
ismaster_response['setName'] = "rs"
109-
110-
if server_type == "RSPrimary":
111-
ismaster_response['ismaster'] = True
112-
elif server_type == "RSSecondary":
113-
ismaster_response['secondary'] = True
114-
elif server_type == "Mongos":
115-
ismaster_response['msg'] = 'isdbgrid'
116-
117-
ismaster_response['lastWrite'] = {
118-
'lastWriteDate': make_last_write_date(server)
119-
}
120-
121-
for field in 'maxWireVersion', 'tags', 'idleWritePeriodMillis':
122-
if field in server:
123-
ismaster_response[field] = server[field]
124-
125-
# Sets _last_update_time to now.
126-
sd = ServerDescription(clean_node(server['address']),
127-
IsMaster(ismaster_response),
128-
round_trip_time=server['avg_rtt_ms'])
129-
130-
sd._last_update_time = server['lastUpdateTime'] / 1000.0 # ms to sec.
131-
return sd
132-
133-
134-
class TestAllScenarios(unittest.TestCase):
38+
class TestAllScenarios(create_selection_tests(_TEST_PATH)):
13539
pass
13640

13741

138-
def create_test(scenario_def):
139-
def run_scenario(self):
140-
if 'heartbeatFrequencyMS' in scenario_def:
141-
frequency = int(scenario_def['heartbeatFrequencyMS']) / 1000.0
142-
else:
143-
frequency = HEARTBEAT_FREQUENCY
144-
145-
# Initialize topologies.
146-
seeds, hosts = get_addresses(
147-
scenario_def['topology_description']['servers'])
148-
149-
topology = Topology(
150-
TopologySettings(seeds=seeds,
151-
monitor_class=MockMonitor,
152-
pool_class=MockPool,
153-
heartbeat_frequency=frequency))
154-
155-
# Update topologies with server descriptions.
156-
for server in scenario_def['topology_description']['servers']:
157-
server_description = make_server_description(server, hosts)
158-
topology.on_change(server_description)
159-
160-
# Create server selector.
161-
# Make first letter lowercase to match read_pref's modes.
162-
pref_def = scenario_def['read_preference']
163-
mode_string = pref_def.get('mode', 'primary')
164-
mode_string = mode_string[:1].lower() + mode_string[1:]
165-
mode = read_preferences.read_pref_mode_from_name(mode_string)
166-
max_staleness = pref_def.get('maxStalenessSeconds', -1)
167-
tag_sets = pref_def.get('tag_sets')
168-
169-
if scenario_def.get('error'):
170-
with self.assertRaises(ConfigurationError):
171-
# Error can be raised when making Read Pref or selecting.
172-
pref = read_preferences.make_read_preference(
173-
mode, tag_sets=tag_sets, max_staleness=max_staleness)
174-
175-
topology.select_server(pref)
176-
return
177-
178-
expected_addrs = set([
179-
server['address'] for server in scenario_def['in_latency_window']])
180-
181-
# Select servers.
182-
pref = read_preferences.make_read_preference(
183-
mode, tag_sets=tag_sets, max_staleness=max_staleness)
184-
185-
if not expected_addrs:
186-
with self.assertRaises(ConnectionFailure):
187-
topology.select_servers(pref, server_selection_timeout=0)
188-
return
189-
190-
servers = topology.select_servers(pref, server_selection_timeout=0)
191-
actual_addrs = set(['%s:%d' % s.description.address for s in servers])
192-
193-
for unexpected in actual_addrs - expected_addrs:
194-
self.fail("'%s' shouldn't have been selected, but was" % unexpected)
195-
196-
for unselected in expected_addrs - actual_addrs:
197-
self.fail("'%s' should have been selected, but wasn't" % unselected)
198-
199-
return run_scenario
200-
201-
202-
def create_tests():
203-
for dirpath, _, filenames in os.walk(_TEST_PATH):
204-
dirname = os.path.split(dirpath)
205-
dirname = os.path.split(dirname[-2])[-1] + '_' + dirname[-1]
206-
207-
for filename in filenames:
208-
if not filename.endswith('.json'):
209-
continue
210-
211-
with open(os.path.join(dirpath, filename)) as scenario_stream:
212-
scenario_def = json_util.loads(scenario_stream.read())
213-
214-
# Construct test from scenario.
215-
new_test = create_test(scenario_def)
216-
test_name = 'test_%s_%s' % (
217-
dirname, os.path.splitext(filename)[0])
218-
219-
new_test.__name__ = test_name
220-
setattr(TestAllScenarios, new_test.__name__, new_test)
221-
222-
223-
create_tests()
224-
225-
22642
class TestMaxStaleness(unittest.TestCase):
22743
def test_max_staleness(self):
22844
client = MongoClient()

0 commit comments

Comments
 (0)