package worker
import (
"context"
"errors"
"net"
"sync/atomic"
"testing"
"time"
"github.com/mercemay/portr/internal/check/rate"
"github.com/mercemay/portr/internal/check/retry"
"github.com/mercemay/portr/internal/config/target"
)
func TestPool_runsAllTargets(t *testing.T) {
var calls int32
dial := func(_ context.Context, _, _ string) (net.Conn, error) {
atomic.AddInt32(&calls, 1)
return nil, errors.New("nope")
}
p := NewPool(4, dial, rate.New(0), retry.New(0, 0, time.Second))
targets := make([]target.Target, 20)
for i := range targets {
targets[i] = target.Target{Host: "h", Port: i + 1}
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
count := 0
for r := range p.Run(ctx, targets) {
count++
if r.Open {
t.Errorf("should not be open: %v", r)
}
}
if count != 20 {
t.Errorf("got %d results, want 20", count)
}
if got := atomic.LoadInt32(&calls); got < 20 {
t.Errorf("dial called %d times, want >= 20", got)
}
}
func TestPool_respectsCancel(t *testing.T) {
dial := func(ctx context.Context, _, _ string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
}
p := NewPool(2, dial, rate.New(0), retry.New(0, 0, time.Second))
ctx, cancel := context.WithCancel(context.Background())
cancel()
count := 0
for range p.Run(ctx, []target.Target{{Host: "h", Port: 1}}) {
count++
}
_ = count
}