internal/render/markdown/renderer.go

package markdown

import (
	"bytes"
	"fmt"
	"html"
	"io"
	"strings"

	"github.com/yuin/goldmark/ast"
	extast "github.com/yuin/goldmark/extension/ast"
)

// Renderer walks a parsed Document and writes HTML. It intentionally
// re-implements the render loop instead of delegating to goldmark's html
// renderer so that tilstream can inject section wrappers, footnote anchors,
// and custom code-block handlers.
type Renderer struct {
	CodeBlockHandler func(lang string, body []byte, w io.Writer) error
	HeadingPrefix    string
}

// NewRenderer returns a Renderer with sensible defaults.
func NewRenderer() *Renderer {
	return &Renderer{
		CodeBlockHandler: defaultCodeBlock,
		HeadingPrefix:    "til-",
	}
}

// Render writes HTML for doc into w.
func (r *Renderer) Render(w io.Writer, doc *Document) error {
	if doc == nil || doc.Root == nil {
		return fmt.Errorf("renderer: nil document")
	}
	return ast.Walk(doc.Root, func(n ast.Node, entering bool) (ast.WalkStatus, error) {
		return r.visit(w, doc, n, entering)
	})
}

// RenderString is a convenience used in tests and summaries.
func (r *Renderer) RenderString(doc *Document) (string, error) {
	var buf bytes.Buffer
	if err := r.Render(&buf, doc); err != nil {
		return "", err
	}
	return buf.String(), nil
}

func (r *Renderer) visit(w io.Writer, d *Document, n ast.Node, entering bool) (ast.WalkStatus, error) {
	switch node := n.(type) {
	case *ast.Document:
		// nothing to do
	case *ast.Paragraph:
		if entering {
			fmt.Fprint(w, "<p>")
		} else {
			fmt.Fprint(w, "</p>\n")
		}
	case *ast.Heading:
		if entering {
			id := headingID(d, node)
			fmt.Fprintf(w, `<h%d id="%s%s">`, node.Level, r.HeadingPrefix, id)
		} else {
			fmt.Fprintf(w, "</h%d>\n", node.Level)
		}
	case *ast.Text:
		if entering {
			seg := node.Segment
			esc := html.EscapeString(string(seg.Value(d.Source)))
			fmt.Fprint(w, esc)
			if node.HardLineBreak() {
				fmt.Fprint(w, "<br>")
			} else if node.SoftLineBreak() {
				fmt.Fprint(w, "\n")
			}
		}
	case *ast.Emphasis:
		tag := "em"
		if node.Level == 2 {
			tag = "strong"
		}
		if entering {
			fmt.Fprintf(w, "<%s>", tag)
		} else {
			fmt.Fprintf(w, "</%s>", tag)
		}
	case *ast.Link:
		if entering {
			fmt.Fprintf(w, `<a href="%s">`, html.EscapeString(string(node.Destination)))
		} else {
			fmt.Fprint(w, "</a>")
		}
	case *ast.CodeSpan:
		if entering {
			fmt.Fprint(w, "<code>")
		} else {
			fmt.Fprint(w, "</code>")
		}
	case *ast.FencedCodeBlock:
		if entering {
			lang := string(node.Language(d.Source))
			body := extractLines(d.Source, node.Lines())
			if r.CodeBlockHandler != nil {
				if err := r.CodeBlockHandler(lang, body, w); err != nil {
					return ast.WalkStop, err
				}
				return ast.WalkSkipChildren, nil
			}
		}
	case *ast.List:
		tag := "ul"
		if node.IsOrdered() {
			tag = "ol"
		}
		if entering {
			fmt.Fprintf(w, "<%s>\n", tag)
		} else {
			fmt.Fprintf(w, "</%s>\n", tag)
		}
	case *ast.ListItem:
		if entering {
			fmt.Fprint(w, "<li>")
		} else {
			fmt.Fprint(w, "</li>\n")
		}
	case *ast.Blockquote:
		if entering {
			fmt.Fprint(w, "<blockquote>")
		} else {
			fmt.Fprint(w, "</blockquote>\n")
		}
	case *extast.Strikethrough:
		if entering {
			fmt.Fprint(w, "<del>")
		} else {
			fmt.Fprint(w, "</del>")
		}
	}
	return ast.WalkContinue, nil
}

func defaultCodeBlock(lang string, body []byte, w io.Writer) error {
	fmt.Fprintf(w, `<pre><code class="language-%s">`, html.EscapeString(lang))
	w.Write([]byte(html.EscapeString(string(body))))
	fmt.Fprint(w, "</code></pre>\n")
	return nil
}

func extractLines(src []byte, lines *ast.TextSegments) []byte {
	var buf bytes.Buffer
	if lines == nil {
		return nil
	}
	for i := 0; i < lines.Len(); i++ {
		seg := lines.At(i)
		buf.Write(seg.Value(src))
	}
	return buf.Bytes()
}

func headingID(d *Document, h *ast.Heading) string {
	if id, ok := h.AttributeString("id"); ok {
		return string(id.([]byte))
	}
	return slug(d.TextOf(h))
}

func slug(s string) string {
	var b strings.Builder
	for _, r := range strings.ToLower(strings.TrimSpace(s)) {
		switch {
		case r >= 'a' && r <= 'z', r >= '0' && r <= '9':
			b.WriteRune(r)
		case r == ' ' || r == '-' || r == '_':
			b.WriteByte('-')
		}
	}
	return b.String()
}