Skip to content

Commit ab6b3a3

Browse files
aherlihybehackett
authored andcommitted
Add support for maxAwaitTimeMS to getMore
1 parent 36129ed commit ab6b3a3

File tree

3 files changed

+153
-13
lines changed

3 files changed

+153
-13
lines changed

pymongo/cursor.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(self, collection, filter=None, projection=None, skip=0,
154154
self.__hint = None
155155
self.__comment = None
156156
self.__max_time_ms = None
157+
self.__max_await_time_ms = None
157158
self.__max = None
158159
self.__min = None
159160
self.__manipulate = manipulate
@@ -244,10 +245,10 @@ def _clone(self, deepcopy=True):
244245
"""Internal clone helper."""
245246
clone = self._clone_base()
246247
values_to_clone = ("spec", "projection", "skip", "limit",
247-
"max_time_ms", "comment", "max", "min",
248-
"ordering", "explain", "hint", "batch_size",
249-
"max_scan", "manipulate", "query_flags",
250-
"modifiers")
248+
"max_time_ms", "max_await_time_ms", "comment",
249+
"max", "min", "ordering", "explain", "hint",
250+
"batch_size", "max_scan", "manipulate",
251+
"query_flags", "modifiers")
251252
data = dict((k, v) for k, v in iteritems(self.__dict__)
252253
if k.startswith('_Cursor__') and k[9:] in values_to_clone)
253254
if deepcopy:
@@ -470,6 +471,35 @@ def max_time_ms(self, max_time_ms):
470471
self.__max_time_ms = max_time_ms
471472
return self
472473

