scanner.go

// Worker-pool TCP connect scanner.
//
// See mercemay.top/src/portr/ for context.
package main

import (
	"context"
	"fmt"
	"net"
	"sort"
	"strconv"
	"strings"
	"sync"
	"time"
)

// DialFunc matches net.Dialer.DialContext so tests can inject a fake.
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)

// Result is one open/closed verdict.
type Result struct {
	Host string
	Port int
	Open bool
	When time.Duration
}

// Scanner runs a pool of goroutines hitting every (host, port) pair.
type Scanner struct {
	Targets     []string
	Ports       []int
	Concurrency int
	Timeout     time.Duration
	Dial        DialFunc
}

type job struct {
	host string
	port int
}

// Run launches workers and returns a channel of results. The channel
// closes when the scan finishes or ctx is cancelled.
func (s Scanner) Run(ctx context.Context) <-chan Result {
	if s.Concurrency <= 0 {
		s.Concurrency = 64
	}
	if s.Dial == nil {
		s.Dial = (&net.Dialer{Timeout: s.Timeout}).DialContext
	}

	jobs := make(chan job, s.Concurrency)
	out := make(chan Result, s.Concurrency)

	var wg sync.WaitGroup
	for i := 0; i < s.Concurrency; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for j := range jobs {
				select {
				case <-ctx.Done():
					return
				default:
				}
				out <- s.probe(ctx, j)
			}
		}()
	}

	go func() {
		defer close(jobs)
		for _, h := range s.Targets {
			for _, p := range s.Ports {
				select {
				case <-ctx.Done():
					return
				case jobs <- job{host: h, port: p}:
				}
			}
		}
	}()

	go func() {
		wg.Wait()
		close(out)
	}()

	return out
}

func (s Scanner) probe(ctx context.Context, j job) Result {
	start := time.Now()
	addr := net.JoinHostPort(j.host, strconv.Itoa(j.port))
	conn, err := s.Dial(ctx, "tcp", addr)
	r := Result{Host: j.host, Port: j.port, When: time.Since(start)}
	if err == nil {
		r.Open = true
		_ = conn.Close()
	}
	return r
}

// parsePorts accepts "22,80,443", "1-1024", or a mix like "22,80-90".
func parsePorts(spec string) ([]int, error) {
	set := make(map[int]struct{})
	for _, part := range strings.Split(spec, ",") {
		part = strings.TrimSpace(part)
		if part == "" {
			continue
		}
		if strings.Contains(part, "-") {
			bounds := strings.SplitN(part, "-", 2)
			lo, err1 := strconv.Atoi(bounds[0])
			hi, err2 := strconv.Atoi(bounds[1])
			if err1 != nil || err2 != nil || lo > hi || lo < 1 || hi > 65535 {
				return nil, fmt.Errorf("invalid range %q", part)
			}
			for p := lo; p <= hi; p++ {
				set[p] = struct{}{}
			}
			continue
		}
		p, err := strconv.Atoi(part)
		if err != nil || p < 1 || p > 65535 {
			return nil, fmt.Errorf("invalid port %q", part)
		}
		set[p] = struct{}{}
	}
	out := make([]int, 0, len(set))
	for p := range set {
		out = append(out, p)
	}
	sort.Ints(out)
	return out, nil
}