package sqs_test
import (
"context"
"errors"
"sort"
"testing"
"github.com/aws/aws-lambda-go/events"
adapter "mercemay.top/src/lambdalog/adapters/sqs"
)
type capture struct{ msgs []string }
func (c *capture) Info(msg string, _ ...adapter.Field) { c.msgs = append(c.msgs, "info:"+msg) }
func (c *capture) Error(msg string, _ ...adapter.Field) { c.msgs = append(c.msgs, "error:"+msg) }
func (c *capture) With(_ ...adapter.Field) adapter.Logger { return c }
func TestHandler_CollectsPartialFailures(t *testing.T) {
c := &capture{}
h := adapter.Handler(c, func(ctx context.Context, m events.SQSMessage) error {
if m.MessageId == "bad" {
return errors.New("boom")
}
return nil
})
resp, err := h(context.Background(), events.SQSEvent{Records: []events.SQSMessage{
{MessageId: "a"}, {MessageId: "bad"}, {MessageId: "b"},
}})
if err != nil {
t.Fatalf("err: %v", err)
}
if len(resp.BatchItemFailures) != 1 {
t.Fatalf("failures: %#v", resp.BatchItemFailures)
}
if resp.BatchItemFailures[0].ItemIdentifier != "bad" {
t.Fatalf("wrong failure: %+v", resp.BatchItemFailures)
}
}
func TestHandler_AllOK(t *testing.T) {
c := &capture{}
h := adapter.Handler(c, func(ctx context.Context, m events.SQSMessage) error { return nil })
resp, _ := h(context.Background(), events.SQSEvent{Records: []events.SQSMessage{{MessageId: "a"}}})
if len(resp.BatchItemFailures) != 0 {
t.Fatalf("unexpected failures: %+v", resp.BatchItemFailures)
}
}
func TestStopBatchHandler(t *testing.T) {
c := &capture{}
call := 0
h := adapter.StopBatchHandler(c, func(ctx context.Context, m events.SQSMessage) error {
call++
if call == 2 {
return adapter.ErrStopBatch
}
return nil
})
resp, _ := h(context.Background(), events.SQSEvent{Records: []events.SQSMessage{
{MessageId: "a"}, {MessageId: "b"}, {MessageId: "c"},
}})
if len(resp.BatchItemFailures) != 2 {
t.Fatalf("failures: %#v", resp.BatchItemFailures)
}
ids := []string{resp.BatchItemFailures[0].ItemIdentifier, resp.BatchItemFailures[1].ItemIdentifier}
sort.Strings(ids)
if ids[0] != "b" || ids[1] != "c" {
t.Fatalf("wrong ids: %v", ids)
}
}
func TestGroupIDs(t *testing.T) {
evt := events.SQSEvent{Records: []events.SQSMessage{
{Attributes: map[string]string{"MessageGroupId": "g1"}},
{Attributes: map[string]string{"MessageGroupId": "g1"}},
{Attributes: map[string]string{"MessageGroupId": "g2"}},
{Attributes: map[string]string{}},
}}
got := adapter.GroupIDs(evt)
sort.Strings(got)
if len(got) != 2 || got[0] != "g1" || got[1] != "g2" {
t.Fatalf("groups: %v", got)
}
}