Skip to content

Commit 70574d1

Browse files
authored
KIP-74: Manage assigned partition order in consumer (#2562)
1 parent bd24486 commit 70574d1

File tree

3 files changed

+96
-75
lines changed

3 files changed

+96
-75
lines changed

kafka/consumer/fetcher.py

+64-52
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import copy
55
import itertools
66
import logging
7-
import random
87
import sys
98
import time
109

@@ -57,7 +56,6 @@ class Fetcher(six.Iterator):
5756
'max_partition_fetch_bytes': 1048576,
5857
'max_poll_records': sys.maxsize,
5958
'check_crcs': True,
60-
'iterator_refetch_records': 1, # undocumented -- interface may change
6159
'metric_group_prefix': 'consumer',
6260
'retry_backoff_ms': 100,
6361
'enable_incremental_fetch_sessions': True,
@@ -380,10 +378,13 @@ def _append(self, drained, part, max_records, update_offsets):
380378
# as long as the partition is still assigned
381379
position = self._subscriptions.assignment[tp].position
382380
if part.next_fetch_offset == position.offset:
383-
part_records = part.take(max_records)
384381
log.debug("Returning fetched records at offset %d for assigned"
385382
" partition %s", position.offset, tp)
386-
drained[tp].extend(part_records)
383+
part_records = part.take(max_records)
384+
# list.extend([]) is a noop, but because drained is a defaultdict
385+
# we should avoid initializing the default list unless there are records
386+
if part_records:
387+
drained[tp].extend(part_records)
387388
# We want to increment subscription position if (1) we're using consumer.poll(),
388389
# or (2) we didn't return any records (consumer iterator will update position
389390
# when each message is yielded). There may be edge cases where we re-fetch records
@@ -562,13 +563,11 @@ def _handle_list_offsets_response(self, future, response):
562563
def _fetchable_partitions(self):
563564
fetchable = self._subscriptions.fetchable_partitions()
564565
# do not fetch a partition if we have a pending fetch response to process
566+
discard = {fetch.topic_partition for fetch in self._completed_fetches}
565567
current = self._next_partition_records
566-
pending = copy.copy(self._completed_fetches)
567568
if current:
568-
fetchable.discard(current.topic_partition)
569-
for fetch in pending:
570-
fetchable.discard(fetch.topic_partition)
571-
return fetchable
569+
discard.add(current.topic_partition)
570+
return [tp for tp in fetchable if tp not in discard]
572571

573572
def _create_fetch_requests(self):
574573
"""Create fetch requests for all assigned partitions, grouped by node.
@@ -581,7 +580,7 @@ def _create_fetch_requests(self):
581580
# create the fetch info as a dict of lists of partition info tuples
582581
# which can be passed to FetchRequest() via .items()
583582
version = self._client.api_version(FetchRequest, max_version=10)
584-
fetchable = collections.defaultdict(dict)
583+
fetchable = collections.defaultdict(collections.OrderedDict)
585584

586585
for partition in self._fetchable_partitions():
587586
node_id = self._client.cluster.leader_for_partition(partition)
@@ -695,10 +694,7 @@ def _handle_fetch_response(self, node_id, fetch_offsets, send_time, response):
695694
for partition_data in partitions])
696695
metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions)
697696

698-
# randomized ordering should improve balance for short-lived consumers
699-
random.shuffle(response.topics)
700697
for topic, partitions in response.topics:
701-
random.shuffle(partitions)
702698
for partition_data in partitions:
703699
tp = TopicPartition(topic, partition_data[0])
704700
fetch_offset = fetch_offsets[tp]
@@ -733,8 +729,6 @@ def _parse_fetched_data(self, completed_fetch):
733729
" since it is no longer fetchable", tp)
734730

735731
elif error_type is Errors.NoError:
736-
self._subscriptions.assignment[tp].highwater = highwater
737-
738732
# we are interested in this fetch only if the beginning
739733
# offset (of the *request*) matches the current consumed position
740734
# Note that the *response* may return a messageset that starts
@@ -748,30 +742,35 @@ def _parse_fetched_data(self, completed_fetch):
748742
return None
749743

750744
records = MemoryRecords(completed_fetch.partition_data[-1])
751-
if records.has_next():
752-
log.debug("Adding fetched record for partition %s with"
753-
" offset %d to buffered record list", tp,
754-
position.offset)
755-
parsed_records = self.PartitionRecords(fetch_offset, tp, records,
756-
self.config['key_deserializer'],
757-
self.config['value_deserializer'],
758-
self.config['check_crcs'],
759-
completed_fetch.metric_aggregator)
760-
return parsed_records
761-
elif records.size_in_bytes() > 0:
762-
# we did not read a single message from a non-empty
763-
# buffer because that message's size is larger than
764-
# fetch size, in this case record this exception
765-
record_too_large_partitions = {tp: fetch_offset}
766-
raise RecordTooLargeError(
767-
"There are some messages at [Partition=Offset]: %s "
768-
" whose size is larger than the fetch size %s"
769-
" and hence cannot be ever returned."
770-
" Increase the fetch size, or decrease the maximum message"
771-
" size the broker will allow." % (
772-
record_too_large_partitions,
773-
self.config['max_partition_fetch_bytes']),
774-
record_too_large_partitions)
745+
log.debug("Preparing to read %s bytes of data for partition %s with offset %d",
746+
records.size_in_bytes(), tp, fetch_offset)
747+
parsed_records = self.PartitionRecords(fetch_offset, tp, records,
748+
self.config['key_deserializer'],
749+
self.config['value_deserializer'],
750+
self.config['check_crcs'],
751+
completed_fetch.metric_aggregator,
752+
self._on_partition_records_drain)
753+
if not records.has_next() and records.size_in_bytes() > 0:
754+
if completed_fetch.response_version < 3:
755+
# Implement the pre KIP-74 behavior of throwing a RecordTooLargeException.
756+
record_too_large_partitions = {tp: fetch_offset}
757+
raise RecordTooLargeError(
758+
"There are some messages at [Partition=Offset]: %s "
759+
" whose size is larger than the fetch size %s"
760+
" and hence cannot be ever returned. Please condier upgrading your broker to 0.10.1.0 or"
761+
" newer to avoid this issue. Alternatively, increase the fetch size on the client (using"
762+
" max_partition_fetch_bytes)" % (
763+
record_too_large_partitions,
764+
self.config['max_partition_fetch_bytes']),
765+
record_too_large_partitions)
766+
else:
767+
# This should not happen with brokers that support FetchRequest/Response V3 or higher (i.e. KIP-74)
768+
raise Errors.KafkaError("Failed to make progress reading messages at %s=%s."
769+
" Received a non-empty fetch response from the server, but no"
770+
" complete records were found." % (tp, fetch_offset))
771+
772+
if highwater >= 0:
773+
self._subscriptions.assignment[tp].highwater = highwater
775774

776775
elif error_type in (Errors.NotLeaderForPartitionError,
777776
Errors.ReplicaNotAvailableError,
@@ -805,14 +804,25 @@ def _parse_fetched_data(self, completed_fetch):
805804
if parsed_records is None:
806805
completed_fetch.metric_aggregator.record(tp, 0, 0)
807806

808-
return None
807+
if error_type is not Errors.NoError:
808+
# we move the partition to the end if there was an error. This way, it's more likely that partitions for
809+
# the same topic can remain together (allowing for more efficient serialization).
810+
self._subscriptions.move_partition_to_end(tp)
811+
812+
return parsed_records
813+
814+
def _on_partition_records_drain(self, partition_records):
815+
# we move the partition to the end if we received some bytes. This way, it's more likely that partitions
816+
# for the same topic can remain together (allowing for more efficient serialization).
817+
if partition_records.bytes_read > 0:
818+
self._subscriptions.move_partition_to_end(partition_records.topic_partition)
809819

810820
def close(self):
811821
if self._next_partition_records is not None:
812822
self._next_partition_records.drain()
813823

814824
class PartitionRecords(object):
815-
def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializer, check_crcs, metric_aggregator):
825+
def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializer, check_crcs, metric_aggregator, on_drain):
816826
self.fetch_offset = fetch_offset
817827
self.topic_partition = tp
818828
self.leader_epoch = -1
@@ -824,6 +834,7 @@ def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializ
824834
self.record_iterator = itertools.dropwhile(
825835
self._maybe_skip_record,
826836
self._unpack_records(tp, records, key_deserializer, value_deserializer))
837+
self.on_drain = on_drain
827838

828839
def _maybe_skip_record(self, record):
829840
# When fetching an offset that is in the middle of a
@@ -845,6 +856,7 @@ def drain(self):
845856
if self.record_iterator is not None:
846857
self.record_iterator = None
847858
self.metric_aggregator.record(self.topic_partition, self.bytes_read, self.records_read)
859+
self.on_drain(self)
848860

849861
def take(self, n=None):
850862
return list(itertools.islice(self.record_iterator, 0, n))
@@ -943,6 +955,13 @@ def __init__(self, node_id):
943955
self.session_partitions = {}
944956

945957
def build_next(self, next_partitions):
958+
"""
959+
Arguments:
960+
next_partitions (dict): TopicPartition -> TopicPartitionState
961+
962+
Returns:
963+
FetchRequestData
964+
"""
946965
if self.next_metadata.is_full:
947966
log.debug("Built full fetch %s for node %s with %s partition(s).",
948967
self.next_metadata, self.node_id, len(next_partitions))
@@ -965,8 +984,8 @@ def build_next(self, next_partitions):
965984
altered.add(tp)
966985

967986
log.debug("Built incremental fetch %s for node %s. Added %s, altered %s, removed %s out of %s",
968-
self.next_metadata, self.node_id, added, altered, removed, self.session_partitions.keys())
969-
to_send = {tp: next_partitions[tp] for tp in (added | altered)}
987+
self.next_metadata, self.node_id, added, altered, removed, self.session_partitions.keys())
988+
to_send = collections.OrderedDict({tp: next_partitions[tp] for tp in next_partitions if tp in (added | altered)})
970989
return FetchRequestData(to_send, removed, self.next_metadata)
971990

972991
def handle_response(self, response):
@@ -1106,18 +1125,11 @@ def epoch(self):
11061125
@property
11071126
def to_send(self):
11081127
# Return as list of [(topic, [(partition, ...), ...]), ...]
1109-
# so it an be passed directly to encoder
1128+
# so it can be passed directly to encoder
11101129
partition_data = collections.defaultdict(list)
11111130
for tp, partition_info in six.iteritems(self._to_send):
11121131
partition_data[tp.topic].append(partition_info)
1113-
# As of version == 3 partitions will be returned in order as
1114-
# they are requested, so to avoid starvation with
1115-
# `fetch_max_bytes` option we need this shuffle
1116-
# NOTE: we do have partition_data in random order due to usage
1117-
# of unordered structures like dicts, but that does not
1118-
# guarantee equal distribution, and starting in Python3.6
1119-
# dicts retain insert order.
1120-
return random.sample(list(partition_data.items()), k=len(partition_data))
1132+
return list(partition_data.items())
11211133

11221134
@property
11231135
def to_forget(self):

kafka/consumer/subscription_state.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import absolute_import
22

33
import abc
4+
from collections import defaultdict, OrderedDict
45
try:
56
from collections import Sequence
67
except ImportError:
78
from collections.abc import Sequence
89
import logging
10+
import random
911
import re
1012

1113
from kafka.vendor import six
@@ -68,7 +70,7 @@ def __init__(self, offset_reset_strategy='earliest'):
6870
self.subscribed_pattern = None # regex str or None
6971
self._group_subscription = set()
7072
self._user_assignment = set()
71-
self.assignment = dict()
73+
self.assignment = OrderedDict()
7274
self.listener = None
7375

7476
# initialize to true for the consumers to fetch offset upon starting up
@@ -200,14 +202,8 @@ def assign_from_user(self, partitions):
200202

201203
if self._user_assignment != set(partitions):
202204
self._user_assignment = set(partitions)
203-
204-
for partition in partitions:
205-
if partition not in self.assignment:
206-
self._add_assigned_partition(partition)
207-
208-
for tp in set(self.assignment.keys()) - self._user_assignment:
209-
del self.assignment[tp]
210-
205+
self._set_assignment({partition: self.assignment.get(partition, TopicPartitionState())
206+
for partition in partitions})
211207
self.needs_fetch_committed_offsets = True
212208

213209
def assign_from_subscribed(self, assignments):
@@ -229,13 +225,25 @@ def assign_from_subscribed(self, assignments):
229225
if tp.topic not in self.subscription:
230226
raise ValueError("Assigned partition %s for non-subscribed topic." % (tp,))
231227

232-
# after rebalancing, we always reinitialize the assignment state
233-
self.assignment.clear()
234-
for tp in assignments:
235-
self._add_assigned_partition(tp)
228+
# after rebalancing, we always reinitialize the assignment value
229+
# randomized ordering should improve balance for short-lived consumers
230+
self._set_assignment({partition: TopicPartitionState() for partition in assignments}, randomize=True)
236231
self.needs_fetch_committed_offsets = True
237232
log.info("Updated partition assignment: %s", assignments)
238233

234+
def _set_assignment(self, partition_states, randomize=False):
235+
"""Batch partition assignment by topic (self.assignment is OrderedDict)"""
236+
self.assignment.clear()
237+
topics = [tp.topic for tp in six.iterkeys(partition_states)]
238+
if randomize:
239+
random.shuffle(topics)
240+
topic_partitions = OrderedDict({topic: [] for topic in topics})
241+
for tp in six.iterkeys(partition_states):
242+
topic_partitions[tp.topic].append(tp)
243+
for topic in six.iterkeys(topic_partitions):
244+
for tp in topic_partitions[topic]:
245+
self.assignment[tp] = partition_states[tp]
246+
239247
def unsubscribe(self):
240248
"""Clear all topic subscriptions and partition assignments"""
241249
self.subscription = None
@@ -283,11 +291,11 @@ def paused_partitions(self):
283291
if self.is_paused(partition))
284292

285293
def fetchable_partitions(self):
286-
"""Return set of TopicPartitions that should be Fetched."""
287-
fetchable = set()
294+
"""Return ordered list of TopicPartitions that should be Fetched."""
295+
fetchable = list()
288296
for partition, state in six.iteritems(self.assignment):
289297
if state.is_fetchable():
290-
fetchable.add(partition)
298+
fetchable.append(partition)
291299
return fetchable
292300

293301
def partitions_auto_assigned(self):
@@ -348,8 +356,9 @@ def pause(self, partition):
348356
def resume(self, partition):
349357
self.assignment[partition].resume()
350358

351-
def _add_assigned_partition(self, partition):
352-
self.assignment[partition] = TopicPartitionState()
359+
def move_partition_to_end(self, partition):
360+
if partition in self.assignment:
361+
self.assignment.move_to_end(partition)
353362

354363

355364
class TopicPartitionState(object):

test/test_fetcher.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def test__unpack_records(mocker):
451451
(None, b"c", None),
452452
]
453453
memory_records = MemoryRecords(_build_record_batch(messages))
454-
part_records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock())
454+
part_records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
455455
records = list(part_records.record_iterator)
456456
assert len(records) == 3
457457
assert all(map(lambda x: isinstance(x, ConsumerRecord), records))
@@ -556,7 +556,7 @@ def test_partition_records_offset(mocker):
556556
tp = TopicPartition('foo', 0)
557557
messages = [(None, b'msg', None) for i in range(batch_start, batch_end)]
558558
memory_records = MemoryRecords(_build_record_batch(messages, offset=batch_start))
559-
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock())
559+
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
560560
assert records
561561
assert records.next_fetch_offset == fetch_offset
562562
msgs = records.take(1)
@@ -573,7 +573,7 @@ def test_partition_records_offset(mocker):
573573
def test_partition_records_empty(mocker):
574574
tp = TopicPartition('foo', 0)
575575
memory_records = MemoryRecords(_build_record_batch([]))
576-
records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock())
576+
records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
577577
msgs = records.take()
578578
assert len(msgs) == 0
579579
assert not records
@@ -586,7 +586,7 @@ def test_partition_records_no_fetch_offset(mocker):
586586
tp = TopicPartition('foo', 0)
587587
messages = [(None, b'msg', None) for i in range(batch_start, batch_end)]
588588
memory_records = MemoryRecords(_build_record_batch(messages, offset=batch_start))
589-
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock())
589+
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
590590
msgs = records.take()
591591
assert len(msgs) == 0
592592
assert not records
@@ -610,7 +610,7 @@ def test_partition_records_compacted_offset(mocker):
610610
builder.append(key=None, value=b'msg', timestamp=None, headers=[])
611611
builder.close()
612612
memory_records = MemoryRecords(builder.buffer())
613-
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock())
613+
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
614614
msgs = records.take()
615615
assert len(msgs) == batch_end - fetch_offset - 1
616616
assert msgs[0].offset == fetch_offset + 1

0 commit comments

Comments
 (0)