package lambdalog
import (
"bytes"
"context"
"encoding/json"
"errors"
"strings"
"testing"
"time"
)
func TestHandler_InjectsLoggerWithRequestID(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
base := New(&buf)
base.now = func() time.Time { return time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) }
called := false
h := base.Handler(func(ctx context.Context) error {
called = true
LoggerFromContext(ctx).Info("hello")
return nil
})
ctx := WithRequestID(context.Background(), "rid-42")
if _, err := h(ctx, nil); err != nil {
t.Fatalf("handler returned error: %v", err)
}
if !called {
t.Fatal("wrapped handler was not invoked")
}
var rec map[string]any
if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &rec); err != nil {
t.Fatalf("parse record: %v", err)
}
if rec["rid"] != "rid-42" {
t.Fatalf("expected rid attached, got %v", rec["rid"])
}
}
func TestHandler_ClassifyRejectsBadSignatures(t *testing.T) {
t.Parallel()
cases := []struct {
name string
fn any
}{
{"not a func", 42},
{"no args", func() error { return nil }},
{"first not ctx", func(s string) error { return nil }},
{"too many args", func(ctx context.Context, a, b int) error { return nil }},
{"returns wrong", func(ctx context.Context) string { return "" }},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
h := Handler(tc.fn)
if _, err := h(context.Background(), nil); err == nil {
t.Fatal("expected error from malformed handler")
}
})
}
}
func TestHandler_PropagatesUserError(t *testing.T) {
t.Parallel()
sentinel := errors.New("boom")
h := Handler(func(ctx context.Context) error { return sentinel })
_, err := h(context.Background(), nil)
if !errors.Is(err, sentinel) {
t.Fatalf("expected sentinel error, got %v", err)
}
}
func TestHandler_DecodesTypedInput(t *testing.T) {
t.Parallel()
type req struct {
User string `json:"user"`
}
var seen string
h := Handler(func(ctx context.Context, in req) (map[string]string, error) {
seen = in.User
return map[string]string{"greeted": in.User}, nil
})
out, err := h(context.Background(), []byte(`{"user":"alice"}`))
if err != nil {
t.Fatalf("handler error: %v", err)
}
if seen != "alice" {
t.Fatalf("expected typed decode, got %q", seen)
}
if !strings.Contains(string(out), `"greeted":"alice"`) {
t.Fatalf("unexpected output: %s", out)
}
}
func TestLoggerFromContext_FallsBackToRoot(t *testing.T) {
t.Parallel()
t.Helper()
if LoggerFromContext(nil) == nil {
t.Fatal("nil context should still yield the root logger")
}
if LoggerFromContext(context.Background()) == nil {
t.Fatal("empty context should still yield the root logger")
}
}
func TestWithLogger_RoundTrip(t *testing.T) {
t.Parallel()
l := New(nil).With("svc", "checkout")
ctx := WithLogger(context.Background(), l)
got := LoggerFromContext(ctx)
if got != l {
t.Fatalf("logger round-trip failed: %p != %p", got, l)
}
t.Cleanup(func() {
// Nothing to tear down; exercised to satisfy the style guide.
})
}