internal/render/markdown/extensions/tasklists.go

package extensions

import (
	"fmt"

	"github.com/yuin/goldmark"
	"github.com/yuin/goldmark/ast"
	"github.com/yuin/goldmark/parser"
	"github.com/yuin/goldmark/renderer"
	"github.com/yuin/goldmark/text"
	"github.com/yuin/goldmark/util"
)

// TaskListItem is an inline AST node attached at the start of a list item
// when the item begins with "[ ]" or "[x]".
type TaskListItem struct {
	ast.BaseInline
	Checked bool
}

// KindTaskListItem identifies the node.
var KindTaskListItem = ast.NewNodeKind("TaskListItem")

func (*TaskListItem) Kind() ast.NodeKind { return KindTaskListItem }

// Dump satisfies ast.Node.
func (n *TaskListItem) Dump(src []byte, level int) { ast.DumpHelper(n, src, level, nil, nil) }

// TaskLists is the goldmark extender.
type TaskLists struct{}

// Extend wires up parser and renderer hooks.
func (TaskLists) Extend(m goldmark.Markdown) {
	m.Parser().AddOptions(parser.WithInlineParsers(
		util.Prioritized(&taskParser{}, 0),
	))
	m.Renderer().AddOptions(renderer.WithNodeRenderers(
		util.Prioritized(&taskRenderer{}, 500),
	))
}

type taskParser struct{}

func (p *taskParser) Trigger() []byte { return []byte{'['} }

func (p *taskParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node {
	// Only match at the start of a list item.
	if _, ok := parent.(*ast.TextBlock); !ok {
		return nil
	}
	if parent.Parent() == nil {
		return nil
	}
	if _, ok := parent.Parent().(*ast.ListItem); !ok {
		return nil
	}
	line, _ := block.PeekLine()
	if len(line) < 4 {
		return nil
	}
	if line[0] != '[' || line[2] != ']' {
		return nil
	}
	switch line[1] {
	case ' ':
		block.Advance(4)
		return &TaskListItem{Checked: false}
	case 'x', 'X':
		block.Advance(4)
		return &TaskListItem{Checked: true}
	}
	return nil
}

type taskRenderer struct{}

func (r *taskRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
	reg.Register(KindTaskListItem, r.render)
}

func (r *taskRenderer) render(w util.BufWriter, src []byte, n ast.Node, entering bool) (ast.WalkStatus, error) {
	if !entering {
		return ast.WalkContinue, nil
	}
	t := n.(*TaskListItem)
	if t.Checked {
		fmt.Fprint(w, `<input type="checkbox" checked disabled> `)
	} else {
		fmt.Fprint(w, `<input type="checkbox" disabled> `)
	}
	return ast.WalkContinue, nil
}

// CountTasks walks a subtree and returns (total, done) counts.
func CountTasks(root ast.Node) (int, int) {
	var total, done int
	_ = ast.Walk(root, func(n ast.Node, entering bool) (ast.WalkStatus, error) {
		if !entering {
			return ast.WalkContinue, nil
		}
		if t, ok := n.(*TaskListItem); ok {
			total++
			if t.Checked {
				done++
			}
		}
		return ast.WalkContinue, nil
	})
	return total, done
}