main.go

// Command portr is a small concurrent TCP port scanner.
//
// See mercemay.top/src/portr/ for history and why this is archived.
package main

import (
	"context"
	"fmt"
	"net"
	"os"
	"os/signal"
	"strings"
	"syscall"
	"time"

	flag "github.com/spf13/pflag"
)

func main() {
	var (
		portsArg    string
		concurrency int
		timeout     time.Duration
		noTUI       bool
	)

	flag.StringVarP(&portsArg, "ports", "p", "1-1024", "ports, e.g. 22,80,443 or 1-1024")
	flag.IntVar(&concurrency, "concurrency", 128, "number of concurrent probes")
	flag.DurationVar(&timeout, "timeout", 500*time.Millisecond, "per-connect timeout")
	flag.BoolVar(&noTUI, "no-tui", false, "do not launch the TUI")
	flag.Parse()

	if flag.NArg() == 0 {
		fmt.Fprintln(os.Stderr, "usage: portr [flags] <target>")
		os.Exit(2)
	}

	target := flag.Arg(0)
	hosts, err := expand(target)
	if err != nil {
		fatal("bad target:", err)
	}

	ports, err := parsePorts(portsArg)
	if err != nil {
		fatal("bad --ports:", err)
	}

	ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
	defer cancel()

	scan := Scanner{
		Targets:     hosts,
		Ports:       ports,
		Concurrency: concurrency,
		Timeout:     timeout,
		Dial:        (&net.Dialer{Timeout: timeout}).DialContext,
	}

	results := scan.Run(ctx)

	if noTUI {
		printPlain(results)
		return
	}

	if err := runTUI(ctx, results, len(hosts)*len(ports)); err != nil {
		fatal("tui:", err)
	}
}

func printPlain(results <-chan Result) {
	for r := range results {
		if r.Open {
			fmt.Printf("%s:%d open\n", r.Host, r.Port)
		}
	}
}

// expand turns a user-supplied target (single host, comma list, or CIDR)
// into the concrete host strings we need to scan.
func expand(target string) ([]string, error) {
	if strings.Contains(target, "/") {
		_, ipnet, err := net.ParseCIDR(target)
		if err != nil {
			return nil, err
		}
		var hosts []string
		for ip := ipnet.IP.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
			hosts = append(hosts, ip.String())
		}
		// trim network and broadcast for IPv4
		if len(hosts) > 2 && ipnet.IP.To4() != nil {
			hosts = hosts[1 : len(hosts)-1]
		}
		return hosts, nil
	}
	if strings.Contains(target, ",") {
		return strings.Split(target, ","), nil
	}
	return []string{target}, nil
}

func inc(ip net.IP) {
	for j := len(ip) - 1; j >= 0; j-- {
		ip[j]++
		if ip[j] > 0 {
			return
		}
	}
}

func fatal(args ...interface{}) {
	fmt.Fprintln(os.Stderr, append([]interface{}{"portr:"}, args...)...)
	os.Exit(1)
}