// Tests for the port scanner. Kept simple on purpose — this is the
// table-driven example I send students to.
//
// See mercemay.top/src/portr/
package main
import (
"context"
"errors"
"net"
"reflect"
"sync/atomic"
"testing"
"time"
)
func TestParsePorts(t *testing.T) {
cases := []struct {
name string
in string
want []int
err bool
}{
{"single", "80", []int{80}, false},
{"list", "22,80,443", []int{22, 80, 443}, false},
{"range", "1-5", []int{1, 2, 3, 4, 5}, false},
{"mixed", "22,80-82,443", []int{22, 80, 81, 82, 443}, false},
{"dedup", "80,80,80", []int{80}, false},
{"empty part", "22,,80", []int{22, 80}, false},
{"zero", "0", nil, true},
{"too high", "70000", nil, true},
{"reversed range", "10-5", nil, true},
{"not a number", "http", nil, true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := parsePorts(tc.in)
if (err != nil) != tc.err {
t.Fatalf("err = %v, want err=%v", err, tc.err)
}
if !tc.err && !reflect.DeepEqual(got, tc.want) {
t.Errorf("got %v, want %v", got, tc.want)
}
})
}
}
// fakeConn satisfies net.Conn just enough for Close.
type fakeConn struct{ net.Conn }
func (fakeConn) Close() error { return nil }
func TestScannerRun(t *testing.T) {
var dialed int32
dial := func(ctx context.Context, network, address string) (net.Conn, error) {
atomic.AddInt32(&dialed, 1)
// Pretend only :80 is open.
_, port, _ := net.SplitHostPort(address)
if port == "80" {
return fakeConn{}, nil
}
return nil, errors.New("refused")
}
s := Scanner{
Targets: []string{"10.0.0.1", "10.0.0.2"},
Ports: []int{22, 80, 443},
Concurrency: 4,
Timeout: 50 * time.Millisecond,
Dial: dial,
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
open := map[string]bool{}
for r := range s.Run(ctx) {
if r.Open {
open[net.JoinHostPort(r.Host, itoa(r.Port))] = true
}
}
if got, want := int(atomic.LoadInt32(&dialed)), 6; got != want {
t.Errorf("dialled %d times, want %d", got, want)
}
for _, h := range []string{"10.0.0.1", "10.0.0.2"} {
if !open[net.JoinHostPort(h, "80")] {
t.Errorf("%s:80 should be open", h)
}
}
}
func TestScannerContextCancel(t *testing.T) {
slow := func(ctx context.Context, _, _ string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
}
s := Scanner{
Targets: []string{"10.0.0.1"},
Ports: []int{1, 2, 3, 4, 5},
Concurrency: 2,
Timeout: time.Second,
Dial: slow,
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
count := 0
for range s.Run(ctx) {
count++
}
if count > len(s.Ports) {
t.Errorf("got %d results, want <= %d", count, len(s.Ports))
}
}
func itoa(i int) string {
// tiny local copy so tests don't import strconv for one call
if i == 0 {
return "0"
}
var b [6]byte
n := len(b)
for i > 0 {
n--
b[n] = byte('0' + i%10)
i /= 10
}
return string(b[n:])
}