internal/render/markdown/extensions/toc.go

package extensions

import (
	"bytes"
	"fmt"
	"html"
	"io"
	"strings"

	"github.com/yuin/goldmark/ast"
)

// TOCEntry is a single heading in the generated table of contents.
type TOCEntry struct {
	Level    int
	Text     string
	ID       string
	Children []*TOCEntry
}

// TOC is a tree of entries. The zero value is a valid empty TOC.
type TOC struct {
	Root TOCEntry
}

// BuildTOC walks the AST and builds a nested TOC from headings.
// Headings at a lower level than MinLevel are ignored. Headings above
// MaxLevel are skipped to avoid deeply nested tables of contents.
func BuildTOC(root ast.Node, source []byte, minLevel, maxLevel int) *TOC {
	if minLevel <= 0 {
		minLevel = 1
	}
	if maxLevel <= 0 || maxLevel > 6 {
		maxLevel = 6
	}
	toc := &TOC{Root: TOCEntry{Level: 0}}
	stack := []*TOCEntry{&toc.Root}

	_ = ast.Walk(root, func(n ast.Node, entering bool) (ast.WalkStatus, error) {
		if !entering {
			return ast.WalkContinue, nil
		}
		h, ok := n.(*ast.Heading)
		if !ok {
			return ast.WalkContinue, nil
		}
		if h.Level < minLevel || h.Level > maxLevel {
			return ast.WalkSkipChildren, nil
		}
		text := extractHeadingText(h, source)
		id := slugify(text)
		if existing, ok := h.AttributeString("id"); ok {
			id = string(existing.([]byte))
		}
		entry := &TOCEntry{Level: h.Level, Text: text, ID: id}
		for len(stack) > 1 && stack[len(stack)-1].Level >= h.Level {
			stack = stack[:len(stack)-1]
		}
		parent := stack[len(stack)-1]
		parent.Children = append(parent.Children, entry)
		stack = append(stack, entry)
		return ast.WalkSkipChildren, nil
	})
	return toc
}

// Render writes the TOC as a nested <ul> tree to w.
func (t *TOC) Render(w io.Writer) error {
	if t == nil || len(t.Root.Children) == 0 {
		return nil
	}
	return renderEntries(w, t.Root.Children)
}

// RenderString is a convenience wrapper.
func (t *TOC) RenderString() string {
	var buf bytes.Buffer
	_ = t.Render(&buf)
	return buf.String()
}

func renderEntries(w io.Writer, entries []*TOCEntry) error {
	if _, err := io.WriteString(w, "<ul class=\"toc\">\n"); err != nil {
		return err
	}
	for _, e := range entries {
		if _, err := fmt.Fprintf(w, `  <li><a href="#%s">%s</a>`, e.ID, html.EscapeString(e.Text)); err != nil {
			return err
		}
		if len(e.Children) > 0 {
			if _, err := io.WriteString(w, "\n"); err != nil {
				return err
			}
			if err := renderEntries(w, e.Children); err != nil {
				return err
			}
		}
		if _, err := io.WriteString(w, "</li>\n"); err != nil {
			return err
		}
	}
	if _, err := io.WriteString(w, "</ul>\n"); err != nil {
		return err
	}
	return nil
}

func extractHeadingText(h *ast.Heading, src []byte) string {
	var buf bytes.Buffer
	for c := h.FirstChild(); c != nil; c = c.NextSibling() {
		if t, ok := c.(*ast.Text); ok {
			buf.Write(t.Segment.Value(src))
		}
	}
	return strings.TrimSpace(buf.String())
}

func slugify(s string) string {
	s = strings.ToLower(strings.TrimSpace(s))
	var b strings.Builder
	b.Grow(len(s))
	lastDash := false
	for _, r := range s {
		switch {
		case r >= 'a' && r <= 'z', r >= '0' && r <= '9':
			b.WriteRune(r)
			lastDash = false
		case r == ' ' || r == '-' || r == '_':
			if !lastDash {
				b.WriteByte('-')
				lastDash = true
			}
		}
	}
	return strings.Trim(b.String(), "-")
}

// Flatten returns all TOC entries in document order.
func (t *TOC) Flatten() []*TOCEntry {
	var out []*TOCEntry
	var walk func([]*TOCEntry)
	walk = func(list []*TOCEntry) {
		for _, e := range list {
			out = append(out, e)
			walk(e.Children)
		}
	}
	walk(t.Root.Children)
	return out
}