handler.go

// Handler wraps a user Lambda handler so every invocation sees a logger in
// context that is already bound to the Lambda request id.
//
// Usage (inside main):
//
//	lambda.Start(lambdalog.Handler(myHandler))
//
// The wrapper does not allocate a new Logger per invocation; it clones the
// process-global logger's attribute slice but reuses the underlying writer
// and mutex.
package lambdalog

import (
	"context"
	"errors"
	"reflect"
)

// HandlerFunc is the shape of a typed Lambda handler after reflection.
// The wrapper accepts any function with one of the supported signatures and
// injects a logger into ctx before calling it.
type HandlerFunc func(ctx context.Context, payload []byte) ([]byte, error)

// Handler returns a wrapper around fn that enriches ctx with a Logger. fn
// must be a function with a signature compatible with the AWS Lambda Go
// runtime: (ctx, in) -> (out, error), (ctx) -> error, or (ctx, in) -> error.
func Handler(fn any) HandlerFunc {
	return rootLogger.Handler(fn)
}

// Handler returns a wrapper that uses l as the base logger. Prefer this over
// the package-level Handler when you need a non-default logger configuration.
func (l *Logger) Handler(fn any) HandlerFunc {
	v := reflect.ValueOf(fn)
	t := v.Type()
	if t.Kind() != reflect.Func {
		return func(context.Context, []byte) ([]byte, error) {
			return nil, errors.New("lambdalog.Handler: argument is not a function")
		}
	}
	sig, err := classify(t)
	if err != nil {
		return func(context.Context, []byte) ([]byte, error) { return nil, err }
	}
	return func(ctx context.Context, payload []byte) ([]byte, error) {
		child := l.FromContext(ctx)
		ctx = WithLogger(ctx, child)
		return sig.call(v, ctx, payload)
	}
}

// LoggerFromContext returns the logger attached by Handler. If no logger is
// attached, a root logger is returned so callers need not nil-check.
func LoggerFromContext(ctx context.Context) *Logger {
	if ctx == nil {
		return rootLogger
	}
	if l, ok := ctx.Value(loggerKey).(*Logger); ok && l != nil {
		return l
	}
	return rootLogger
}

// WithLogger attaches l to ctx. The attached logger is returned by
// LoggerFromContext for any descendant of ctx.
func WithLogger(ctx context.Context, l *Logger) context.Context {
	if ctx == nil {
		ctx = context.Background()
	}
	return context.WithValue(ctx, loggerKey, l)
}

var rootLogger = New(nil)

type loggerCtxKey struct{}

var loggerKey = loggerCtxKey{}

type signature struct {
	wantsInput  bool
	returnsData bool
	inputType   reflect.Type
}

func classify(t reflect.Type) (signature, error) {
	var s signature
	if t.NumIn() < 1 || t.NumIn() > 2 {
		return s, errors.New("handler must take (ctx) or (ctx, in)")
	}
	if !t.In(0).Implements(ctxType) {
		return s, errors.New("first argument must be context.Context")
	}
	if t.NumIn() == 2 {
		s.wantsInput = true
		s.inputType = t.In(1)
	}
	switch t.NumOut() {
	case 1:
		if !t.Out(0).Implements(errType) {
			return s, errors.New("single return must be error")
		}
	case 2:
		if !t.Out(1).Implements(errType) {
			return s, errors.New("second return must be error")
		}
		s.returnsData = true
	default:
		return s, errors.New("handler must return error or (T, error)")
	}
	return s, nil
}

func (s signature) call(fn reflect.Value, ctx context.Context, payload []byte) ([]byte, error) {
	args := []reflect.Value{reflect.ValueOf(ctx)}
	if s.wantsInput {
		in := reflect.New(s.inputType).Interface()
		if err := decodePayload(payload, in); err != nil {
			return nil, err
		}
		args = append(args, reflect.ValueOf(in).Elem())
	}
	out := fn.Call(args)
	if s.returnsData {
		if !out[1].IsNil() {
			return nil, out[1].Interface().(error)
		}
		return encodePayload(out[0].Interface())
	}
	if !out[0].IsNil() {
		return nil, out[0].Interface().(error)
	}
	return nil, nil
}

var (
	ctxType = reflect.TypeOf((*context.Context)(nil)).Elem()
	errType = reflect.TypeOf((*error)(nil)).Elem()
)