diff --git a/reader.go b/reader.go index 04d90f35..958a2ed8 100644 --- a/reader.go +++ b/reader.go @@ -304,6 +304,20 @@ func (r *Reader) run(cg *ConsumerGroup) { for attempt := 1; attempt <= r.config.MaxAttempts; attempt++ { gen, err = cg.Next(r.stctx) if err == nil { + if r.config.AssignmentListener != nil { + assignments := make([]GroupMemberTopic, 0, len(gen.Assignments)) + for topic, partitions := range gen.Assignments { + assignedPartitions := make([]int, 0, len(partitions)) + for _, partition := range partitions { + assignedPartitions = append(assignedPartitions, partition.ID) + } + assignments = append(assignments, GroupMemberTopic{ + Topic: topic, + Partitions: assignedPartitions, + }) + } + r.config.AssignmentListener(assignments) + } break } if errors.Is(err, r.stctx.Err()) { @@ -522,6 +536,9 @@ type ReaderConfig struct { // This flag is being added to retain backwards-compatibility, so it will be // removed in a future version of kafka-go. OffsetOutOfRangeError bool + + // AsignmentListener is called when a reassignment happens indicating what are the new partitions + AssignmentListener func(partitions []GroupMemberTopic) } // Validate method validates ReaderConfig properties. diff --git a/reader_test.go b/reader_test.go index f413d742..54b13fab 100644 --- a/reader_test.go +++ b/reader_test.go @@ -10,11 +10,13 @@ import ( "net" "os" "reflect" + "sort" "strconv" "sync" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -890,6 +892,57 @@ func TestReaderConsumerGroup(t *testing.T) { } } +func TestPartitionAssignmentListener(t *testing.T) { + // It appears that some of the tests depend on all these tests being + // run concurrently to pass... this is brittle and should be fixed + // at some point. + t.Parallel() + + topic := makeTopic() + createTopic(t, topic, 10) + defer deleteTopic(t, topic) + + var lock sync.Mutex + assignments := make([][]GroupMemberTopic, 0) + groupID := makeGroupID() + r := NewReader(ReaderConfig{ + Brokers: []string{"localhost:9092"}, + Topic: topic, + GroupID: groupID, + HeartbeatInterval: 2 * time.Second, + CommitInterval: 1 * time.Second, + RebalanceTimeout: 2 * time.Second, + RetentionTime: time.Hour, + MinBytes: 1, + MaxBytes: 1e6, + AssignmentListener: func(partitions []GroupMemberTopic) { + lock.Lock() + defer lock.Unlock() + // we sort the received partitions for easier comparison + for _, partition := range partitions { + sort.Slice(partition.Partitions, func(i, j int) bool { + return partition.Partitions[i] < partition.Partitions[j] + }) + } + assignments = append(assignments, partitions) + }, + }) + defer r.Close() + + assert.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + return reflect.DeepEqual(assignments, [][]GroupMemberTopic{ + { + GroupMemberTopic{ + Topic: topic, + Partitions: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }) + }, 10*time.Second, 100*time.Millisecond) +} + func testReaderConsumerGroupHandshake(t *testing.T, ctx context.Context, r *Reader) { prepareReader(t, context.Background(), r, makeTestSequence(5)...)