diff --git a/pkg/agent/billing/billing.go b/pkg/agent/billing/billing.go index c2bdfca1e..71c7e8925 100644 --- a/pkg/agent/billing/billing.go +++ b/pkg/agent/billing/billing.go @@ -14,6 +14,7 @@ import ( "github.com/neondatabase/autoscaling/pkg/api" "github.com/neondatabase/autoscaling/pkg/billing" "github.com/neondatabase/autoscaling/pkg/reporting" + "github.com/neondatabase/autoscaling/pkg/util/taskgroup" ) type Config struct { @@ -72,6 +73,7 @@ type MetricsCollector struct { func NewMetricsCollector( ctx context.Context, parentLogger *zap.Logger, + tg taskgroup.Group, conf *Config, metrics PromMetrics, ) (*MetricsCollector, error) { @@ -82,7 +84,10 @@ func NewMetricsCollector( return nil, err } - sink := reporting.NewEventSink(logger, metrics.reporting, clients...) + // Note: we pass context.TODO() because we want to manually shut down the event senders only + // once we've had a chance to run a final collection of remaining billing events. + // So instead, we defer a call to sink.Finish() in (*MetricsCollector).Run() + sink := reporting.NewEventSink(context.TODO(), logger, tg, metrics.reporting, clients...) return &MetricsCollector{ conf: conf, @@ -105,6 +110,8 @@ func (mc *MetricsCollector) Run( accumulateTicker := time.NewTicker(time.Second * time.Duration(mc.conf.AccumulateEverySeconds)) defer accumulateTicker.Stop() + defer mc.sink.Finish() + state := metricsState{ historical: make(map[metricsKey]vmMetricsHistory), present: make(map[metricsKey]vmMetricsInstant), diff --git a/pkg/agent/entrypoint.go b/pkg/agent/entrypoint.go index ae315161f..fe3c3e186 100644 --- a/pkg/agent/entrypoint.go +++ b/pkg/agent/entrypoint.go @@ -52,8 +52,10 @@ func (r MainRunner) Run(logger *zap.Logger, ctx context.Context) error { } defer schedTracker.Stop() + tg := taskgroup.NewGroup(logger, taskgroup.WithParentContext(ctx)) + scalingEventsMetrics := scalingevents.NewPromMetrics() - scalingReporter, err := scalingevents.NewReporter(ctx, logger, &r.Config.ScalingEvents, scalingEventsMetrics) + scalingReporter, err := scalingevents.NewReporter(ctx, logger, tg, &r.Config.ScalingEvents, scalingEventsMetrics) if err != nil { return fmt.Errorf("Error creating scaling events reporter: %w", err) } @@ -82,12 +84,11 @@ func (r MainRunner) Run(logger *zap.Logger, ctx context.Context) error { } } - mc, err := billing.NewMetricsCollector(ctx, logger, &r.Config.Billing, metrics) + mc, err := billing.NewMetricsCollector(ctx, logger, tg, &r.Config.Billing, metrics) if err != nil { return fmt.Errorf("error creating billing metrics collector: %w", err) } - tg := taskgroup.NewGroup(logger, taskgroup.WithParentContext(ctx)) tg.Go("billing", func(logger *zap.Logger) error { return mc.Run(tg.Ctx(), logger, storeForNode) }) diff --git a/pkg/agent/scalingevents/reporter.go b/pkg/agent/scalingevents/reporter.go index cd9165f01..504f55b2d 100644 --- a/pkg/agent/scalingevents/reporter.go +++ b/pkg/agent/scalingevents/reporter.go @@ -9,6 +9,7 @@ import ( "go.uber.org/zap" "github.com/neondatabase/autoscaling/pkg/reporting" + "github.com/neondatabase/autoscaling/pkg/util/taskgroup" ) type Config struct { @@ -61,6 +62,7 @@ const ( func NewReporter( ctx context.Context, parentLogger *zap.Logger, + tg taskgroup.Group, conf *Config, metrics PromMetrics, ) (*Reporter, error) { @@ -71,7 +73,7 @@ func NewReporter( return nil, err } - sink := reporting.NewEventSink(logger, metrics.reporting, clients...) + sink := reporting.NewEventSink(ctx, logger, tg, metrics.reporting, clients...) return &Reporter{ conf: conf, diff --git a/pkg/reporting/send.go b/pkg/reporting/send.go index 0a2780d06..d97603241 100644 --- a/pkg/reporting/send.go +++ b/pkg/reporting/send.go @@ -12,7 +12,6 @@ type eventSender[E any] struct { metrics *EventSinkMetrics queue eventQueuePuller[E] - done <-chan struct{} // lastSendDuration tracks the "real" last full duration of (eventSender).sendAllCurrentEvents(). // @@ -37,7 +36,7 @@ type eventSender[E any] struct { lastSendDuration time.Duration } -func (s eventSender[E]) senderLoop(logger *zap.Logger) { +func (s eventSender[E]) senderLoop(ctx context.Context, logger *zap.Logger) { ticker := time.NewTicker(time.Second * time.Duration(s.client.BaseConfig.PushEverySeconds)) defer ticker.Stop() @@ -45,7 +44,7 @@ func (s eventSender[E]) senderLoop(logger *zap.Logger) { final := false select { - case <-s.done: + case <-ctx.Done(): logger.Info("Received notification that events submission is done") final = true case <-ticker.C: diff --git a/pkg/reporting/sink.go b/pkg/reporting/sink.go index 941d156dc..99e1daee9 100644 --- a/pkg/reporting/sink.go +++ b/pkg/reporting/sink.go @@ -3,11 +3,14 @@ package reporting // public API for event reporting import ( + "context" "fmt" "sync" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" + + "github.com/neondatabase/autoscaling/pkg/util/taskgroup" ) type EventSink[E any] struct { @@ -15,9 +18,16 @@ type EventSink[E any] struct { done func() } -func NewEventSink[E any](logger *zap.Logger, metrics *EventSinkMetrics, clients ...Client[E]) *EventSink[E] { +func NewEventSink[E any]( + ctx context.Context, + logger *zap.Logger, + tg taskgroup.Group, + metrics *EventSinkMetrics, + clients ...Client[E], +) *EventSink[E] { var queueWriters []eventQueuePusher[E] - signalDone := make(chan struct{}) + + senderCtx, cancelSenders := context.WithCancel(ctx) for _, c := range clients { qw, qr := newEventQueue[E](metrics.queueSizeCurrent.WithLabelValues(c.Name)) @@ -28,15 +38,18 @@ func NewEventSink[E any](logger *zap.Logger, metrics *EventSinkMetrics, clients client: c, metrics: metrics, queue: qr, - done: signalDone, lastSendDuration: 0, } - go sender.senderLoop(logger.Named(fmt.Sprintf("send-%s", c.Name))) + taskName := fmt.Sprintf("send-%s", c.Name) + tg.Go(taskName, func(_ *zap.Logger) error { + sender.senderLoop(senderCtx, logger.Named(taskName)) + return nil + }) } return &EventSink[E]{ queueWriters: queueWriters, - done: sync.OnceFunc(func() { close(signalDone) }), + done: sync.OnceFunc(cancelSenders), } } @@ -47,6 +60,12 @@ func (s *EventSink[E]) Enqueue(event E) { } } +// Finish signals that the last events have been Enqueue'd, and so they should be sent off before +// shutting down. +func (s *EventSink[E]) Finish() { + s.done() +} + type EventSinkMetrics struct { queueSizeCurrent *prometheus.GaugeVec lastSendDuration *prometheus.GaugeVec