internal/parser/http1/headers_test.go

package http1

import (
	"bufio"
	"strings"
	"testing"

	"github.com/google/go-cmp/cmp"
)

func TestReadHeaders(t *testing.T) {
	cases := []struct {
		name string
		raw  string
		want [][2]string
	}{
		{
			name: "plain",
			raw:  "Host: example.com\r\nUser-Agent: curl/8.0\r\n\r\n",
			want: [][2]string{
				{"Host", "example.com"},
				{"User-Agent", "curl/8.0"},
			},
		},
		{
			name: "fold_continuation",
			raw:  "X-Long: a\r\n b\r\n\tc\r\n\r\n",
			want: [][2]string{
				{"X-Long", "a b c"},
			},
		},
		{
			name: "mixed_case_key",
			raw:  "content-TYPE: application/json\r\n\r\n",
			want: [][2]string{
				{"Content-Type", "application/json"},
			},
		},
		{
			name: "duplicate_keys_preserved",
			raw:  "Set-Cookie: a=1\r\nSet-Cookie: b=2\r\n\r\n",
			want: [][2]string{
				{"Set-Cookie", "a=1"},
				{"Set-Cookie", "b=2"},
			},
		},
	}
	for _, tc := range cases {
		t.Run(tc.name, func(t *testing.T) {
			r := bufio.NewReader(strings.NewReader(tc.raw))
			got, err := ReadHeaders(r)
			if err != nil {
				t.Fatalf("ReadHeaders: %v", err)
			}
			if diff := cmp.Diff(tc.want, got); diff != "" {
				t.Errorf("mismatch (-want +got):\n%s", diff)
			}
		})
	}
}

func TestReadHeadersErrors(t *testing.T) {
	for _, bad := range []string{
		"no-colon-line\r\n\r\n",
		":empty-name\r\n\r\n",
		" fold-first\r\n\r\n",
	} {
		t.Run(bad, func(t *testing.T) {
			r := bufio.NewReader(strings.NewReader(bad))
			if _, err := ReadHeaders(r); err == nil {
				t.Fatal("expected error")
			}
		})
	}
}

func TestGetAndGetAll(t *testing.T) {
	h := [][2]string{
		{"Set-Cookie", "a"},
		{"Content-Type", "text/plain"},
		{"Set-Cookie", "b"},
	}
	if got := Get(h, "content-type"); got != "text/plain" {
		t.Errorf("Get = %q", got)
	}
	if got := GetAll(h, "Set-Cookie"); !cmp.Equal(got, []string{"a", "b"}) {
		t.Errorf("GetAll = %#v", got)
	}
}

func TestCanonicalKey(t *testing.T) {
	tests := map[string]string{
		"host":                       "Host",
		"CONTENT-LENGTH":             "Content-Length",
		"x-forwarded-for":            "X-Forwarded-For",
		"x-ray--marker":              "X-Ray--Marker",
	}
	for in, want := range tests {
		if got := CanonicalKey(in); got != want {
			t.Errorf("CanonicalKey(%q)=%q want %q", in, got, want)
		}
	}
}