1
1
from __future__ import absolute_import
2
2
3
- from collections import defaultdict
4
3
from itertools import izip_longest , repeat
5
4
import logging
6
5
import time
@@ -235,6 +234,12 @@ def __init__(self, client, group, topic, auto_commit=True, partitions=None,
235
234
buffer_size = FETCH_BUFFER_SIZE_BYTES ,
236
235
max_buffer_size = MAX_FETCH_BUFFER_SIZE_BYTES ,
237
236
iter_timeout = None ):
237
+ super (SimpleConsumer , self ).__init__ (
238
+ client , group , topic ,
239
+ partitions = partitions ,
240
+ auto_commit = auto_commit ,
241
+ auto_commit_every_n = auto_commit_every_n ,
242
+ auto_commit_every_t = auto_commit_every_t )
238
243
239
244
if max_buffer_size is not None and buffer_size > max_buffer_size :
240
245
raise ValueError ("buffer_size (%d) is greater than "
@@ -245,17 +250,10 @@ def __init__(self, client, group, topic, auto_commit=True, partitions=None,
245
250
self .partition_info = False # Do not return partition info in msgs
246
251
self .fetch_max_wait_time = FETCH_MAX_WAIT_TIME
247
252
self .fetch_min_bytes = fetch_size_bytes
248
- self .fetch_started = defaultdict ( bool ) # defaults to false
253
+ self .fetch_offsets = self . offsets . copy ()
249
254
self .iter_timeout = iter_timeout
250
255
self .queue = Queue ()
251
256
252
- super (SimpleConsumer , self ).__init__ (
253
- client , group , topic ,
254
- partitions = partitions ,
255
- auto_commit = auto_commit ,
256
- auto_commit_every_n = auto_commit_every_n ,
257
- auto_commit_every_t = auto_commit_every_t )
258
-
259
257
def __repr__ (self ):
260
258
return '<SimpleConsumer group=%s, topic=%s, partitions=%s>' % \
261
259
(self .group , self .topic , str (self .offsets .keys ()))
@@ -305,6 +303,10 @@ def seek(self, offset, whence):
305
303
else :
306
304
raise ValueError ("Unexpected value for `whence`, %d" % whence )
307
305
306
+ # Reset queue and fetch offsets since they are invalid
307
+ self .fetch_offsets = self .offsets .copy ()
308
+ self .queue = Queue ()
309
+
308
310
def get_messages (self , count = 1 , block = True , timeout = 0.1 ):
309
311
"""
310
312
Fetch the specified number of messages
@@ -316,33 +318,69 @@ def get_messages(self, count=1, block=True, timeout=0.1):
316
318
it will block forever.
317
319
"""
318
320
messages = []
319
- if timeout :
321
+ if timeout is not None :
320
322
max_time = time .time () + timeout
321
323
324
+ new_offsets = {}
322
325
while count > 0 and (timeout is None or timeout > 0 ):
323
- message = self .get_message (block , timeout )
324
- if message :
325
- messages .append (message )
326
+ result = self ._get_message (block , timeout , get_partition_info = True ,
327
+ update_offset = False )
328
+ if result :
329
+ partition , message = result
330
+ if self .partition_info :
331
+ messages .append (result )
332
+ else :
333
+ messages .append (message )
334
+ new_offsets [partition ] = message .offset + 1
326
335
count -= 1
327
336
else :
328
337
# Ran out of messages for the last request.
329
338
if not block :
330
339
# If we're not blocking, break.
331
340
break
332
- if timeout :
341
+ if timeout is not None :
333
342
# If we're blocking and have a timeout, reduce it to the
334
343
# appropriate value
335
344
timeout = max_time - time .time ()
336
345
346
+ # Update and commit offsets if necessary
347
+ self .offsets .update (new_offsets )
348
+ self .count_since_commit += len (messages )
349
+ self ._auto_commit ()
337
350
return messages
338
351
339
- def get_message (self , block = True , timeout = 0.1 ):
352
+ def get_message (self , block = True , timeout = 0.1 , get_partition_info = None ):
353
+ return self ._get_message (block , timeout , get_partition_info )
354
+
355
+ def _get_message (self , block = True , timeout = 0.1 , get_partition_info = None ,
356
+ update_offset = True ):
357
+ """
358
+ If no messages can be fetched, returns None.
359
+ If get_partition_info is None, it defaults to self.partition_info
360
+ If get_partition_info is True, returns (partition, message)
361
+ If get_partition_info is False, returns message
362
+ """
340
363
if self .queue .empty ():
341
364
# We're out of messages, go grab some more.
342
365
with FetchContext (self , block , timeout ):
343
366
self ._fetch ()
344
367
try :
345
- return self .queue .get_nowait ()
368
+ partition , message = self .queue .get_nowait ()
369
+
370
+ if update_offset :
371
+ # Update partition offset
372
+ self .offsets [partition ] = message .offset + 1
373
+
374
+ # Count, check and commit messages if necessary
375
+ self .count_since_commit += 1
376
+ self ._auto_commit ()
377
+
378
+ if get_partition_info is None :
379
+ get_partition_info = self .partition_info
380
+ if get_partition_info :
381
+ return partition , message
382
+ else :
383
+ return message
346
384
except Empty :
347
385
return None
348
386
@@ -367,11 +405,11 @@ def __iter__(self):
367
405
def _fetch (self ):
368
406
# Create fetch request payloads for all the partitions
369
407
requests = []
370
- partitions = self .offsets .keys ()
408
+ partitions = self .fetch_offsets .keys ()
371
409
while partitions :
372
410
for partition in partitions :
373
411
requests .append (FetchRequest (self .topic , partition ,
374
- self .offsets [partition ],
412
+ self .fetch_offsets [partition ],
375
413
self .buffer_size ))
376
414
# Send request
377
415
responses = self .client .send_fetch_request (
@@ -384,18 +422,9 @@ def _fetch(self):
384
422
partition = resp .partition
385
423
try :
386
424
for message in resp .messages :
387
- # Update partition offset
388
- self .offsets [partition ] = message .offset + 1
389
-
390
- # Count, check and commit messages if necessary
391
- self .count_since_commit += 1
392
- self ._auto_commit ()
393
-
394
425
# Put the message in our queue
395
- if self .partition_info :
396
- self .queue .put ((partition , message ))
397
- else :
398
- self .queue .put (message )
426
+ self .queue .put ((partition , message ))
427
+ self .fetch_offsets [partition ] = message .offset + 1
399
428
except ConsumerFetchSizeTooSmall , e :
400
429
if (self .max_buffer_size is not None and
401
430
self .buffer_size == self .max_buffer_size ):
@@ -585,12 +614,11 @@ def __iter__(self):
585
614
break
586
615
587
616
# Count, check and commit messages if necessary
588
- self .offsets [partition ] = message .offset
617
+ self .offsets [partition ] = message .offset + 1
589
618
self .start .clear ()
590
- yield message
591
-
592
619
self .count_since_commit += 1
593
620
self ._auto_commit ()
621
+ yield message
594
622
595
623
self .start .clear ()
596
624
@@ -613,9 +641,10 @@ def get_messages(self, count=1, block=True, timeout=10):
613
641
self .size .value = count
614
642
self .pause .clear ()
615
643
616
- if timeout :
644
+ if timeout is not None :
617
645
max_time = time .time () + timeout
618
646
647
+ new_offsets = {}
619
648
while count > 0 and (timeout is None or timeout > 0 ):
620
649
# Trigger consumption only if the queue is empty
621
650
# By doing this, we will ensure that consumers do not
@@ -630,16 +659,18 @@ def get_messages(self, count=1, block=True, timeout=10):
630
659
break
631
660
632
661
messages .append (message )
633
-
634
- # Count, check and commit messages if necessary
635
- self .offsets [partition ] = message .offset
636
- self .count_since_commit += 1
637
- self ._auto_commit ()
662
+ new_offsets [partition ] = message .offset + 1
638
663
count -= 1
639
- timeout = max_time - time .time ()
664
+ if timeout is not None :
665
+ timeout = max_time - time .time ()
640
666
641
667
self .size .value = 0
642
668
self .start .clear ()
643
669
self .pause .set ()
644
670
671
+ # Update and commit offsets if necessary
672
+ self .offsets .update (new_offsets )
673
+ self .count_since_commit += len (messages )
674
+ self ._auto_commit ()
675
+
645
676
return messages
0 commit comments