package html
import (
"bytes"
"strings"
"testing"
)
func TestStrictDropsScripts(t *testing.T) {
t.Parallel()
cases := []struct {
name string
in string
wantHas []string
wantMiss []string
}{
{
name: "drop_script_tag",
in: `<p>hi</p><script>alert(1)</script>`,
wantHas: []string{"<p>hi</p>"},
wantMiss: []string{"<script", "alert"},
},
{
name: "drop_event_handler",
in: `<a href="https://example.com" onclick="bad()">x</a>`,
wantHas: []string{`href="https://example.com"`},
wantMiss: []string{"onclick"},
},
{
name: "keep_em_strong",
in: `<em>a</em> <strong>b</strong>`,
wantHas: []string{"<em>a</em>", "<strong>b</strong>"},
wantMiss: []string{},
},
{
name: "drop_javascript_uri",
in: `<a href="javascript:evil()">x</a>`,
wantHas: []string{},
wantMiss: []string{"javascript:"},
},
}
pol := Strict()
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := pol.Sanitize(tc.in)
for _, h := range tc.wantHas {
if !strings.Contains(got, h) {
t.Errorf("want substring %q in %q", h, got)
}
}
for _, m := range tc.wantMiss {
if strings.Contains(got, m) {
t.Errorf("did not want %q in %q", m, got)
}
}
})
}
}
func TestSanitizeBytesAndReader(t *testing.T) {
t.Parallel()
pol := Strict()
in := []byte(`<p>ok</p><script>x</script>`)
got := pol.SanitizeBytes(in)
if strings.Contains(string(got), "script") {
t.Errorf("script survived: %s", got)
}
var buf bytes.Buffer
if err := pol.SanitizeReader(&buf, bytes.NewReader(in)); err != nil {
t.Fatalf("SanitizeReader: %v", err)
}
if strings.Contains(buf.String(), "script") {
t.Errorf("reader variant let script through: %s", buf.String())
}
}
func TestStripAll(t *testing.T) {
t.Parallel()
in := "<p>hello <b>world</b></p>"
got := StripAll().Sanitize(in)
if strings.Contains(got, "<") {
t.Errorf("strip left tags: %q", got)
}
}
func TestToPlainText(t *testing.T) {
t.Parallel()
in := "<p> hello\n\n <b>world</b> </p>"
got := ToPlainText(in)
want := "hello world"
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestCollapseWhitespace(t *testing.T) {
t.Parallel()
t.Helper()
in := " a\t\tb\n\nc "
want := "a b c"
got := collapseWhitespace(in)
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}