lambdalog_test.go

package lambdalog

import (
	"bytes"
	"context"
	"encoding/json"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	"github.com/google/go-cmp/cmp/cmpopts"
)

func fixedClock(t time.Time) func() time.Time {
	return func() time.Time { return t }
}

func decode(t *testing.T, line []byte) map[string]any {
	t.Helper()
	var m map[string]any
	if err := json.Unmarshal(line, &m); err != nil {
		t.Fatalf("invalid json %q: %v", line, err)
	}
	return m
}

func TestLoggerEmitsExpectedFields(t *testing.T) {
	t.Parallel()

	cases := []struct {
		name   string
		build  func(w *bytes.Buffer) *Logger
		call   func(l *Logger)
		expect map[string]any
	}{
		{
			name: "info with extra kvs",
			build: func(w *bytes.Buffer) *Logger {
				l := New(w)
				l.now = fixedClock(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC))
				return l
			},
			call: func(l *Logger) { l.Info("hello", "a", 1, "b", "x") },
			expect: map[string]any{
				"ts":    "2025-01-01T00:00:00Z",
				"level": "info",
				"msg":   "hello",
				"a":     float64(1),
				"b":     "x",
			},
		},
		{
			name: "with attrs inherited",
			build: func(w *bytes.Buffer) *Logger {
				l := New(w).With("service", "billing")
				l.now = fixedClock(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC))
				return l
			},
			call: func(l *Logger) { l.Warn("careful") },
			expect: map[string]any{
				"ts":      "2025-01-01T00:00:00Z",
				"level":   "warn",
				"msg":     "careful",
				"service": "billing",
			},
		},
		{
			name: "odd kv count yields MISSING",
			build: func(w *bytes.Buffer) *Logger {
				l := New(w)
				l.now = fixedClock(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC))
				return l
			},
			call: func(l *Logger) { l.Error("oops", "only_key") },
			expect: map[string]any{
				"ts":      "2025-01-01T00:00:00Z",
				"level":   "error",
				"msg":     "oops",
				"MISSING": "only_key",
			},
		},
	}

	for _, tc := range cases {
		tc := tc
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()
			var buf bytes.Buffer
			tc.build(&buf).With("", "").attrs = nil
			l := tc.build(&buf)
			tc.call(l)
			got := decode(t, buf.Bytes())
			if diff := cmp.Diff(tc.expect, got, cmpopts.EquateEmpty()); diff != "" {
				t.Errorf("record mismatch (-want +got):\n%s", diff)
			}
		})
	}
}

func TestLoggerFromContextAddsRequestID(t *testing.T) {
	t.Parallel()
	var buf bytes.Buffer
	l := New(&buf)
	l.now = fixedClock(time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC))

	ctx := WithRequestID(context.Background(), "req-42")
	l.FromContext(ctx).Info("hi")

	got := decode(t, buf.Bytes())
	if got["rid"] != "req-42" {
		t.Fatalf("expected rid=req-42, got %v", got["rid"])
	}
}

func TestLevelFiltering(t *testing.T) {
	t.Parallel()
	var buf bytes.Buffer
	l := New(&buf).WithLevel(LevelWarn)
	l.now = fixedClock(time.Now())
	l.Info("suppressed")
	l.Error("kept")
	if strings.Contains(buf.String(), "suppressed") {
		t.Fatalf("Info should have been dropped: %q", buf.String())
	}
	if !strings.Contains(buf.String(), "kept") {
		t.Fatalf("Error should have been emitted: %q", buf.String())
	}
}

func TestLoggerConcurrentWritesAreLineSafe(t *testing.T) {
	t.Parallel()
	var buf bytes.Buffer
	l := New(&buf)
	l.now = fixedClock(time.Now())

	var wg sync.WaitGroup
	for i := 0; i < 50; i++ {
		wg.Add(1)
		go func(i int) {
			defer wg.Done()
			l.Info("m", "i", i)
		}(i)
	}
	wg.Wait()

	lines := strings.Split(strings.TrimRight(buf.String(), "\n"), "\n")
	if len(lines) != 50 {
		t.Fatalf("expected 50 lines, got %d", len(lines))
	}
	for _, line := range lines {
		if !json.Valid([]byte(line)) {
			t.Fatalf("line is not valid json: %q", line)
		}
	}
}