context.go

package lambdalog

import (
	"context"

	"github.com/aws/aws-lambda-go/lambdacontext"
)

// ctxKey is an unexported type used for context keys to avoid collisions with
// keys defined by other packages.
type ctxKey int

const (
	fieldsKey ctxKey = iota + 1
	ridOverrideKey
)

// WithField returns a copy of ctx that carries an extra logging attribute.
// When the context is passed to Logger.FromContext, the attribute is emitted
// as a JSON field on every record.
//
// Attributes are accumulated, not replaced: calling WithField twice with
// different keys yields a context with both keys present.
func WithField(ctx context.Context, key string, value any) context.Context {
	if ctx == nil {
		ctx = context.Background()
	}
	existing, _ := ctx.Value(fieldsKey).([]attr)
	// Deep-copy so two derived contexts don't share a backing array.
	next := make([]attr, len(existing), len(existing)+1)
	copy(next, existing)
	next = append(next, attr{Key: key, Value: value})
	return context.WithValue(ctx, fieldsKey, next)
}

// WithRequestID forces the Lambda request id carried on ctx to the given
// value. This is mainly useful for tests and for code paths that synthesize
// a request id outside of a Lambda invocation.
func WithRequestID(ctx context.Context, rid string) context.Context {
	if ctx == nil {
		ctx = context.Background()
	}
	return context.WithValue(ctx, ridOverrideKey, rid)
}

// requestIDFromContext returns the Lambda request id, preferring an explicit
// override set via WithRequestID and otherwise falling back to the runtime
// value exposed by lambdacontext.FromContext.
func requestIDFromContext(ctx context.Context) string {
	if v, ok := ctx.Value(ridOverrideKey).(string); ok && v != "" {
		return v
	}
	if lc, ok := lambdacontext.FromContext(ctx); ok && lc != nil {
		return lc.AwsRequestID
	}
	return ""
}

// fieldsFromContext returns any attributes accumulated via WithField.
func fieldsFromContext(ctx context.Context) []attr {
	v, _ := ctx.Value(fieldsKey).([]attr)
	if len(v) == 0 {
		return nil
	}
	out := make([]attr, len(v))
	copy(out, v)
	return out
}