package lambdalog
import (
"bytes"
"context"
"encoding/json"
"strings"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func fixedClock(t time.Time) func() time.Time {
return func() time.Time { return t }
}
func decode(t *testing.T, line []byte) map[string]any {
t.Helper()
var m map[string]any
if err := json.Unmarshal(line, &m); err != nil {
t.Fatalf("invalid json %q: %v", line, err)
}
return m
}
func TestLoggerEmitsExpectedFields(t *testing.T) {
t.Parallel()
cases := []struct {
name string
build func(w *bytes.Buffer) *Logger
call func(l *Logger)
expect map[string]any
}{
{
name: "info with extra kvs",
build: func(w *bytes.Buffer) *Logger {
l := New(w)
l.now = fixedClock(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC))
return l
},
call: func(l *Logger) { l.Info("hello", "a", 1, "b", "x") },
expect: map[string]any{
"ts": "2025-01-01T00:00:00Z",
"level": "info",
"msg": "hello",
"a": float64(1),
"b": "x",
},
},
{
name: "with attrs inherited",
build: func(w *bytes.Buffer) *Logger {
l := New(w).With("service", "billing")
l.now = fixedClock(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC))
return l
},
call: func(l *Logger) { l.Warn("careful") },
expect: map[string]any{
"ts": "2025-01-01T00:00:00Z",
"level": "warn",
"msg": "careful",
"service": "billing",
},
},
{
name: "odd kv count yields MISSING",
build: func(w *bytes.Buffer) *Logger {
l := New(w)
l.now = fixedClock(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC))
return l
},
call: func(l *Logger) { l.Error("oops", "only_key") },
expect: map[string]any{
"ts": "2025-01-01T00:00:00Z",
"level": "error",
"msg": "oops",
"MISSING": "only_key",
},
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
tc.build(&buf).With("", "").attrs = nil
l := tc.build(&buf)
tc.call(l)
got := decode(t, buf.Bytes())
if diff := cmp.Diff(tc.expect, got, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("record mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLoggerFromContextAddsRequestID(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
l := New(&buf)
l.now = fixedClock(time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC))
ctx := WithRequestID(context.Background(), "req-42")
l.FromContext(ctx).Info("hi")
got := decode(t, buf.Bytes())
if got["rid"] != "req-42" {
t.Fatalf("expected rid=req-42, got %v", got["rid"])
}
}
func TestLevelFiltering(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
l := New(&buf).WithLevel(LevelWarn)
l.now = fixedClock(time.Now())
l.Info("suppressed")
l.Error("kept")
if strings.Contains(buf.String(), "suppressed") {
t.Fatalf("Info should have been dropped: %q", buf.String())
}
if !strings.Contains(buf.String(), "kept") {
t.Fatalf("Error should have been emitted: %q", buf.String())
}
}
func TestLoggerConcurrentWritesAreLineSafe(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
l := New(&buf)
l.now = fixedClock(time.Now())
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
l.Info("m", "i", i)
}(i)
}
wg.Wait()
lines := strings.Split(strings.TrimRight(buf.String(), "\n"), "\n")
if len(lines) != 50 {
t.Fatalf("expected 50 lines, got %d", len(lines))
}
for _, line := range lines {
if !json.Valid([]byte(line)) {
t.Fatalf("line is not valid json: %q", line)
}
}
}