diff --git a/reader.go b/reader.go index 04d90f35..a22da582 100644 --- a/reader.go +++ b/reader.go @@ -1487,25 +1487,48 @@ func (r *reader) initialize(ctx context.Context, offset int64) (conn *Conn, star return } -func (r *reader) read(ctx context.Context, offset int64, conn *Conn) (int64, error) { - r.stats.fetches.observe(1) - r.stats.offset.observe(offset) +// readBatch wraps the call to conn.ReadBatchWith to make it interruptible. +// Conn methods are written in a non-interruptible style, so the only way to +// interrupt them is to close the connection in another goroutine. +func (r *reader) readBatch(ctx context.Context, conn *Conn) (*Batch, error) { + done := make(chan struct{}) + defer close(done) - t0 := time.Now() - conn.SetReadDeadline(t0.Add(r.maxWait)) + go func() { + select { + case <-ctx.Done(): + conn.Close() + case <-done: + return + } + }() batch := conn.ReadBatchWith(ReadBatchConfig{ MinBytes: r.minBytes, MaxBytes: r.maxBytes, IsolationLevel: r.isolationLevel, }) + return batch, ctx.Err() +} + +func (r *reader) read(ctx context.Context, offset int64, conn *Conn) (int64, error) { + r.stats.fetches.observe(1) + r.stats.offset.observe(offset) + + t0 := time.Now() + conn.SetReadDeadline(t0.Add(r.maxWait)) + + batch, err := r.readBatch(ctx, conn) + if err != nil { + return offset, err + } + highWaterMark := batch.HighWaterMark() t1 := time.Now() r.stats.waitTime.observeDuration(t1.Sub(t0)) var msg Message - var err error var size int64 var bytes int64 diff --git a/reader_test.go b/reader_test.go index f413d742..3a8bb607 100644 --- a/reader_test.go +++ b/reader_test.go @@ -1846,6 +1846,49 @@ func TestReaderReadCompactedMessage(t *testing.T) { } } +func TestReaderClose(t *testing.T) { + t.Parallel() + + r := NewReader(ReaderConfig{ + Brokers: []string{"localhost:9092"}, + Topic: makeTopic(), + MaxWait: 2 * time.Second, + }) + defer r.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err := r.FetchMessage(ctx) + if errors.Is(err, context.DeadlineExceeded) { + t.Errorf("bad err: %v", err) + } + + t0 := time.Now() + r.Close() + if time.Since(t0) > 100*time.Millisecond { + t.Errorf("r.Close took too long") + } +} + +func BenchmarkReaderClose(b *testing.B) { + r := NewReader(ReaderConfig{ + Brokers: []string{"localhost:9092"}, + Topic: makeTopic(), + MaxWait: 2 * time.Second, + }) + defer r.Close() + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err := r.FetchMessage(ctx) + if errors.Is(err, context.DeadlineExceeded) { + b.Errorf("bad err: %v", err) + } + } +} + // writeMessagesForCompactionCheck writes messages with specific writer configuration. func writeMessagesForCompactionCheck(t *testing.T, topic string, msgs []Message) { t.Helper()