Skip to content
This repository was archived by the owner on Mar 24, 2021. It is now read-only.

Commit f001ccb

Browse files
committed
Merge pull request #513 from Parsely/testing/shared_consumer
stop producer tests from sharing a consumer instance
2 parents 34ccb49 + 33d2f0a commit f001ccb

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

tests/pykafka/test_producer.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from uuid import uuid4
88

99
from pykafka import KafkaClient
10+
from pykafka.common import OffsetType
1011
from pykafka.exceptions import MessageSizeTooLarge, ProducerQueueFullError
1112
from pykafka.partitioners import hashing_partitioner
1213
from pykafka.protocol import Message
@@ -26,27 +27,33 @@ def setUpClass(cls):
2627
cls.topic_name = b'test-data'
2728
cls.kafka.create_topic(cls.topic_name, 3, 2)
2829
cls.client = KafkaClient(cls.kafka.brokers, use_greenlets=cls.USE_GEVENT)
29-
cls.consumer = cls.client.topics[cls.topic_name].get_simple_consumer(
30-
consumer_timeout_ms=1000)
3130

3231
@classmethod
3332
def tearDownClass(cls):
34-
cls.consumer.stop()
3533
stop_cluster(cls.kafka)
3634

3735
def _get_producer(self, **kwargs):
3836
topic = self.client.topics[self.topic_name]
3937
return topic.get_producer(use_rdkafka=self.USE_RDKAFKA, **kwargs)
4038

39+
def _get_consumer(self):
40+
return self.client.topics[self.topic_name].get_simple_consumer(
41+
consumer_timeout_ms=1000,
42+
auto_offset_reset=OffsetType.LATEST,
43+
reset_offset_on_start=True,
44+
)
45+
4146
def test_produce(self):
4247
# unique bytes, just to be absolutely sure we're not fetching data
4348
# produced in a previous test
4449
payload = uuid4().bytes
4550

51+
consumer = self._get_consumer()
52+
4653
prod = self._get_producer(sync=True, min_queued_messages=1)
4754
prod.produce(payload)
4855

49-
message = self.consumer.consume()
56+
message = consumer.consume()
5057
assert message.value == payload
5158

5259
def test_sync_produce_raises(self):
@@ -60,41 +67,38 @@ def test_produce_hashing_partitioner(self):
6067
# produced in a previous test
6168
payload = uuid4().bytes
6269

70+
consumer = self._get_consumer()
71+
6372
prod = self._get_producer(
6473
sync=True,
6574
min_queued_messages=1,
6675
partitioner=hashing_partitioner)
6776
prod.produce(payload, partition_key=b"dummy")
6877

6978
# set a timeout so we don't wait forever if we break producer code
70-
message = self.consumer.consume()
79+
message = consumer.consume()
7180
assert message.value == payload
7281

7382
def test_async_produce(self):
7483
payload = uuid4().bytes
7584

85+
consumer = self._get_consumer()
86+
7687
prod = self._get_producer(min_queued_messages=1, delivery_reports=True)
7788
prod.produce(payload)
7889

7990
report = prod.get_delivery_report()
8091
self.assertEqual(report[0].value, payload)
8192
self.assertIsNone(report[1])
8293

83-
message = self.consumer.consume()
94+
message = consumer.consume()
8495
assert message.value == payload
8596

8697
def test_recover_disconnected(self):
8798
"""Test our retry-loop with a recoverable error"""
8899
payload = uuid4().bytes
89100
prod = self._get_producer(min_queued_messages=1, delivery_reports=True)
90-
91-
# We must stop the consumer for this test, to ensure that it is the
92-
# producer that will encounter the disconnected brokers and initiate
93-
# a cluster update
94-
self.consumer.stop()
95-
for t in self.consumer._fetch_workers:
96-
t.join()
97-
part_offsets = self.consumer.held_offsets
101+
consumer = self._get_consumer()
98102

99103
for broker in self.client.brokers.values():
100104
broker._connection.disconnect()
@@ -103,56 +107,55 @@ def test_recover_disconnected(self):
103107
report = prod.get_delivery_report()
104108
self.assertIsNone(report[1])
105109

106-
self.consumer.start()
107-
self.consumer.reset_offsets(
108-
# This is just a reset_offsets, but works around issue #216:
109-
[(self.consumer.partitions[pid], offset if offset != -1 else -2)
110-
for pid, offset in part_offsets.items()])
111-
message = self.consumer.consume()
110+
message = consumer.consume()
112111
self.assertEqual(message.value, payload)
113112

114113
def test_async_produce_context(self):
115114
"""Ensure that the producer works as a context manager"""
116115
payload = uuid4().bytes
117116

117+
consumer = self._get_consumer()
118118
with self._get_producer(min_queued_messages=1) as producer:
119119
producer.produce(payload)
120120

121-
message = self.consumer.consume()
121+
message = consumer.consume()
122122
assert message.value == payload
123123

124124
def test_async_produce_queue_full(self):
125125
"""Ensure that the producer raises an error when its queue is full"""
126+
consumer = self._get_consumer()
126127
with self._get_producer(block_on_queue_full=False,
127128
max_queued_messages=1,
128129
linger_ms=1000) as producer:
129130
with self.assertRaises(ProducerQueueFullError):
130131
while True:
131132
producer.produce(uuid4().bytes)
132-
while self.consumer.consume() is not None:
133+
while consumer.consume() is not None:
133134
time.sleep(.05)
134135

