Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion pkg/logger/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@rules_go//go:def.bzl", "go_library")
load("@rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "logger",
Expand All @@ -16,3 +16,18 @@ go_library(
visibility = ["//visibility:public"],
deps = ["//pkg/fault"],
)

go_test(
name = "logger_test",
srcs = [
"fault_handler_test.go",
"wide_test.go",
],
deps = [
":logger",
"//pkg/codes",
"//pkg/fault",
"//pkg/logger/loggertest",
"@com_github_stretchr_testify//require",
],
)
33 changes: 25 additions & 8 deletions pkg/logger/event.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package logger

import (
"context"
"fmt"
"log/slog"
"sync"
Expand All @@ -26,6 +27,10 @@ type Event struct {
attrs []slog.Attr
errors []error
written bool
// pc is the program counter captured at [StartWideEvent] so the emitted
// log record's source attribute points at the caller that opened the
// event, not at this file. Zero means no PC was captured.
pc uintptr
}

// Add appends attributes to the event. These attributes will be included
Expand Down Expand Up @@ -86,21 +91,33 @@ func (e *Event) End() {
))
}

casted := []any{
attrs := make([]slog.Attr, 0, len(e.attrs)+2)
attrs = append(attrs,
slog.GroupAttrs("errors", errors...),
slog.GroupAttrs("log_meta",
slog.Time("start", e.start),
slog.Duration("duration", time.Since(e.start)),
),
}
for _, attr := range e.attrs {
casted = append(casted, attr)
}
)
attrs = append(attrs, e.attrs...)

level := slog.LevelInfo
msg := e.message
if len(e.errors) > 0 {
logger.Error("error", casted...)
} else {
logger.Info(e.message, casted...)
level = slog.LevelError
msg = "error"
}

ctx := context.Background()
if !logger.Enabled(ctx, level) {
return
}

// Build the record manually so the source attribute points at the
// caller of StartWideEvent (the handler/middleware that opened the
// event), not at this file. Using logger.Error/Info here would capture
// the PC of this line instead, which is useless for debugging.
r := slog.NewRecord(time.Now(), level, msg, e.pc)
r.AddAttrs(attrs...)
_ = logger.Handler().Handle(ctx, r)
}
118 changes: 118 additions & 0 deletions pkg/logger/fault_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package logger_test

import (
"errors"
"log/slog"
"strings"
"testing"

"github.com/stretchr/testify/require"

"github.com/unkeyed/unkey/pkg/codes"
"github.com/unkeyed/unkey/pkg/fault"
"github.com/unkeyed/unkey/pkg/logger"
"github.com/unkeyed/unkey/pkg/logger/loggertest"
)

func TestFaultHandler_EnrichesFaultError(t *testing.T) {
h := loggertest.Install(t)

root := fault.New("root cause", fault.Code(codes.App.Internal.UnexpectedError.URN()))
wrapped := fault.Wrap(root, fault.Internal("outer"))

logger.Error("boom", "error", wrapped)

attrs := loggertest.FlatAttrs(h.Last(t))

steps, ok := attrs["error.steps"].([]fault.Step)
require.True(t, ok, "expected []fault.Step for error.steps, got %T", attrs["error.steps"])
require.Len(t, steps, 3, "fault chain should produce 3 steps (root + Wrap-merge + outer)")
require.Equal(t, "root cause", steps[0].Message)
require.Equal(t, "outer", steps[len(steps)-1].Message)

loc, ok := attrs["error.location"].(string)
require.True(t, ok, "error.location should be a string")
require.Equal(t, steps[len(steps)-1].Location, loc,
"error.location should match the outermost wrap")
}

func TestFaultHandler_IgnoresNonFaultError(t *testing.T) {
h := loggertest.Install(t)

logger.Error("boom", "error", errors.New("plain error"))

attrs := loggertest.FlatAttrs(h.Last(t))
_, hasSteps := attrs["error.steps"]
_, hasLoc := attrs["error.location"]
require.False(t, hasSteps, "stdlib errors should not produce error.steps")
require.False(t, hasLoc, "stdlib errors should not produce error.location")
}

