middleware/http/middleware_test.go

package http_test

import (
	stdhttp "net/http"
	"net/http/httptest"
	"testing"

	mw "mercemay.top/src/lambdalog/middleware/http"
)

type stubLogger struct{ msgs []string }

func (s *stubLogger) Info(m string, _ ...mw.Field)  { s.msgs = append(s.msgs, m) }
func (s *stubLogger) Error(m string, _ ...mw.Field) { s.msgs = append(s.msgs, "err:"+m) }
func (s *stubLogger) With(_ ...mw.Field) mw.Logger  { return s }

func TestInjectLogger_Attaches(t *testing.T) {
	l := &stubLogger{}
	h := mw.InjectLogger(l)(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
		if got := mw.LoggerFrom(r.Context(), nil); got == nil {
			t.Fatal("logger missing")
		}
		w.WriteHeader(200)
	}))
	rec := httptest.NewRecorder()
	req := httptest.NewRequest("GET", "/", nil)
	h.ServeHTTP(rec, req)
	if rec.Code != 200 {
		t.Fatalf("code %d", rec.Code)
	}
}

func TestLoggerFrom_Fallback(t *testing.T) {
	l := &stubLogger{}
	got := mw.LoggerFrom(nil, l)
	if got != l {
		t.Fatalf("fallback not returned")
	}
}

func TestChain_AppliesInOrder(t *testing.T) {
	order := []string{}
	mk := func(name string) func(stdhttp.Handler) stdhttp.Handler {
		return func(next stdhttp.Handler) stdhttp.Handler {
			return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
				order = append(order, name+"-before")
				next.ServeHTTP(w, r)
				order = append(order, name+"-after")
			})
		}
	}
	chain := mw.Chain(mk("a"), mk("b"))
	h := chain(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
		order = append(order, "inner")
	}))
	h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
	want := []string{"a-before", "b-before", "inner", "b-after", "a-after"}
	if len(order) != len(want) {
		t.Fatalf("order: %v", order)
	}
	for i := range want {
		if order[i] != want[i] {
			t.Fatalf("order: %v, want %v", order, want)
		}
	}
}

func TestCounter(t *testing.T) {
	var c mw.Counter
	c.Inc()
	c.Inc()
	if c.Load() != 2 {
		t.Fatalf("load: %d", c.Load())
	}
	c.Reset()
	if c.Load() != 0 {
		t.Fatalf("reset failed")
	}
}