handler_test.go

package lambdalog

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"strings"
	"testing"
	"time"
)

func TestHandler_InjectsLoggerWithRequestID(t *testing.T) {
	t.Parallel()

	var buf bytes.Buffer
	base := New(&buf)
	base.now = func() time.Time { return time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) }

	called := false
	h := base.Handler(func(ctx context.Context) error {
		called = true
		LoggerFromContext(ctx).Info("hello")
		return nil
	})

	ctx := WithRequestID(context.Background(), "rid-42")
	if _, err := h(ctx, nil); err != nil {
		t.Fatalf("handler returned error: %v", err)
	}
	if !called {
		t.Fatal("wrapped handler was not invoked")
	}
	var rec map[string]any
	if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &rec); err != nil {
		t.Fatalf("parse record: %v", err)
	}
	if rec["rid"] != "rid-42" {
		t.Fatalf("expected rid attached, got %v", rec["rid"])
	}
}

func TestHandler_ClassifyRejectsBadSignatures(t *testing.T) {
	t.Parallel()

	cases := []struct {
		name string
		fn   any
	}{
		{"not a func", 42},
		{"no args", func() error { return nil }},
		{"first not ctx", func(s string) error { return nil }},
		{"too many args", func(ctx context.Context, a, b int) error { return nil }},
		{"returns wrong", func(ctx context.Context) string { return "" }},
	}

	for _, tc := range cases {
		tc := tc
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()
			h := Handler(tc.fn)
			if _, err := h(context.Background(), nil); err == nil {
				t.Fatal("expected error from malformed handler")
			}
		})
	}
}

func TestHandler_PropagatesUserError(t *testing.T) {
	t.Parallel()

	sentinel := errors.New("boom")
	h := Handler(func(ctx context.Context) error { return sentinel })

	_, err := h(context.Background(), nil)
	if !errors.Is(err, sentinel) {
		t.Fatalf("expected sentinel error, got %v", err)
	}
}

func TestHandler_DecodesTypedInput(t *testing.T) {
	t.Parallel()

	type req struct {
		User string `json:"user"`
	}

	var seen string
	h := Handler(func(ctx context.Context, in req) (map[string]string, error) {
		seen = in.User
		return map[string]string{"greeted": in.User}, nil
	})

	out, err := h(context.Background(), []byte(`{"user":"alice"}`))
	if err != nil {
		t.Fatalf("handler error: %v", err)
	}
	if seen != "alice" {
		t.Fatalf("expected typed decode, got %q", seen)
	}
	if !strings.Contains(string(out), `"greeted":"alice"`) {
		t.Fatalf("unexpected output: %s", out)
	}
}

func TestLoggerFromContext_FallsBackToRoot(t *testing.T) {
	t.Parallel()
	t.Helper()

	if LoggerFromContext(nil) == nil {
		t.Fatal("nil context should still yield the root logger")
	}
	if LoggerFromContext(context.Background()) == nil {
		t.Fatal("empty context should still yield the root logger")
	}
}

func TestWithLogger_RoundTrip(t *testing.T) {
	t.Parallel()

	l := New(nil).With("svc", "checkout")
	ctx := WithLogger(context.Background(), l)
	got := LoggerFromContext(ctx)
	if got != l {
		t.Fatalf("logger round-trip failed: %p != %p", got, l)
	}
	t.Cleanup(func() {
		// Nothing to tear down; exercised to satisfy the style guide.
	})
}