Skip to content

Commit 9c47f3c

Browse files
committed
Cursor add_option/remove_option PYTHON-242
You can now set and unset arbitrary query flags. Set the tailable flag: cursor = db.coll.find().add_option(2) Now unset it: cursor.remove_option(2)
1 parent 0c45881 commit 9c47f3c

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

pymongo/cursor.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(self, collection, spec=None, fields=None, skip=0, limit=0,
112112
self.__tz_aware = collection.database.connection.tz_aware
113113
self.__must_use_master = _must_use_master
114114
self.__is_command = _is_command
115+
self.__query_flags = 0
115116

116117
self.__data = []
117118
self.__connection_id = None
@@ -174,6 +175,7 @@ def clone(self):
174175
copy.__partial = self.__partial
175176
copy.__must_use_master = self.__must_use_master
176177
copy.__is_command = self.__is_command
178+
copy.__query_flags = self.__query_flags
177179
copy.__kwargs = self.__kwargs
178180
return copy
179181

@@ -209,7 +211,7 @@ def __query_spec(self):
209211
def __query_options(self):
210212
"""Get the query options string to use for this query.
211213
"""
212-
options = 0
214+
options = self.__query_flags
213215
if self.__tailable:
214216
options |= _QUERY_OPTIONS["tailable_cursor"]
215217
if self.__slave_okay:
@@ -228,6 +230,32 @@ def __check_okay_to_chain(self):
228230
if self.__retrieved or self.__id is not None:
229231
raise InvalidOperation("cannot set options after executing query")
230232

233+
def add_option(self, mask):
234+
"""Set arbitary query flags using a bitmask.
235+
236+
To set the tailable flag:
237+
cursor.add_option(2)
238+
"""
239+
if not isinstance(mask, int):
240+
raise TypeError("mask must be an int")
241+
self.__check_okay_to_chain()
242+
243+
self.__query_flags |= mask
244+
return self
245+
246+
def remove_option(self, mask):
247+
"""Unset arbitrary query flags using a bitmask.
248+
249+
To unset the tailable flag:
250+
cursor.remove_option(2)
251+
"""
252+
if not isinstance(mask, int):
253+
raise TypeError("mask must be an int")
254+
self.__check_okay_to_chain()
255+
256+
self.__query_flags &= ~mask
257+
return self
258+
231259
def limit(self, limit):
232260
"""Limits the number of results to be returned by this cursor.
233261

test/test_cursor.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,16 +464,62 @@ class MyClass(dict):
464464
slave_okay=True,
465465
await_data=True,
466466
partial=True).limit(2)
467+
cursor.add_option(64)
467468

468469
cursor2 = cursor.clone()
469470
self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
471+
self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
470472
self.assertEqual(cursor._Cursor__timeout, cursor2._Cursor__timeout)
471473
self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
472474
self.assertEqual(cursor._Cursor__tailable, cursor2._Cursor__tailable)
473-
self.assertEqual(type(cursor._Cursor__as_class), type(cursor2._Cursor__as_class))
474-
self.assertEqual(cursor._Cursor__slave_okay, cursor2._Cursor__slave_okay)
475-
self.assertEqual(cursor._Cursor__await_data, cursor2._Cursor__await_data)
475+
self.assertEqual(type(cursor._Cursor__as_class),
476+
type(cursor2._Cursor__as_class))
477+
self.assertEqual(cursor._Cursor__slave_okay,
478+
cursor2._Cursor__slave_okay)
479+
self.assertEqual(cursor._Cursor__await_data,
480+
cursor2._Cursor__await_data)
476481
self.assertEqual(cursor._Cursor__partial, cursor2._Cursor__partial)
482+
self.assertEqual(cursor._Cursor__query_flags,
483+
cursor2._Cursor__query_flags)
484+
485+
def test_add_remove_option(self):
486+
cursor = self.db.test.find()
487+
self.assertEqual(0, cursor._Cursor__query_options())
488+
cursor.add_option(2)
489+
cursor2 = self.db.test.find(tailable=True)
490+
self.assertEqual(2, cursor2._Cursor__query_options())
491+
self.assertEqual(cursor._Cursor__query_options(),
492+
cursor2._Cursor__query_options())
493+
cursor.add_option(32)
494+
cursor2 = self.db.test.find(tailable=True, await_data=True)
495+
self.assertEqual(34, cursor2._Cursor__query_options())
496+
self.assertEqual(cursor._Cursor__query_options(),
497+
cursor2._Cursor__query_options())
498+
cursor.add_option(128)
499+
cursor2 = self.db.test.find(tailable=True,
500+
await_data=True).add_option(128)
501+
self.assertEqual(162, cursor2._Cursor__query_options())
502+
self.assertEqual(cursor._Cursor__query_options(),
503+
cursor2._Cursor__query_options())
504+
505+
self.assertEqual(162, cursor._Cursor__query_options())
506+
cursor.add_option(128)
507+
self.assertEqual(162, cursor._Cursor__query_options())
508+
509+
cursor.remove_option(128)
510+
cursor2 = self.db.test.find(tailable=True, await_data=True)
511+
self.assertEqual(34, cursor2._Cursor__query_options())
512+
self.assertEqual(cursor._Cursor__query_options(),
513+
cursor2._Cursor__query_options())
514+
cursor.remove_option(32)
515+
cursor2 = self.db.test.find(tailable=True)
516+
self.assertEqual(2, cursor2._Cursor__query_options())
517+
self.assertEqual(cursor._Cursor__query_options(),
518+
cursor2._Cursor__query_options())
519+
520+
self.assertEqual(2, cursor._Cursor__query_options())
521+
cursor.remove_option(32)
522+
self.assertEqual(2, cursor._Cursor__query_options())
477523

478524
def test_count_with_fields(self):
479525
self.db.test.drop()

0 commit comments

Comments
 (0)