Skip to content

Commit 6c2c25d

Browse files
authored
Use SubscriptionType to track topics/pattern/user assignment (#2565)
1 parent 70574d1 commit 6c2c25d

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

kafka/consumer/subscription_state.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
from collections import Sequence
77
except ImportError:
88
from collections.abc import Sequence
9+
try:
10+
# enum in stdlib as of py3.4
11+
from enum import IntEnum # pylint: disable=import-error
12+
except ImportError:
13+
# vendored backport module
14+
from kafka.vendor.enum34 import IntEnum
915
import logging
1016
import random
1117
import re
@@ -20,6 +26,13 @@
2026
log = logging.getLogger(__name__)
2127

2228

29+
class SubscriptionType(IntEnum):
30+
NONE = 0
31+
AUTO_TOPICS = 1
32+
AUTO_PATTERN = 2
33+
USER_ASSIGNED = 3
34+
35+
2336
class SubscriptionState(object):
2437
"""
2538
A class for tracking the topics, partitions, and offsets for the consumer.
@@ -67,6 +80,7 @@ def __init__(self, offset_reset_strategy='earliest'):
6780
self._default_offset_reset_strategy = offset_reset_strategy
6881

6982
self.subscription = None # set() or None
83+
self.subscription_type = SubscriptionType.NONE
7084
self.subscribed_pattern = None # regex str or None
7185
self._group_subscription = set()
7286
self._user_assignment = set()
@@ -76,6 +90,14 @@ def __init__(self, offset_reset_strategy='earliest'):
7690
# initialize to true for the consumers to fetch offset upon starting up
7791
self.needs_fetch_committed_offsets = True
7892

93+
def _set_subscription_type(self, subscription_type):
94+
if not isinstance(subscription_type, SubscriptionType):
95+
raise ValueError('SubscriptionType enum required')
96+
if self.subscription_type == SubscriptionType.NONE:
97+
self.subscription_type = subscription_type
98+
elif self.subscription_type != subscription_type:
99+
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
100+
79101
def subscribe(self, topics=(), pattern=None, listener=None):
80102
"""Subscribe to a list of topics, or a topic regex pattern.
81103
@@ -111,17 +133,19 @@ def subscribe(self, topics=(), pattern=None, listener=None):
111133
guaranteed, however, that the partitions revoked/assigned
112134
through this interface are from topics subscribed in this call.
113135
"""
114-
if self._user_assignment or (topics and pattern):
115-
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
116136
assert topics or pattern, 'Must provide topics or pattern'
137+
if (topics and pattern):
138+
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
117139

118-
if pattern:
140+
elif pattern:
141+
self._set_subscription_type(SubscriptionType.AUTO_PATTERN)
119142
log.info('Subscribing to pattern: /%s/', pattern)
120143
self.subscription = set()
121144
self.subscribed_pattern = re.compile(pattern)
122145
else:
123146
if isinstance(topics, str) or not isinstance(topics, Sequence):
124147
raise TypeError('Topics must be a list (or non-str sequence)')
148+
self._set_subscription_type(SubscriptionType.AUTO_TOPICS)
125149
self.change_subscription(topics)
126150

127151
if listener and not isinstance(listener, ConsumerRebalanceListener):
@@ -141,7 +165,7 @@ def change_subscription(self, topics):
141165
- a topic name is '.' or '..' or
142166
- a topic name does not consist of ASCII-characters/'-'/'_'/'.'
143167
"""
144-
if self._user_assignment:
168+
if not self.partitions_auto_assigned():
145169
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
146170

147171
if isinstance(topics, six.string_types):
@@ -168,13 +192,13 @@ def group_subscribe(self, topics):
168192
Arguments:
169193
topics (list of str): topics to add to the group subscription
170194
"""
171-
if self._user_assignment:
195+
if not self.partitions_auto_assigned():
172196
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
173197
self._group_subscription.update(topics)
174198

175199
def reset_group_subscription(self):
176200
"""Reset the group's subscription to only contain topics subscribed by this consumer."""
177-
if self._user_assignment:
201+
if not self.partitions_auto_assigned():
178202
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
179203
assert self.subscription is not None, 'Subscription required'
180204
self._group_subscription.intersection_update(self.subscription)
@@ -197,9 +221,7 @@ def assign_from_user(self, partitions):
197221
Raises:
198222
IllegalStateError: if consumer has already called subscribe()
199223
"""
200-
if self.subscription is not None:
201-
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
202-
224+
self._set_subscription_type(SubscriptionType.USER_ASSIGNED)
203225
if self._user_assignment != set(partitions):
204226
self._user_assignment = set(partitions)
205227
self._set_assignment({partition: self.assignment.get(partition, TopicPartitionState())
@@ -250,6 +272,7 @@ def unsubscribe(self):
250272
self._user_assignment.clear()
251273
self.assignment.clear()
252274
self.subscribed_pattern = None
275+
self.subscription_type = SubscriptionType.NONE
253276

254277
def group_subscription(self):
255278
"""Get the topic subscription for the group.
@@ -300,7 +323,7 @@ def fetchable_partitions(self):
300323

301324
def partitions_auto_assigned(self):
302325
"""Return True unless user supplied partitions manually."""
303-
return self.subscription is not None
326+
return self.subscription_type in (SubscriptionType.AUTO_TOPICS, SubscriptionType.AUTO_PATTERN)
304327

305328
def all_consumed_offsets(self):
306329
"""Returns consumed offsets as {TopicPartition: OffsetAndMetadata}"""

test/test_consumer_integration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def test_kafka_consumer_unsupported_encoding(
6868
def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages):
6969
TIMEOUT_MS = 500
7070
consumer = kafka_consumer_factory(auto_offset_reset='earliest',
71-
enable_auto_commit=False,
72-
consumer_timeout_ms=TIMEOUT_MS)
71+
enable_auto_commit=False,
72+
consumer_timeout_ms=TIMEOUT_MS)
7373

7474
# Manual assignment avoids overhead of consumer group mgmt
7575
consumer.unsubscribe()

test/test_coordinator.py

+1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def test_subscription_listener_failure(mocker, coordinator):
189189

190190

191191
def test_perform_assignment(mocker, coordinator):
192+
coordinator._subscription.subscribe(topics=['foo1'])
192193
member_metadata = {
193194
'member-foo': ConsumerProtocolMemberMetadata(0, ['foo1'], b''),
194195
'member-bar': ConsumerProtocolMemberMetadata(0, ['foo1'], b'')

0 commit comments

Comments
 (0)