adapters/sqs/consumer.go

// Package sqs adapts lambdalog to SQS event sources. The adapter walks each
// record in the batch, binding a logger to the message id, and aggregates
// partial failures into an SQSBatchResponse.
//
// See mercemay.top/src/lambdalog/adapters/sqs/.
package sqs

import (
	"context"
	"errors"
	"time"

	"github.com/aws/aws-lambda-go/events"
)

// Logger is the narrow interface consumed by this adapter.
type Logger interface {
	Info(msg string, fields ...Field)
	Error(msg string, fields ...Field)
	With(fields ...Field) Logger
}

// Field mirrors encoder.Field.
type Field struct {
	Key   string
	Value any
}

// MessageHandler processes a single SQS message.
type MessageHandler func(ctx context.Context, msg events.SQSMessage) error

// Handler adapts a MessageHandler to the batch interface expected by the
// Lambda runtime. Each record is processed independently; errors surface as
// BatchItemFailures so SQS only redelivers the failed records.
func Handler(logger Logger, msg MessageHandler) func(context.Context, events.SQSEvent) (events.SQSEventResponse, error) {
	return func(ctx context.Context, evt events.SQSEvent) (events.SQSEventResponse, error) {
		var resp events.SQSEventResponse
		for _, m := range evt.Records {
			start := time.Now()
			l := logger.With(
				Field{Key: "message_id", Value: m.MessageId},
				Field{Key: "source_arn", Value: m.EventSourceARN},
			)
			l.Info("sqs.message.start")
			err := msg(ctx, m)
			dur := time.Since(start)
			if err != nil {
				l.Error("sqs.message.end",
					Field{Key: "duration_ms", Value: dur.Milliseconds()},
					Field{Key: "error", Value: err.Error()})
				resp.BatchItemFailures = append(resp.BatchItemFailures, events.SQSBatchItemFailure{ItemIdentifier: m.MessageId})
				continue
			}
			l.Info("sqs.message.end",
				Field{Key: "duration_ms", Value: dur.Milliseconds()})
		}
		return resp, nil
	}
}

// ErrStopBatch can be returned by a MessageHandler to abort processing of
// the remaining messages in the current batch. The SQS runtime will treat
// them as failed and redeliver.
var ErrStopBatch = errors.New("sqs: stop batch")

// StopBatchHandler is an alternative wrapper that honours ErrStopBatch.
// Messages after the stop signal are reported as failed without invoking
// the user handler.
func StopBatchHandler(logger Logger, msg MessageHandler) func(context.Context, events.SQSEvent) (events.SQSEventResponse, error) {
	return func(ctx context.Context, evt events.SQSEvent) (events.SQSEventResponse, error) {
		var resp events.SQSEventResponse
		stopped := false
		for _, m := range evt.Records {
			if stopped {
				resp.BatchItemFailures = append(resp.BatchItemFailures, events.SQSBatchItemFailure{ItemIdentifier: m.MessageId})
				continue
			}
			err := msg(ctx, m)
			if errors.Is(err, ErrStopBatch) {
				stopped = true
				resp.BatchItemFailures = append(resp.BatchItemFailures, events.SQSBatchItemFailure{ItemIdentifier: m.MessageId})
				continue
			}
			if err != nil {
				logger.Error("sqs.message.end", Field{Key: "message_id", Value: m.MessageId}, Field{Key: "error", Value: err.Error()})
				resp.BatchItemFailures = append(resp.BatchItemFailures, events.SQSBatchItemFailure{ItemIdentifier: m.MessageId})
				continue
			}
		}
		return resp, nil
	}
}

// GroupIDs returns the distinct message group ids in evt, which is useful
// for adaptive concurrency-control strategies at the consumer side.
func GroupIDs(evt events.SQSEvent) []string {
	seen := map[string]bool{}
	var out []string
	for _, m := range evt.Records {
		g := m.Attributes["MessageGroupId"]
		if g == "" || seen[g] {
			continue
		}
		seen[g] = true
		out = append(out, g)
	}
	return out
}