474+
def max_await_time_ms(self, max_await_time_ms):
475+
"""Specifies a time limit for a getMore operation on a
476+
:attr:`~pymongo.CursorType.TAILABLE_AWAIT` cursor. For all other types
477+
of cursor max_await_time_ms is ignored.
478+
479+
Raises :exc:`TypeError` if `max_await_time_ms` is not an integer or
480+
``None``.
481+
Raises :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor`
482+
has already been used.
483+
484+
.. note:: `max_await_time_ms` requires server version **>= 3.2**
485+
486+
:Parameters:
487+
- `max_await_time_ms`: the time limit after which the operation is
488+
aborted
489+
490+
.. versionadded:: 3.2
491+
"""
492+
if (not isinstance(max_await_time_ms, integer_types)
493+
and max_await_time_ms is not None):
494+
raise TypeError("max_await_time_ms must be an integer or None")
495+
self.__check_okay_to_chain()
496+
497+
# Ignore max_await_time_ms if not tailable or await_data is False.
498+
if self.__query_flags & CursorType.TAILABLE_AWAIT:
499+
self.__max_await_time_ms = max_await_time_ms
500+
501+
return self
502+
473503
def __getitem__(self, index):
474504
"""Get a single document or a slice of documents from this cursor.
475505
@@ -1007,7 +1037,7 @@ def _refresh(self):
10071037
limit,
10081038
self.__id,
10091039
self.__codec_options,
1010-
self.__max_time_ms))
1040+
self.__max_await_time_ms))
10111041

10121042
else: # Cursor id is zero nothing else to return
10131043
self.__killed = True

pymongo/message.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,14 @@ def _gen_find_command(coll, spec, projection, skip, limit, batch_size,
196196
return cmd
197197

198198

199-
def _gen_get_more_command(cursor_id, coll, batch_size, max_time_ms):
199+
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms):
200200
"""Generate a getMore command document."""
201201
cmd = SON([('getMore', cursor_id),
202202
('collection', coll)])
203203
if batch_size:
204204
cmd['batchSize'] = batch_size
205-
if max_time_ms:
206-
cmd['maxTimeMS'] = max_time_ms
205+
if max_await_time_ms is not None:
206+
cmd['maxTimeMS'] = max_await_time_ms
207207
return cmd
208208

209209

@@ -274,24 +274,25 @@ def get_message(self, set_slave_ok, is_mongos, use_cmd=False):
274274
class _GetMore(object):
275275
"""A getmore operation."""
276276

277-
__slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_time_ms',
277+
__slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms',
278278
'codec_options')
279279

280280
name = 'getMore'
281281

282282
def __init__(self, db, coll, ntoreturn, cursor_id, codec_options,
283-
max_time_ms=None):
283+
max_await_time_ms=None):
284284
self.db = db
285285
self.coll = coll
286286
self.ntoreturn = ntoreturn
287287
self.cursor_id = cursor_id
288288
self.codec_options = codec_options
289-
self.max_time_ms = max_time_ms
289+
self.max_await_time_ms = max_await_time_ms
290290

291291
def as_command(self):
292292
"""Return a getMore command document for this query."""
293293
return _gen_get_more_command(self.cursor_id, self.coll,
294-
self.ntoreturn, self.max_time_ms), self.db
294+
self.ntoreturn,
295+
self.max_await_time_ms), self.db
295296

296297
def get_message(self, dummy0, dummy1, use_cmd=False):
297298
"""Get a getmore message."""

test/test_cursor.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from bson.py3compat import u, PY3
2626
from bson.son import SON
2727
from pymongo import (MongoClient,
28+
monitoring,
2829
ASCENDING,
2930
DESCENDING,
3031
ALL,
@@ -41,7 +42,7 @@
4142
host,
4243
port,
4344
IntegrationTest)
44-
from test.utils import server_started_with_auth
45+
from test.utils import server_started_with_auth, single_client, EventListener
4546

4647
if PY3:
4748
long = int
@@ -190,6 +191,114 @@ def test_max_time_ms(self):
190191
"maxTimeAlwaysTimeOut",
191192
mode="off")
192193

194+
@client_context.require_version_min(3, 1, 9, -1)
195+
def test_max_await_time_ms(self):
196+
db = self.db
197+
db.pymongo_test.drop()
198+
coll = db.create_collection("pymongo_test", capped=True, size=4096)
199+
200+
self.assertRaises(TypeError, coll.find().max_await_time_ms, 'foo')
201+
coll.insert_one({"amalia": 1})
202+
coll.insert_one({"amalia": 2})
203+
204+
coll.find().max_await_time_ms(None)
205+
coll.find().max_await_time_ms(long(1))
206+
207+
# When cursor is not tailable_await
208+
cursor = coll.find()
209+
self.assertEqual(None, cursor._Cursor__max_await_time_ms)
210+
cursor = coll.find().max_await_time_ms(99)
211+
self.assertEqual(None, cursor._Cursor__max_await_time_ms)
212+
213+
# If cursor is tailable_await and timeout is unset
214+
cursor = coll.find(cursor_type=CursorType.TAILABLE_AWAIT)
215+
self.assertEqual(None, cursor._Cursor__max_await_time_ms)
216+
217+
# If cursor is tailable_await and timeout is set
218+
cursor = coll.find(
219+
cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99)
220+
self.assertEqual(99, cursor._Cursor__max_await_time_ms)
221+
222+
cursor = coll.find(
223+
cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(
224+
10).max_await_time_ms(90)
225+
self.assertEqual(90, cursor._Cursor__max_await_time_ms)
226+
227+
listener = EventListener()
228+
saved_listeners = monitoring._LISTENERS
229+
monitoring._LISTENERS = monitoring._Listeners([])
230+
coll = single_client(
231+
event_listeners=[listener])[self.db.name].pymongo_test
232+
results = listener.results
233+
234+
try:
235+
# Tailable_await defaults.
236+
list(coll.find(cursor_type=CursorType.TAILABLE_AWAIT))
237+
# find
238+
self.assertFalse('maxTimeMS' in results['started'][0].command)
239+
# getMore
240+
self.assertFalse('maxTimeMS' in results['started'][1].command)
241+
results.clear()
242+
243+
# Tailable_await with max_await_time_ms set.
244+
list(coll.find(
245+
cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99))
246+
# find
247+
self.assertFalse('maxTimeMS' in results['started'][0].command)
248+
# getMore
249+
self.assertTrue('maxTimeMS' in results['started'][1].command)
250+
self.assertEqual(99, results['started'][1].command['maxTimeMS'])
251+
results.clear()
252+
253+
# Tailable_await with max_time_ms
254+
list(coll.find(
255+
cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(1))
256+
# find
257+
self.assertTrue('maxTimeMS' in results['started'][0].command)
258+
self.assertEqual(1, results['started'][0].command['maxTimeMS'])
259+
# getMore
260+
self.assertFalse('maxTimeMS' in results['started'][1].command)
261+
results.clear()
262+
263+
# Tailable_await with both max_time_ms and max_await_time_ms
264+
list(coll.find(
265+
cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(
266+
1).max_await_time_ms(99))
267+
# find
268+
self.assertTrue('maxTimeMS' in results['started'][0].command)
269+
self.assertEqual(1, results['started'][0].command['maxTimeMS'])
270+
# getMore
271+
self.assertTrue('maxTimeMS' in results['started'][1].command)
272+
self.assertEqual(99, results['started'][1].command['maxTimeMS'])
273+
results.clear()
274+
275+
# Non tailable_await with max_await_time_ms
276+
list(coll.find(batch_size=1).max_await_time_ms(99))
277+
# find
278+
self.assertFalse('maxTimeMS' in results['started'][0].command)
279+
# getMore
280+
self.assertFalse('maxTimeMS' in results['started'][1].command)
281+
results.clear()
282+
283+
# Non tailable_await with max_time_ms
284+
list(coll.find(batch_size=1).max_time_ms(99))
285+
# find
286+
self.assertTrue('maxTimeMS' in results['started'][0].command)
287+
self.assertEqual(99, results['started'][0].command['maxTimeMS'])
288+
# getMore
289+
self.assertFalse('maxTimeMS' in results['started'][1].command)
290+
291+
# Non tailable_await with both max_time_ms and max_await_time_ms
292+
list(coll.find(batch_size=1).max_time_ms(99).max_await_time_ms(88))
293+
# find
294+
self.assertTrue('maxTimeMS' in results['started'][0].command)
295+
self.assertEqual(99, results['started'][0].command['maxTimeMS'])
296+
# getMore
297+
self.assertFalse('maxTimeMS' in results['started'][1].command)
298+
299+
finally:
300+
monitoring._LISTENERS = saved_listeners
301+
193302
@client_context.require_version_min(2, 5, 3, -1)
194303
@client_context.require_test_commands
195304
def test_max_time_ms_getmore(self):

0 commit comments

Comments
 (0)