77from uuid import uuid4
88
99from pykafka import KafkaClient
10+ from pykafka .common import OffsetType
1011from pykafka .exceptions import MessageSizeTooLarge , ProducerQueueFullError
1112from pykafka .partitioners import hashing_partitioner
1213from pykafka .protocol import Message
@@ -26,27 +27,33 @@ def setUpClass(cls):
2627 cls .topic_name = b'test-data'
2728 cls .kafka .create_topic (cls .topic_name , 3 , 2 )
2829 cls .client = KafkaClient (cls .kafka .brokers , use_greenlets = cls .USE_GEVENT )
29- cls .consumer = cls .client .topics [cls .topic_name ].get_simple_consumer (
30- consumer_timeout_ms = 1000 )
3130
3231 @classmethod
3332 def tearDownClass (cls ):
34- cls .consumer .stop ()
3533 stop_cluster (cls .kafka )
3634
3735 def _get_producer (self , ** kwargs ):
3836 topic = self .client .topics [self .topic_name ]
3937 return topic .get_producer (use_rdkafka = self .USE_RDKAFKA , ** kwargs )
4038
39+ def _get_consumer (self ):
40+ return self .client .topics [self .topic_name ].get_simple_consumer (
41+ consumer_timeout_ms = 1000 ,
42+ auto_offset_reset = OffsetType .LATEST ,
43+ reset_offset_on_start = True ,
44+ )
45+
4146 def test_produce (self ):
4247 # unique bytes, just to be absolutely sure we're not fetching data
4348 # produced in a previous test
4449 payload = uuid4 ().bytes
4550
51+ consumer = self ._get_consumer ()
52+
4653 prod = self ._get_producer (sync = True , min_queued_messages = 1 )
4754 prod .produce (payload )
4855
49- message = self . consumer .consume ()
56+ message = consumer .consume ()
5057 assert message .value == payload
5158
5259 def test_sync_produce_raises (self ):
@@ -60,41 +67,38 @@ def test_produce_hashing_partitioner(self):
6067 # produced in a previous test
6168 payload = uuid4 ().bytes
6269
70+ consumer = self ._get_consumer ()
71+
6372 prod = self ._get_producer (
6473 sync = True ,
6574 min_queued_messages = 1 ,
6675 partitioner = hashing_partitioner )
6776 prod .produce (payload , partition_key = b"dummy" )
6877
6978 # set a timeout so we don't wait forever if we break producer code
70- message = self . consumer .consume ()
79+ message = consumer .consume ()
7180 assert message .value == payload
7281
7382 def test_async_produce (self ):
7483 payload = uuid4 ().bytes
7584
85+ consumer = self ._get_consumer ()
86+
7687 prod = self ._get_producer (min_queued_messages = 1 , delivery_reports = True )
7788 prod .produce (payload )
7889
7990 report = prod .get_delivery_report ()
8091 self .assertEqual (report [0 ].value , payload )
8192 self .assertIsNone (report [1 ])
8293
83- message = self . consumer .consume ()
94+ message = consumer .consume ()
8495 assert message .value == payload
8596
8697 def test_recover_disconnected (self ):
8798 """Test our retry-loop with a recoverable error"""
8899 payload = uuid4 ().bytes
89100 prod = self ._get_producer (min_queued_messages = 1 , delivery_reports = True )
90-
91- # We must stop the consumer for this test, to ensure that it is the
92- # producer that will encounter the disconnected brokers and initiate
93- # a cluster update
94- self .consumer .stop ()
95- for t in self .consumer ._fetch_workers :
96- t .join ()
97- part_offsets = self .consumer .held_offsets
101+ consumer = self ._get_consumer ()
98102
99103 for broker in self .client .brokers .values ():
100104 broker ._connection .disconnect ()
@@ -103,56 +107,55 @@ def test_recover_disconnected(self):
103107 report = prod .get_delivery_report ()
104108 self .assertIsNone (report [1 ])
105109
106- self .consumer .start ()
107- self .consumer .reset_offsets (
108- # This is just a reset_offsets, but works around issue #216:
109- [(self .consumer .partitions [pid ], offset if offset != - 1 else - 2 )
110- for pid , offset in part_offsets .items ()])
111- message = self .consumer .consume ()
110+ message = consumer .consume ()
112111 self .assertEqual (message .value , payload )
113112
114113 def test_async_produce_context (self ):
115114 """Ensure that the producer works as a context manager"""
116115 payload = uuid4 ().bytes
117116
117+ consumer = self ._get_consumer ()
118118 with self ._get_producer (min_queued_messages = 1 ) as producer :
119119 producer .produce (payload )
120120
121- message = self . consumer .consume ()
121+ message = consumer .consume ()
122122 assert message .value == payload
123123
124124 def test_async_produce_queue_full (self ):
125125 """Ensure that the producer raises an error when its queue is full"""
126+ consumer = self ._get_consumer ()
126127 with self ._get_producer (block_on_queue_full = False ,
127128 max_queued_messages = 1 ,
128129 linger_ms = 1000 ) as producer :
129130 with self .assertRaises (ProducerQueueFullError ):
130131 while True :
131132 producer .produce (uuid4 ().bytes )
132- while self . consumer .consume () is not None :
133+ while consumer .consume () is not None :
133134 time .sleep (.05 )
134135
135136 def test_async_produce_lingers (self ):
136137 """Ensure that the context manager waits for linger_ms milliseconds"""
137138 linger = 3
139+ consumer = self ._get_consumer ()
138140 with self ._get_producer (linger_ms = linger * 1000 ) as producer :
139141 start = time .time ()
140142 producer .produce (uuid4 ().bytes )
141143 producer .produce (uuid4 ().bytes )
142144 self .assertTrue (int (time .time () - start ) >= int (linger ))
143- self . consumer .consume ()
144- self . consumer .consume ()
145+ consumer .consume ()
146+ consumer .consume ()
145147
146148 def test_async_produce_thread_exception (self ):
147149 """Ensure that an exception on a worker thread is raised to the main thread"""
150+ consumer = self ._get_consumer ()
148151 with self .assertRaises (AttributeError ):
149152 with self ._get_producer (min_queued_messages = 1 ) as producer :
150153 # get some dummy data into the queue that will cause a crash
151154 # when flushed:
152155 msg = Message ("stuff" , partition_id = 0 )
153156 del msg .value
154157 producer ._produce (msg )
155- while self . consumer .consume () is not None :
158+ while consumer .consume () is not None :
156159 time .sleep (.05 )
157160
158161 def test_required_acks (self ):
@@ -171,13 +174,14 @@ def test_required_acks(self):
171174
172175 def test_null_payloads (self ):
173176 """Test that None is accepted as a null payload"""
177+ consumer = self ._get_consumer ()
174178 prod = self ._get_producer (sync = True , min_queued_messages = 1 )
175179 prod .produce (None )
176- self .assertIsNone (self . consumer .consume ().value )
180+ self .assertIsNone (consumer .consume ().value )
177181 prod .produce (None , partition_key = b"whatever" )
178- self .assertIsNone (self . consumer .consume ().value )
182+ self .assertIsNone (consumer .consume ().value )
179183 prod .produce (b"" ) # empty string should be distinguished from None
180- self .assertEqual (b"" , self . consumer .consume ().value )
184+ self .assertEqual (b"" , consumer .consume ().value )
181185
182186 def test_owned_broker_flush_message_larger_then_max_request_size (self ):
183187 """Test that producer batches messages into the batches no larger then
@@ -239,18 +243,19 @@ def test_async_produce_compression_large_message(self):
239243 # TODO: make payload size bigger once pypy snappy compression issue is
240244 # fixed
241245 large_payload = b'' .join ([uuid4 ().bytes for i in range (5 )])
246+ consumer = self ._get_consumer ()
242247
243248 prod = self ._get_producer (
244- compression = CompressionType .SNAPPY ,
245- delivery_reports = True
246- )
249+ compression = CompressionType .SNAPPY ,
250+ delivery_reports = True
251+ )
247252 prod .produce (large_payload )
248253
249254 report = prod .get_delivery_report ()
250255 self .assertEqual (report [0 ].value , large_payload )
251256 self .assertIsNone (report [1 ])
252257
253- message = self . consumer .consume ()
258+ message = consumer .consume ()
254259 assert message .value == large_payload
255260
256261 for i in range (10 ):
@@ -259,6 +264,7 @@ def test_async_produce_compression_large_message(self):
259264 # use retry logic to loop over delivery reports and ensure we can
260265 # produce a group of large messages
261266 reports = []
267+
262268 def ensure_all_messages_produced ():
263269 report = prod .get_delivery_report ()
264270 reports .append (report )
@@ -271,15 +277,17 @@ def ensure_all_messages_produced():
271277
272278 # cleanup and consumer all messages
273279 msgs = []
280+
274281 def ensure_all_messages_consumed ():
275- msg = self . consumer .consume ()
282+ msg = consumer .consume ()
276283 if msg :
277284 msgs .append (msg )
278285 assert len (msgs ) == 10
279286 retry (ensure_all_messages_consumed , retry_time = 15 )
280287
281288 def test_async_produce_large_message (self ):
282289
290+ consumer = self ._get_consumer ()
283291 large_payload = b'' .join ([uuid4 ().bytes for i in range (50000 )])
284292 assert len (large_payload ) / 1024 / 1024 < 1.0
285293
@@ -290,7 +298,7 @@ def test_async_produce_large_message(self):
290298 self .assertEqual (report [0 ].value , large_payload )
291299 self .assertIsNone (report [1 ])
292300
293- message = self . consumer .consume ()
301+ message = consumer .consume ()
294302 assert message .value == large_payload
295303
296304 for i in range (10 ):
@@ -299,6 +307,7 @@ def test_async_produce_large_message(self):
299307 # use retry logic to loop over delivery reports and ensure we can
300308 # produce a group of large messages
301309 reports = []
310+
302311 def ensure_all_messages_produced ():
303312 report = prod .get_delivery_report ()
304313 reports .append (report )
@@ -311,8 +320,9 @@ def ensure_all_messages_produced():
311320
312321 # cleanup and consumer all messages
313322 msgs = []
323+
314324 def ensure_all_messages_consumed ():
315- msg = self . consumer .consume ()
325+ msg = consumer .consume ()
316326 if msg :
317327 msgs .append (msg )
318328 assert len (msgs ) == 10
0 commit comments