func TestFaultHandler_IgnoresWhenNoErrorArg(t *testing.T) {
h := loggertest.Install(t)

logger.Info("hello", "user_id", "u_123", "count", 42)

attrs := loggertest.FlatAttrs(h.Last(t))
_, hasSteps := attrs["error.steps"]
require.False(t, hasSteps, "records without an error value should not be enriched")
require.Equal(t, "u_123", attrs["user_id"])
}

func TestFaultHandler_FirstFaultErrorWins(t *testing.T) {
h := loggertest.Install(t)

first := fault.New("first error")
second := fault.New("second error")

logger.Error("boom", "first", first, "second", second)

attrs := loggertest.FlatAttrs(h.Last(t))
steps := attrs["error.steps"].([]fault.Step)
require.Equal(t, "first error", steps[0].Message,
"only the first fault error in args should drive enrichment")
}

func TestFaultHandler_AppliesAcrossAllSinks(t *testing.T) {
// Two captures installed back-to-back must BOTH see the enriched
// record — proves enrichment runs at the top of the fan-out instead
// of being baked into a single inner handler.
a := loggertest.Install(t)
b := loggertest.Install(t)

logger.Error("boom", "error", fault.New("oops"))

for name, h := range map[string]*loggertest.CaptureHandler{"a": a, "b": b} {
attrs := loggertest.FlatAttrs(h.Last(t))
_, ok := attrs["error.steps"]
require.True(t, ok, "handler %s should have received the enriched record", name)
}
}

func TestFaultHandler_FaultInSlogAttrValue(t *testing.T) {
h := loggertest.Install(t)

err := fault.New("via attr")
logger.Error("boom", slog.Any("err", err))

attrs := loggertest.FlatAttrs(h.Last(t))
require.NotNil(t, attrs["error.steps"],
"errors passed via slog.Any(...) should still be detected")
}

func TestLoggerAliases_SourceIsCaller(t *testing.T) {
// Aliasing logger.Error -> slog.Error (no wrapper frame) means the PC
// captured by stdlib slog points at the caller, not at this package.
h := loggertest.Install(t)

logger.Error("boom") // <-- expected source line

rec := h.Last(t)
require.NotZero(t, rec.PC, "record PC must be set")

frame := loggertest.PCFrame(rec.PC)
require.True(t, strings.HasSuffix(frame.File, "fault_handler_test.go"),
"source file should be the test file, got %s", frame.File)
require.Contains(t, frame.Function, "TestLoggerAliases_SourceIsCaller",
"source function should be the test, got %s", frame.Function)
}
12 changes: 12 additions & 0 deletions pkg/logger/loggertest/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
load("@rules_go//go:def.bzl", "go_library")

go_library(
name = "loggertest",
srcs = ["capture.go"],
importpath = "github.com/unkeyed/unkey/pkg/logger/loggertest",
visibility = ["//visibility:public"],
deps = [
"//pkg/logger",
"@com_github_stretchr_testify//require",
],
)
160 changes: 160 additions & 0 deletions pkg/logger/loggertest/capture.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Package loggertest provides shared testing utilities for the global
// [logger] package: a slog handler that records every emitted record and
// helpers for inspecting them.
//
// This is testing-only — it lives in its own package so production code
// can't accidentally depend on it, but assertions across services can
// reuse the same capture implementation instead of redefining one in
// every test file.
package loggertest

import (
"context"
"log/slog"
"runtime"
"sync"
"testing"

"github.com/stretchr/testify/require"

"github.com/unkeyed/unkey/pkg/logger"
)

// PCFrame turns a single program counter into its first stack frame.
// Useful for asserting on the source attribute slog will derive from a
// record's PC.
func PCFrame(pc uintptr) runtime.Frame {
frames := runtime.CallersFrames([]uintptr{pc})
f, _ := frames.Next()
return f
}

