// Worker-pool TCP connect scanner.
//
// See mercemay.top/src/portr/ for context.
package main
import (
"context"
"fmt"
"net"
"sort"
"strconv"
"strings"
"sync"
"time"
)
// DialFunc matches net.Dialer.DialContext so tests can inject a fake.
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
// Result is one open/closed verdict.
type Result struct {
Host string
Port int
Open bool
When time.Duration
}
// Scanner runs a pool of goroutines hitting every (host, port) pair.
type Scanner struct {
Targets []string
Ports []int
Concurrency int
Timeout time.Duration
Dial DialFunc
}
type job struct {
host string
port int
}
// Run launches workers and returns a channel of results. The channel
// closes when the scan finishes or ctx is cancelled.
func (s Scanner) Run(ctx context.Context) <-chan Result {
if s.Concurrency <= 0 {
s.Concurrency = 64
}
if s.Dial == nil {
s.Dial = (&net.Dialer{Timeout: s.Timeout}).DialContext
}
jobs := make(chan job, s.Concurrency)
out := make(chan Result, s.Concurrency)
var wg sync.WaitGroup
for i := 0; i < s.Concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := range jobs {
select {
case <-ctx.Done():
return
default:
}
out <- s.probe(ctx, j)
}
}()
}
go func() {
defer close(jobs)
for _, h := range s.Targets {
for _, p := range s.Ports {
select {
case <-ctx.Done():
return
case jobs <- job{host: h, port: p}:
}
}
}
}()
go func() {
wg.Wait()
close(out)
}()
return out
}
func (s Scanner) probe(ctx context.Context, j job) Result {
start := time.Now()
addr := net.JoinHostPort(j.host, strconv.Itoa(j.port))
conn, err := s.Dial(ctx, "tcp", addr)
r := Result{Host: j.host, Port: j.port, When: time.Since(start)}
if err == nil {
r.Open = true
_ = conn.Close()
}
return r
}
// parsePorts accepts "22,80,443", "1-1024", or a mix like "22,80-90".
func parsePorts(spec string) ([]int, error) {
set := make(map[int]struct{})
for _, part := range strings.Split(spec, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if strings.Contains(part, "-") {
bounds := strings.SplitN(part, "-", 2)
lo, err1 := strconv.Atoi(bounds[0])
hi, err2 := strconv.Atoi(bounds[1])
if err1 != nil || err2 != nil || lo > hi || lo < 1 || hi > 65535 {
return nil, fmt.Errorf("invalid range %q", part)
}
for p := lo; p <= hi; p++ {
set[p] = struct{}{}
}
continue
}
p, err := strconv.Atoi(part)
if err != nil || p < 1 || p > 65535 {
return nil, fmt.Errorf("invalid port %q", part)
}
set[p] = struct{}{}
}
out := make([]int, 0, len(set))
for p := range set {
out = append(out, p)
}
sort.Ints(out)
return out, nil
}