Skip to content

Commit 08ffa4f

Browse files
author
Mike Dirolf
committed
add batch_size() to Cursor PYTHON-161
1 parent 64f9641 commit 08ffa4f

File tree

2 files changed

+84
-18
lines changed

2 files changed

+84
-18
lines changed

pymongo/cursor.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(self, collection, spec=None, fields=None, skip=0, limit=0,
7979
self.__fields = fields
8080
self.__skip = skip
8181
self.__limit = limit
82+
self.__batch_size = 0
8283

8384
# This is ugly. People want to be able to do cursor[5:5] and
8485
# get an empty result set (old behavior was an
@@ -153,6 +154,7 @@ def clone(self):
153154
copy.__ordering = self.__ordering
154155
copy.__explain = self.__explain
155156
copy.__hint = self.__hint
157+
copy.__batch_size = self.__batch_size
156158
return copy
157159

158160
def __die(self):
@@ -223,6 +225,30 @@ def limit(self, limit):
223225
self.__limit = limit
224226
return self
225227

228+
def batch_size(self, batch_size):
229+
"""Set the size for batches of results returned by this cursor.
230+
231+
Raises :class:`TypeError` if `batch_size` is not an instance
232+
of :class:`int`. Raises :class:`ValueError` if `batch_size` is
233+
less than ``0``. Raises
234+
:class:`~pymongo.errors.InvalidOperation` if this
235+
:class:`Cursor` has already been used. The last `batch_size`
236+
applied to this cursor takes precedence.
237+
238+
:Parameters:
239+
- `batch_size`: The size of each batch of results requested.
240+
241+
.. versionadded:: 1.8.1+
242+
"""
243+
if not isinstance(batch_size, int):
244+
raise TypeError("batch_size must be an int")
245+
if batch_size < 0:
246+
raise ValueError("batch_size must be >= 0")
247+
self.__check_okay_to_chain()
248+
249+
self.__batch_size = batch_size == 1 and 2 or batch_size
250+
return self
251+
226252
def skip(self, skip):
227253
"""Skips the first `skip` results of this cursor.
228254
@@ -530,24 +556,21 @@ def _refresh(self):
530556
if len(self.__data) or self.__killed:
531557
return len(self.__data)
532558

533-
if self.__id is None:
534-
# Query
559+
if self.__id is None: # Query
535560
self.__send_message(
536561
message.query(self.__query_options(),
537562
self.__collection.full_name,
538563
self.__skip, self.__limit,
539564
self.__query_spec(), self.__fields))
540565
if not self.__id:
541566
self.__killed = True
542-
elif self.__id:
543-
# Get More
544-
limit = 0
567+
elif self.__id: # Get More
545568
if self.__limit:
546-
if self.__limit > self.__retrieved:
547-
limit = self.__limit - self.__retrieved
548-
else:
549-
self.__killed = True
550-
return 0
569+
limit = self.__limit - self.__retrieved
570+
if self.__batch_size:
571+
limit = min(limit, self.__batch_size)
572+
else:
573+
limit = self.__batch_size
551574

552575
self.__send_message(
553576
message.get_more(self.__collection.full_name,

test/test_cursor.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def test_explain(self):
5252
def test_hint(self):
5353
db = self.db
5454
self.assertRaises(TypeError, db.test.find().hint, 5.5)
55-
db.test.remove({})
56-
db.test.drop_indexes()
55+
db.test.drop()
5756

5857
for i in range(100):
5958
db.test.insert({"num": i, "foo": i})
@@ -94,7 +93,7 @@ def test_limit(self):
9493
self.assertRaises(TypeError, db.test.find().limit, "hello")
9594
self.assertRaises(TypeError, db.test.find().limit, 5.5)
9695

97-
db.test.remove({})
96+
db.test.drop()
9897
for i in range(100):
9998
db.test.save({"x": i})
10099

@@ -134,6 +133,50 @@ def test_limit(self):
134133
break
135134
self.assertRaises(InvalidOperation, a.limit, 5)
136135

136+
137+
def test_batch_size(self):
138+
db = self.db
139+
db.test.drop()
140+
for x in range(200):
141+
db.test.save({"x": x})
142+
143+
self.assertRaises(TypeError, db.test.find().batch_size, None)
144+
self.assertRaises(TypeError, db.test.find().batch_size, "hello")
145+
self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
146+
self.assertRaises(ValueError, db.test.find().batch_size, -1)
147+
a = db.test.find()
148+
for _ in a:
149+
break
150+
self.assertRaises(InvalidOperation, a.batch_size, 5)
151+
152+
def cursor_count(cursor, expected_count):
153+
count = 0
154+
for _ in cursor:
155+
count += 1
156+
self.assertEqual(expected_count, count)
157+
158+
cursor_count(db.test.find().batch_size(0), 200)
159+
cursor_count(db.test.find().batch_size(1), 200)
160+
cursor_count(db.test.find().batch_size(2), 200)
161+
cursor_count(db.test.find().batch_size(5), 200)
162+
cursor_count(db.test.find().batch_size(100), 200)
163+
cursor_count(db.test.find().batch_size(500), 200)
164+
165+
cursor_count(db.test.find().batch_size(0).limit(1), 1)
166+
cursor_count(db.test.find().batch_size(1).limit(1), 1)
167+
cursor_count(db.test.find().batch_size(2).limit(1), 1)
168+
cursor_count(db.test.find().batch_size(5).limit(1), 1)
169+
cursor_count(db.test.find().batch_size(100).limit(1), 1)
170+
cursor_count(db.test.find().batch_size(500).limit(1), 1)
171+
172+
cursor_count(db.test.find().batch_size(0).limit(10), 10)
173+
cursor_count(db.test.find().batch_size(1).limit(10), 10)
174+
cursor_count(db.test.find().batch_size(2).limit(10), 10)
175+
cursor_count(db.test.find().batch_size(5).limit(10), 10)
176+
cursor_count(db.test.find().batch_size(100).limit(10), 10)
177+
cursor_count(db.test.find().batch_size(500).limit(10), 10)
178+
179+
137180
def test_skip(self):
138181
db = self.db
139182

@@ -189,7 +232,7 @@ def test_sort(self):
189232
[("hello", DESCENDING)], DESCENDING)
190233
self.assertRaises(TypeError, db.test.find().sort, "hello", "world")
191234

192-
db.test.remove({})
235+
db.test.drop()
193236

194237
unsort = range(10)
195238
random.shuffle(unsort)
@@ -218,7 +261,7 @@ def test_sort(self):
218261
shuffled = list(expected)
219262
random.shuffle(shuffled)
220263

221-
db.test.remove({})
264+
db.test.drop()
222265
for (a, b) in shuffled:
223266
db.test.save({"a": a, "b": b})
224267

@@ -235,7 +278,7 @@ def test_sort(self):
235278

236279
def test_count(self):
237280
db = self.db
238-
db.test.remove({})
281+
db.test.drop()
239282

240283
self.assertEqual(0, db.test.find().count())
241284

@@ -260,7 +303,7 @@ def test_count(self):
260303

261304
def test_where(self):
262305
db = self.db
263-
db.test.remove({})
306+
db.test.drop()
264307

265308
a = db.test.find()
266309
self.assertRaises(TypeError, a.where, 5)
@@ -405,7 +448,7 @@ def test_clone(self):
405448
self.assertNotEqual(cursor, cursor.clone())
406449

407450
def test_count_with_fields(self):
408-
self.db.test.remove({})
451+
self.db.test.drop()
409452
self.db.test.save({"x": 1})
410453

411454
if not version.at_least(self.db.connection, (1, 1, 3, -1)):

0 commit comments

Comments
 (0)