// CaptureHandler is a [slog.Handler] that stores every log record it
// receives. Install it into the global logger via [Install] (or call
// [logger.AddHandler] directly) and inspect the captured records in
// assertions.
//
// Safe for concurrent use.
type CaptureHandler struct {
mu sync.Mutex
records []slog.Record
}

// New returns an empty CaptureHandler. Most tests should use [Install]
// instead, which both constructs the handler and wires it into the
// global logger.
func New() *CaptureHandler {
return &CaptureHandler{mu: sync.Mutex{}, records: nil}
}

// Install registers a fresh CaptureHandler with the global logger and
// returns it. Every subsequent log call (via logger.Error, slog.Info,
// wide events, etc.) will be recorded.
//
// The global logger keeps a reference to the handler for the rest of
// the process; tests should not rely on cleanup between runs. Use
// [CaptureHandler.Records] and [CaptureHandler.Snapshot] to scope
// assertions to records emitted after a known point.
func Install(t *testing.T) *CaptureHandler {
t.Helper()
h := New()
logger.AddHandler(h)
return h
}

func (h *CaptureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }

func (h *CaptureHandler) Handle(_ context.Context, r slog.Record) error {
h.mu.Lock()
defer h.mu.Unlock()
// Clone so later mutations on the record (e.g. AddAttrs by downstream
// handlers in the same fan-out) don't bleed into our snapshot.
h.records = append(h.records, r.Clone())
return nil
}

func (h *CaptureHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
func (h *CaptureHandler) WithGroup(_ string) slog.Handler { return h }

// Records returns a copy of every record captured so far.
func (h *CaptureHandler) Records() []slog.Record {
h.mu.Lock()
defer h.mu.Unlock()
out := make([]slog.Record, len(h.records))
copy(out, h.records)
return out
}

// Snapshot returns the current record count. Pair with [CaptureHandler.Since]
// to assert only on records emitted by a specific block of test code, which
// matters because the global logger is shared across tests.
func (h *CaptureHandler) Snapshot() int {
h.mu.Lock()
defer h.mu.Unlock()
return len(h.records)
}

// Since returns all records captured after the given snapshot index.
func (h *CaptureHandler) Since(idx int) []slog.Record {
h.mu.Lock()
defer h.mu.Unlock()
if idx > len(h.records) {
return nil
}
out := make([]slog.Record, len(h.records)-idx)
copy(out, h.records[idx:])
return out
}

// Last returns the most recent record, failing the test if none have
// been captured.
func (h *CaptureHandler) Last(t *testing.T) slog.Record {
t.Helper()
h.mu.Lock()
defer h.mu.Unlock()
require.NotEmpty(t, h.records, "expected at least one log record")
return h.records[len(h.records)-1]
}

// Find returns the first captured record whose message equals msg. Tests
// use this instead of indexing when the global logger is shared and the
// order of records isn't predictable.
func (h *CaptureHandler) Find(t *testing.T, msg string) slog.Record {
t.Helper()
h.mu.Lock()
defer h.mu.Unlock()
for _, r := range h.records {
if r.Message == msg {
return r
}
}
t.Fatalf("no record with message %q (captured %d records)", msg, len(h.records))
return slog.Record{} //nolint:exhaustruct // unreachable; t.Fatalf aborts the test
}

// FlatAttrs collapses a record's attributes into key→value, flattening
// groups with dotted keys (e.g. "http.method"). Keeps test assertions
// readable without having to manually walk slog.Value.Group() trees.
func FlatAttrs(r slog.Record) map[string]any {
out := map[string]any{}
var walk func(prefix string, attrs []slog.Attr)
walk = func(prefix string, attrs []slog.Attr) {
for _, a := range attrs {
key := prefix + a.Key
if a.Value.Kind() == slog.KindGroup {
walk(key+".", a.Value.Group())
continue
}
out[key] = a.Value.Any()
}
}
r.Attrs(func(a slog.Attr) bool {
if a.Value.Kind() == slog.KindGroup {
walk(a.Key+".", a.Value.Group())
} else {
out[a.Key] = a.Value.Any()
}
return true
})
return out
}
Loading