internal/parser/parser_test.go

package parser

import (
	"bytes"
	"io"
	"strings"
	"testing"

	"github.com/google/go-cmp/cmp"
)

func TestParserSniff(t *testing.T) {
	cases := []struct {
		name string
		in   string
		mode string
	}{
		{
			name: "http1_request",
			in:   "GET / HTTP/1.1\r\nHost: x\r\n\r\n",
			mode: "h1",
		},
		{
			name: "http2_preface",
			in:   "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n",
			mode: "h2",
		},
	}
	for _, tc := range cases {
		t.Run(tc.name, func(t *testing.T) {
			p := New(strings.NewReader(tc.in), func(Message) {})
			if err := p.sniff(); err != nil {
				t.Fatalf("sniff: %v", err)
			}
			if p.mode != tc.mode {
				t.Fatalf("mode = %q, want %q", p.mode, tc.mode)
			}
		})
	}
}

func TestParserRunHTTP1(t *testing.T) {
	raw := "GET /a HTTP/1.1\r\nHost: example.com\r\n\r\n" +
		"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nhi"
	var got []Message
	p := New(strings.NewReader(raw), func(m Message) {
		got = append(got, m)
	})
	if err := p.Run(); err != nil && err != io.EOF {
		t.Fatalf("run: %v", err)
	}
	if len(got) != 2 {
		t.Fatalf("got %d messages, want 2", len(got))
	}
	wantLines := []string{"GET /a HTTP/1.1", "HTTP/1.1 200 OK"}
	for i, m := range got {
		if m.StartLine != wantLines[i] {
			t.Errorf("msg[%d]=%q want %q", i, m.StartLine, wantLines[i])
		}
	}
}

func TestParserFlushIncomplete(t *testing.T) {
	// Truncated request: headers present, body promised, none delivered.
	raw := "POST /x HTTP/1.1\r\nContent-Length: 10\r\n\r\nhi"
	var got []Message
	p := New(strings.NewReader(raw), func(m Message) {
		got = append(got, m)
	})
	_ = p.Run()
	p.Flush()
	if len(got) != 1 || !got[0].Incomplete {
		t.Fatalf("expected one incomplete message, got %+v", got)
	}
}

func TestMessageDiff(t *testing.T) {
	a := Message{StartLine: "GET /", Headers: [][2]string{{"Host", "x"}}}
	b := Message{StartLine: "GET /", Headers: [][2]string{{"Host", "x"}}}
	if diff := cmp.Diff(a, b); diff != "" {
		t.Fatalf("unexpected diff:\n%s", diff)
	}
}

func BenchmarkParserRun(b *testing.B) {
	raw := bytes.Repeat([]byte("GET / HTTP/1.1\r\nHost: h\r\n\r\n"), 32)
	for i := 0; i < b.N; i++ {
		p := New(bytes.NewReader(raw), func(Message) {})
		if err := p.Run(); err != nil && err != io.EOF {
			b.Fatal(err)
		}
	}
}