package http_test
import (
stdhttp "net/http"
"net/http/httptest"
"testing"
mw "mercemay.top/src/lambdalog/middleware/http"
)
type stubLogger struct{ msgs []string }
func (s *stubLogger) Info(m string, _ ...mw.Field) { s.msgs = append(s.msgs, m) }
func (s *stubLogger) Error(m string, _ ...mw.Field) { s.msgs = append(s.msgs, "err:"+m) }
func (s *stubLogger) With(_ ...mw.Field) mw.Logger { return s }
func TestInjectLogger_Attaches(t *testing.T) {
l := &stubLogger{}
h := mw.InjectLogger(l)(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
if got := mw.LoggerFrom(r.Context(), nil); got == nil {
t.Fatal("logger missing")
}
w.WriteHeader(200)
}))
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
h.ServeHTTP(rec, req)
if rec.Code != 200 {
t.Fatalf("code %d", rec.Code)
}
}
func TestLoggerFrom_Fallback(t *testing.T) {
l := &stubLogger{}
got := mw.LoggerFrom(nil, l)
if got != l {
t.Fatalf("fallback not returned")
}
}
func TestChain_AppliesInOrder(t *testing.T) {
order := []string{}
mk := func(name string) func(stdhttp.Handler) stdhttp.Handler {
return func(next stdhttp.Handler) stdhttp.Handler {
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
order = append(order, name+"-before")
next.ServeHTTP(w, r)
order = append(order, name+"-after")
})
}
}
chain := mw.Chain(mk("a"), mk("b"))
h := chain(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
order = append(order, "inner")
}))
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
want := []string{"a-before", "b-before", "inner", "b-after", "a-after"}
if len(order) != len(want) {
t.Fatalf("order: %v", order)
}
for i := range want {
if order[i] != want[i] {
t.Fatalf("order: %v, want %v", order, want)
}
}
}
func TestCounter(t *testing.T) {
var c mw.Counter
c.Inc()
c.Inc()
if c.Load() != 2 {
t.Fatalf("load: %d", c.Load())
}
c.Reset()
if c.Load() != 0 {
t.Fatalf("reset failed")
}
}