diff --git a/pkg/source/pulsar.go b/pkg/source/pulsar.go index 7dab3e5..7024b87 100644 --- a/pkg/source/pulsar.go +++ b/pkg/source/pulsar.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "os" + "sync" "time" "github.com/apache/pulsar-client-go/pulsar" @@ -165,10 +166,11 @@ type PulsarConsumerSource struct { PulsarReplicateState bool PulsarMaxReconnect *uint - client pulsar.Client - consumer pulsar.Consumer - log *logrus.Entry - ackTrackers map[string]*ackTracker + client pulsar.Client + consumer pulsar.Consumer + log *logrus.Entry + + ackTrackers *ackTrackers } func (p *PulsarConsumerSource) Capture(cp cursor.Checkpoint) (changes chan Change, err error) { @@ -195,7 +197,7 @@ func (p *PulsarConsumerSource) Capture(cp cursor.Checkpoint) (changes chan Chang return nil, err } - p.ackTrackers = make(map[string]*ackTracker, AckTrackerSize) + p.ackTrackers = newAckTrackers() p.log = logrus.WithFields(logrus.Fields{ "From": "PulsarConsumerSource", @@ -230,13 +232,11 @@ func (p *PulsarConsumerSource) Capture(cp cursor.Checkpoint) (changes chan Chang first = true } + // we only need to track the ack for batched message + // because pulsar will ack the whole batch if ack the first message if msg.ID().BatchSize() > 1 { - key := p.ackTrackerKey(msg.ID()) - if _, ok := p.ackTrackers[key]; !ok { - p.ackTrackers[key] = newAckTracker(uint(msg.ID().BatchSize())) - } + p.ackTrackers.tryAdd(msg.ID()) } - change = Change{Checkpoint: checkpoint, Message: m} return }, func() { @@ -249,11 +249,10 @@ func (p *PulsarConsumerSource) Capture(cp cursor.Checkpoint) (changes chan Chang func (p *PulsarConsumerSource) Commit(cp cursor.Checkpoint) { if mid, err := pulsar.DeserializeMessageID(cp.Data); err == nil { - tracker, ok := p.ackTrackers[p.ackTrackerKey(mid)] - if ok && tracker.ack(int(mid.BatchIdx())) { + // only ack the batch if all messages in the batch are acked + if ack, exist := p.ackTrackers.tryAck(mid); ack && exist { _ = p.consumer.AckID(mid) - delete(p.ackTrackers, p.ackTrackerKey(mid)) - } else if !ok { + } else if !ack && !exist { _ = p.consumer.AckID(mid) } } @@ -272,6 +271,7 @@ func (p *PulsarConsumerSource) ackTrackerKey(id pulsar.MessageID) string { type ackTracker struct { size uint batchIDs *bitset.BitSet + mu sync.Mutex } func newAckTracker(size uint) *ackTracker { @@ -286,9 +286,45 @@ func newAckTracker(size uint) *ackTracker { } func (t *ackTracker) ack(batchID int) bool { + t.mu.Lock() + defer t.mu.Unlock() if batchID < 0 { return true } t.batchIDs.Clear(uint(batchID)) return t.batchIDs.None() } + +type ackTrackers struct { + trackers sync.Map +} + +func newAckTrackers() *ackTrackers { + return &ackTrackers{} +} + +func (a *ackTrackers) key(msg pulsar.MessageID) string { + return fmt.Sprintf("%d:%d", msg.LedgerID(), msg.EntryID()) +} + +func (a *ackTrackers) tryAdd(msg pulsar.MessageID) (ok bool) { + key := a.key(msg) + _, ok = a.trackers.Load(key) + if !ok { + _, ok = a.trackers.LoadOrStore(key, newAckTracker(uint(msg.BatchSize()))) + } + return !ok +} + +func (a *ackTrackers) tryAck(msg pulsar.MessageID) (success bool, exist bool) { + key := a.key(msg) + v, ok := a.trackers.Load(key) + if ok { + tracker := v.(*ackTracker) + success, exist = tracker.ack(int(msg.BatchIdx())), ok + if success { + _, exist = a.trackers.LoadAndDelete(key) + } + } + return +}