internal/parser/http1/headers.go

package http1

import (
	"bufio"
	"errors"
	"fmt"
	"io"
	"strings"
)

// ErrMalformedHeader is returned when a header line cannot be decoded.
var ErrMalformedHeader = errors.New("http1: malformed header")

// maxHeaderBytes is the per-block cap; at the limit we return an error so
// a misbehaving peer cannot drive the tap to OOM.
const maxHeaderBytes = 64 * 1024

// ReadHeaders consumes the header block up to and including the empty
// CRLF terminator. Obsolete line folding ("obs-fold") is re-joined onto
// the preceding header value.
func ReadHeaders(r *bufio.Reader) ([][2]string, error) {
	var (
		out   [][2]string
		used  int
		last  int // index in out of the most recently appended header
	)
	last = -1
	for {
		line, err := r.ReadString('\n')
		used += len(line)
		if err != nil && err != io.EOF {
			return nil, fmt.Errorf("%w: %w", ErrMalformedHeader, err)
		}
		if used > maxHeaderBytes {
			return nil, fmt.Errorf("%w: header block exceeds %d bytes",
				ErrMalformedHeader, maxHeaderBytes)
		}
		stripped := strings.TrimRight(line, "\r\n")
		if stripped == "" {
			return out, nil
		}
		// obs-fold: starts with SP or HT -> continuation of previous.
		if stripped[0] == ' ' || stripped[0] == '\t' {
			if last < 0 {
				return nil, fmt.Errorf("%w: fold before any header", ErrMalformedHeader)
			}
			out[last][1] = out[last][1] + " " + strings.TrimSpace(stripped)
			continue
		}
		k, v, ok := splitHeader(stripped)
		if !ok {
			return nil, fmt.Errorf("%w: %q", ErrMalformedHeader, stripped)
		}
		out = append(out, [2]string{CanonicalKey(k), v})
		last = len(out) - 1
		if err == io.EOF {
			return out, nil
		}
	}
}

// splitHeader splits on the first colon and trims surrounding whitespace.
func splitHeader(line string) (string, string, bool) {
	i := strings.IndexByte(line, ':')
	if i <= 0 {
		return "", "", false
	}
	name := strings.TrimRight(line[:i], " \t")
	value := strings.TrimLeft(line[i+1:], " \t")
	if name == "" {
		return "", "", false
	}
	if !isToken(name) {
		return "", "", false
	}
	return name, value, true
}

// CanonicalKey returns the MIME-canonical casing ("content-type" ->
// "Content-Type"). The net/textproto implementation is not used directly
// because we want to preserve original casing for non-ASCII headers that
// sometimes appear in captures.
func CanonicalKey(k string) string {
	var b strings.Builder
	b.Grow(len(k))
	upper := true
	for i := 0; i < len(k); i++ {
		c := k[i]
		switch {
		case upper && c >= 'a' && c <= 'z':
			b.WriteByte(c - 32)
		case !upper && c >= 'A' && c <= 'Z':
			b.WriteByte(c + 32)
		default:
			b.WriteByte(c)
		}
		upper = c == '-'
	}
	return b.String()
}

// Get returns the first value for the given canonical key. It is O(n) but
// header blocks are small in practice, and avoiding a map keeps ordering.
func Get(h [][2]string, key string) string {
	ck := CanonicalKey(key)
	for _, kv := range h {
		if kv[0] == ck {
			return kv[1]
		}
	}
	return ""
}

// GetAll returns every value for the given key, in order of appearance.
func GetAll(h [][2]string, key string) []string {
	ck := CanonicalKey(key)
	var out []string
	for _, kv := range h {
		if kv[0] == ck {
			out = append(out, kv[1])
		}
	}
	return out
}