135136
def test_async_produce_lingers(self):
136137
"""Ensure that the context manager waits for linger_ms milliseconds"""
137138
linger = 3
139+
consumer = self._get_consumer()
138140
with self._get_producer(linger_ms=linger * 1000) as producer:
139141
start = time.time()
140142
producer.produce(uuid4().bytes)
141143
producer.produce(uuid4().bytes)
142144
self.assertTrue(int(time.time() - start) >= int(linger))
143-
self.consumer.consume()
144-
self.consumer.consume()
145+
consumer.consume()
146+
consumer.consume()
145147

146148
def test_async_produce_thread_exception(self):
147149
"""Ensure that an exception on a worker thread is raised to the main thread"""
150+
consumer = self._get_consumer()
148151
with self.assertRaises(AttributeError):
149152
with self._get_producer(min_queued_messages=1) as producer:
150153
# get some dummy data into the queue that will cause a crash
151154
# when flushed:
152155
msg = Message("stuff", partition_id=0)
153156
del msg.value
154157
producer._produce(msg)
155-
while self.consumer.consume() is not None:
158+
while consumer.consume() is not None:
156159
time.sleep(.05)
157160

158161
def test_required_acks(self):
@@ -171,13 +174,14 @@ def test_required_acks(self):
171174

172175
def test_null_payloads(self):
173176
"""Test that None is accepted as a null payload"""
177+
consumer = self._get_consumer()
174178
prod = self._get_producer(sync=True, min_queued_messages=1)
175179
prod.produce(None)
176-
self.assertIsNone(self.consumer.consume().value)
180+
self.assertIsNone(consumer.consume().value)
177181
prod.produce(None, partition_key=b"whatever")
178-
self.assertIsNone(self.consumer.consume().value)
182+
self.assertIsNone(consumer.consume().value)
179183
prod.produce(b"") # empty string should be distinguished from None
180-
self.assertEqual(b"", self.consumer.consume().value)
184+
self.assertEqual(b"", consumer.consume().value)
181185

182186
def test_owned_broker_flush_message_larger_then_max_request_size(self):
183187
"""Test that producer batches messages into the batches no larger then
@@ -239,18 +243,19 @@ def test_async_produce_compression_large_message(self):
239243
# TODO: make payload size bigger once pypy snappy compression issue is
240244
# fixed
241245
large_payload = b''.join([uuid4().bytes for i in range(5)])
246+
consumer = self._get_consumer()
242247

243248
prod = self._get_producer(
244-
compression=CompressionType.SNAPPY,
245-
delivery_reports=True
246-
)
249+
compression=CompressionType.SNAPPY,
250+
delivery_reports=True
251+
)
247252
prod.produce(large_payload)
248253

249254
report = prod.get_delivery_report()
250255
self.assertEqual(report[0].value, large_payload)
251256
self.assertIsNone(report[1])
252257

253-
message = self.consumer.consume()
258+
message = consumer.consume()
254259
assert message.value == large_payload
255260

256261
for i in range(10):
@@ -259,6 +264,7 @@ def test_async_produce_compression_large_message(self):
259264
# use retry logic to loop over delivery reports and ensure we can
260265
# produce a group of large messages
261266
reports = []
267+
262268
def ensure_all_messages_produced():
263269
report = prod.get_delivery_report()
264270
reports.append(report)
@@ -271,15 +277,17 @@ def ensure_all_messages_produced():
271277

272278
# cleanup and consumer all messages
273279
msgs = []
280+
274281
def ensure_all_messages_consumed():
275-
msg = self.consumer.consume()
282+
msg = consumer.consume()
276283
if msg:
277284
msgs.append(msg)
278285
assert len(msgs) == 10
279286
retry(ensure_all_messages_consumed, retry_time=15)
280287

281288
def test_async_produce_large_message(self):
282289

290+
consumer = self._get_consumer()
283291
large_payload = b''.join([uuid4().bytes for i in range(50000)])
284292
assert len(large_payload) / 1024 / 1024 < 1.0
285293

@@ -290,7 +298,7 @@ def test_async_produce_large_message(self):
290298
self.assertEqual(report[0].value, large_payload)
291299
self.assertIsNone(report[1])
292300

293-
message = self.consumer.consume()
301+
message = consumer.consume()
294302
assert message.value == large_payload
295303

296304
for i in range(10):
@@ -299,6 +307,7 @@ def test_async_produce_large_message(self):
299307
# use retry logic to loop over delivery reports and ensure we can
300308
# produce a group of large messages
301309
reports = []
310+
302311
def ensure_all_messages_produced():
303312
report = prod.get_delivery_report()
304313
reports.append(report)
@@ -311,8 +320,9 @@ def ensure_all_messages_produced():
311320

312321
# cleanup and consumer all messages
313322
msgs = []
323+
314324
def ensure_all_messages_consumed():
315-
msg = self.consumer.consume()
325+
msg = consumer.consume()
316326
if msg:
317327
msgs.append(msg)
318328
assert len(msgs) == 10

0 commit comments

Comments
 (0)