adapters/sqs/consumer_test.go

package sqs_test

import (
	"context"
	"errors"
	"sort"
	"testing"

	"github.com/aws/aws-lambda-go/events"

	adapter "mercemay.top/src/lambdalog/adapters/sqs"
)

type capture struct{ msgs []string }

func (c *capture) Info(msg string, _ ...adapter.Field)        { c.msgs = append(c.msgs, "info:"+msg) }
func (c *capture) Error(msg string, _ ...adapter.Field)       { c.msgs = append(c.msgs, "error:"+msg) }
func (c *capture) With(_ ...adapter.Field) adapter.Logger     { return c }

func TestHandler_CollectsPartialFailures(t *testing.T) {
	c := &capture{}
	h := adapter.Handler(c, func(ctx context.Context, m events.SQSMessage) error {
		if m.MessageId == "bad" {
			return errors.New("boom")
		}
		return nil
	})
	resp, err := h(context.Background(), events.SQSEvent{Records: []events.SQSMessage{
		{MessageId: "a"}, {MessageId: "bad"}, {MessageId: "b"},
	}})
	if err != nil {
		t.Fatalf("err: %v", err)
	}
	if len(resp.BatchItemFailures) != 1 {
		t.Fatalf("failures: %#v", resp.BatchItemFailures)
	}
	if resp.BatchItemFailures[0].ItemIdentifier != "bad" {
		t.Fatalf("wrong failure: %+v", resp.BatchItemFailures)
	}
}

func TestHandler_AllOK(t *testing.T) {
	c := &capture{}
	h := adapter.Handler(c, func(ctx context.Context, m events.SQSMessage) error { return nil })
	resp, _ := h(context.Background(), events.SQSEvent{Records: []events.SQSMessage{{MessageId: "a"}}})
	if len(resp.BatchItemFailures) != 0 {
		t.Fatalf("unexpected failures: %+v", resp.BatchItemFailures)
	}
}

func TestStopBatchHandler(t *testing.T) {
	c := &capture{}
	call := 0
	h := adapter.StopBatchHandler(c, func(ctx context.Context, m events.SQSMessage) error {
		call++
		if call == 2 {
			return adapter.ErrStopBatch
		}
		return nil
	})
	resp, _ := h(context.Background(), events.SQSEvent{Records: []events.SQSMessage{
		{MessageId: "a"}, {MessageId: "b"}, {MessageId: "c"},
	}})
	if len(resp.BatchItemFailures) != 2 {
		t.Fatalf("failures: %#v", resp.BatchItemFailures)
	}
	ids := []string{resp.BatchItemFailures[0].ItemIdentifier, resp.BatchItemFailures[1].ItemIdentifier}
	sort.Strings(ids)
	if ids[0] != "b" || ids[1] != "c" {
		t.Fatalf("wrong ids: %v", ids)
	}
}

func TestGroupIDs(t *testing.T) {
	evt := events.SQSEvent{Records: []events.SQSMessage{
		{Attributes: map[string]string{"MessageGroupId": "g1"}},
		{Attributes: map[string]string{"MessageGroupId": "g1"}},
		{Attributes: map[string]string{"MessageGroupId": "g2"}},
		{Attributes: map[string]string{}},
	}}
	got := adapter.GroupIDs(evt)
	sort.Strings(got)
	if len(got) != 2 || got[0] != "g1" || got[1] != "g2" {
		t.Fatalf("groups: %v", got)
	}
}