scanner_test.go

// Tests for the port scanner. Kept simple on purpose — this is the
// table-driven example I send students to.
//
// See mercemay.top/src/portr/
package main

import (
	"context"
	"errors"
	"net"
	"reflect"
	"sync/atomic"
	"testing"
	"time"
)

func TestParsePorts(t *testing.T) {
	cases := []struct {
		name string
		in   string
		want []int
		err  bool
	}{
		{"single", "80", []int{80}, false},
		{"list", "22,80,443", []int{22, 80, 443}, false},
		{"range", "1-5", []int{1, 2, 3, 4, 5}, false},
		{"mixed", "22,80-82,443", []int{22, 80, 81, 82, 443}, false},
		{"dedup", "80,80,80", []int{80}, false},
		{"empty part", "22,,80", []int{22, 80}, false},
		{"zero", "0", nil, true},
		{"too high", "70000", nil, true},
		{"reversed range", "10-5", nil, true},
		{"not a number", "http", nil, true},
	}
	for _, tc := range cases {
		t.Run(tc.name, func(t *testing.T) {
			got, err := parsePorts(tc.in)
			if (err != nil) != tc.err {
				t.Fatalf("err = %v, want err=%v", err, tc.err)
			}
			if !tc.err && !reflect.DeepEqual(got, tc.want) {
				t.Errorf("got %v, want %v", got, tc.want)
			}
		})
	}
}

// fakeConn satisfies net.Conn just enough for Close.
type fakeConn struct{ net.Conn }

func (fakeConn) Close() error { return nil }

func TestScannerRun(t *testing.T) {
	var dialed int32
	dial := func(ctx context.Context, network, address string) (net.Conn, error) {
		atomic.AddInt32(&dialed, 1)
		// Pretend only :80 is open.
		_, port, _ := net.SplitHostPort(address)
		if port == "80" {
			return fakeConn{}, nil
		}
		return nil, errors.New("refused")
	}

	s := Scanner{
		Targets:     []string{"10.0.0.1", "10.0.0.2"},
		Ports:       []int{22, 80, 443},
		Concurrency: 4,
		Timeout:     50 * time.Millisecond,
		Dial:        dial,
	}

	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
	defer cancel()

	open := map[string]bool{}
	for r := range s.Run(ctx) {
		if r.Open {
			open[net.JoinHostPort(r.Host, itoa(r.Port))] = true
		}
	}

	if got, want := int(atomic.LoadInt32(&dialed)), 6; got != want {
		t.Errorf("dialled %d times, want %d", got, want)
	}
	for _, h := range []string{"10.0.0.1", "10.0.0.2"} {
		if !open[net.JoinHostPort(h, "80")] {
			t.Errorf("%s:80 should be open", h)
		}
	}
}

func TestScannerContextCancel(t *testing.T) {
	slow := func(ctx context.Context, _, _ string) (net.Conn, error) {
		<-ctx.Done()
		return nil, ctx.Err()
	}

	s := Scanner{
		Targets:     []string{"10.0.0.1"},
		Ports:       []int{1, 2, 3, 4, 5},
		Concurrency: 2,
		Timeout:     time.Second,
		Dial:        slow,
	}

	ctx, cancel := context.WithCancel(context.Background())
	cancel()

	count := 0
	for range s.Run(ctx) {
		count++
	}
	if count > len(s.Ports) {
		t.Errorf("got %d results, want <= %d", count, len(s.Ports))
	}
}

func itoa(i int) string {
	// tiny local copy so tests don't import strconv for one call
	if i == 0 {
		return "0"
	}
	var b [6]byte
	n := len(b)
	for i > 0 {
		n--
		b[n] = byte('0' + i%10)
		i /= 10
	}
	return string(b[n:])
}