diff --git a/adaptive.go b/adaptive.go index f332548..57a9dca 100644 --- a/adaptive.go +++ b/adaptive.go @@ -42,8 +42,10 @@ type AdaptiveThrottle struct { k float64 minPerWindow float64 - requests []windowedCounter - accepts []windowedCounter + priorities int + requests []windowedCounter + accepts []windowedCounter + validate func(p Priority, priorities int) (Priority, error) } // NewAdaptiveThrottle returns an AdaptiveThrottle. @@ -51,6 +53,10 @@ type AdaptiveThrottle struct { // priorities is the number of priorities that the throttle will accept. Giving a priority outside // of `[0, priorities)` will panic. func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *AdaptiveThrottle { + if priorities <= 0 { + panic("bulwark: priorities must be greater than 0") + } + opts := adaptiveThrottleOptions{ d: time.Minute, k: K, @@ -68,11 +74,18 @@ func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *Ada accepts[i] = newWindowedCounter(now, opts.d/10, 10) } + validate := opts.validate + if validate == nil { + validate = ClampInvalidPriority + } + return &AdaptiveThrottle{ k: opts.k, + priorities: priorities, requests: requests, accepts: accepts, minPerWindow: opts.minRate * opts.d.Seconds(), + validate: validate, } } @@ -93,7 +106,11 @@ func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *Ada func (t *AdaptiveThrottle) Throttle( ctx context.Context, defaultPriority Priority, fn throttledFn, fallbackFn ...fallbackFn, ) error { - priority := PriorityFromContext(ctx, defaultPriority) + priority, err := t.validate(PriorityFromContext(ctx, defaultPriority), t.priorities) + if err != nil { + return err + } + now := Now() rejectionProbability := t.rejectionProbability(priority, now) if rand.Float64() < rejectionProbability { @@ -112,7 +129,7 @@ func (t *AdaptiveThrottle) Throttle( return ClientSideRejectionError } - err := fn(ctx) + err = fn(ctx) now = Now() switch { @@ -191,6 +208,7 @@ type adaptiveThrottleOptions struct { minRate float64 d time.Duration isErrorAccepted func(err error) bool + validate func(p Priority, priorities int) (Priority, error) } // WithAdaptiveThrottleRatio sets the ratio of the measured success rate and the rate that the throttle @@ -234,14 +252,38 @@ func WithAcceptedErrors(fn func(err error) bool) AdaptiveThrottleOption { }} } +// WithPriorityValidator sets the function that validates input priority values. +// +// The function should return the validated priority value. If the priority is +// invalid, the function should return an error. +func WithPriorityValidator(fn func(p Priority, priorities int) (Priority, error)) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.validate = func(p Priority, priorities int) (Priority, error) { + p, err := fn(p, priorities) + if err != nil { + return p, err + } + + // Safeguard in case the validator returns an out-of-range priority + // without an error. Clamp rather than panic to stay consistent with + // the default behaviour. + return ClampInvalidPriority(p, priorities) + } + }} +} + func Throttle[T any]( ctx context.Context, at *AdaptiveThrottle, defaultPriority Priority, throttledFn throttledArgsFn[T], fallbackFn ...fallbackArgsFn[T], -) (T, error) { - priority := PriorityFromContext(ctx, defaultPriority) +) (res T, err error) { + priority, err := at.validate(PriorityFromContext(ctx, defaultPriority), at.priorities) + if err != nil { + return res, err + } + now := Now() rejectionProbability := at.rejectionProbability(priority, now) if rand.Float64() < rejectionProbability { @@ -261,7 +303,7 @@ func Throttle[T any]( return zero, ClientSideRejectionError } - t, err := throttledFn(ctx) + res, err = throttledFn(ctx) now = Now() switch { @@ -282,7 +324,7 @@ func Throttle[T any]( return fallbackFn[0](ctx, err, false) } - return t, err + return res, err } // WithAdaptiveThrottle is used to send a request to a backend using the given AdaptiveThrottle for @@ -298,7 +340,12 @@ func WithAdaptiveThrottle[T any]( at *AdaptiveThrottle, priority Priority, throttledFn func() (T, error), -) (T, error) { +) (res T, err error) { + priority, err = at.validate(priority, at.priorities) + if err != nil { + return res, err + } + now := Now() rejectionProbability := at.rejectionProbability(priority, now) if rand.Float64() < rejectionProbability { @@ -314,7 +361,7 @@ func WithAdaptiveThrottle[T any]( return zero, ClientSideRejectionError } - t, err := throttledFn() + res, err = throttledFn() now = Now() switch { @@ -331,7 +378,7 @@ func WithAdaptiveThrottle[T any]( at.accept(priority, now) } - return t, err + return res, err } // RejectedError wraps an error to indicate that the error should be considered diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..b76876e --- /dev/null +++ b/validator.go @@ -0,0 +1,48 @@ +package bulwark + +import ( + "fmt" + "log/slog" + + "github.com/deixis/faults" +) + +var ( + // AssertValidPriority panics when a priority is out of range. + // A priority is out of range when it is less than 0 or greater than or equal + // to priorities. + AssertValidPriority = func(p Priority, priorities int) (Priority, error) { + if p < 0 || int(p) >= priorities { + panic(fmt.Sprintf("bulwark: priority must be in the range [0, %d), but got %d", priorities, p)) + } + + return p, nil + } + + // ClampInvalidPriority clamps any out-of-range priority to the lowest valid + // priority (priorities-1). This applies to both negative values and values + // that exceed the configured number of priorities, preventing invalid or + // malicious input from being promoted to a higher-importance tier. + ClampInvalidPriority = func(p Priority, priorities int) (Priority, error) { + if p >= 0 && int(p) < priorities { + return p, nil + } + slog.Warn("bulwark: priority is out of range", "max", priorities-1, "priority", p) + + return Priority(priorities - 1), nil + } + + // RejectInvalidPriority returns an error when a priority is out of range. + // A priority is out of range when it is less than 0 or greater than or equal + // to priorities. + RejectInvalidPriority = func(p Priority, priorities int) (Priority, error) { + if p < 0 || int(p) >= priorities { + return p, faults.Bad(&faults.FieldViolation{ + Field: "priority", + Description: fmt.Sprintf("priority must be in the range [0, %d), but got %d", priorities, p), + }) + } + + return p, nil + } +) diff --git a/validator_test.go b/validator_test.go new file mode 100644 index 0000000..aebcc81 --- /dev/null +++ b/validator_test.go @@ -0,0 +1,163 @@ +package bulwark + +import ( + "context" + "testing" +) + +// TestAssertValidPriority verifies that AssertValidPriority passes valid +// priorities through unchanged and panics for any priority outside [0, priorities). +// This validator is intended for development environments where misuse of the +// library should be caught immediately rather than silently corrected. +func TestAssertValidPriority(t *testing.T) { + t.Run("valid priority passes through", func(t *testing.T) { + for _, p := range []Priority{High, Important, Medium, Low} { + got, err := AssertValidPriority(p, StandardPriorities) + if err != nil { + t.Errorf("priority %d: unexpected error: %v", p, err) + } + if got != p { + t.Errorf("priority %d: expected %d, got %d", p, p, got) + } + } + }) + + t.Run("negative priority panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic, got none") + } + }() + AssertValidPriority(-1, StandardPriorities) + }) + + t.Run("out of range priority panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic, got none") + } + }() + AssertValidPriority(Priority(StandardPriorities), StandardPriorities) + }) +} + +// TestClampInvalidPriority verifies that ClampInvalidPriority passes valid +// priorities through unchanged and clamps any out-of-range priority — including +// negative values — to the lowest valid priority (priorities-1). Clamping to +// the lowest rather than the nearest boundary ensures that invalid or malicious +// input is never silently promoted to a higher-importance tier. +func TestClampInvalidPriority(t *testing.T) { + t.Run("valid priority passes through", func(t *testing.T) { + for _, p := range []Priority{High, Important, Medium, Low} { + got, err := ClampInvalidPriority(p, StandardPriorities) + if err != nil { + t.Errorf("priority %d: unexpected error: %v", p, err) + } + if got != p { + t.Errorf("priority %d: expected %d, got %d", p, p, got) + } + } + }) + + t.Run("negative priority clamped to lowest", func(t *testing.T) { + got, err := ClampInvalidPriority(-1, StandardPriorities) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if got != Priority(StandardPriorities-1) { + t.Errorf("expected %d, got %d", StandardPriorities-1, got) + } + }) + + t.Run("out of range priority adjusted to lowest", func(t *testing.T) { + got, err := ClampInvalidPriority(Priority(StandardPriorities), StandardPriorities) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if got != Priority(StandardPriorities-1) { + t.Errorf("expected %d, got %d", StandardPriorities-1, got) + } + }) +} + +// TestRejectInvalidPriority verifies that RejectInvalidPriority passes valid +// priorities through unchanged and returns an error for any priority outside +// [0, priorities). This validator is suited for APIs where the caller is +// responsible for providing a valid priority and must handle the error. +func TestRejectInvalidPriority(t *testing.T) { + t.Run("valid priority passes through", func(t *testing.T) { + for _, p := range []Priority{High, Important, Medium, Low} { + got, err := RejectInvalidPriority(p, StandardPriorities) + if err != nil { + t.Errorf("priority %d: unexpected error: %v", p, err) + } + if got != p { + t.Errorf("priority %d: expected %d, got %d", p, p, got) + } + } + }) + + t.Run("negative priority returns error", func(t *testing.T) { + _, err := RejectInvalidPriority(-1, StandardPriorities) + if err == nil { + t.Error("expected error, got nil") + } + }) + + t.Run("out of range priority returns error", func(t *testing.T) { + _, err := RejectInvalidPriority(Priority(StandardPriorities), StandardPriorities) + if err == nil { + t.Error("expected error, got nil") + } + }) +} + +// TestWithPriorityValidator verifies that WithPriorityValidator wires a custom +// validator into the throttle so that it is called on every request. It also +// confirms that the default behaviour (no option provided) uses +// ClampInvalidPriority: out-of-range priorities are clamped rather than +// rejected, so the throttled function is still invoked. +func TestWithPriorityValidator(t *testing.T) { + t.Run("custom validator is applied", func(t *testing.T) { + called := false + throttle := NewAdaptiveThrottle(StandardPriorities, + WithPriorityValidator(func(p Priority, priorities int) (Priority, error) { + called = true + return p, nil + }), + ) + throttle.Throttle(context.Background(), High, func(_ context.Context) error { + return nil + }) + if !called { + t.Error("expected custom validator to be called") + } + }) + + t.Run("RejectInvalidPriority used as validator rejects invalid priority", func(t *testing.T) { + throttle := NewAdaptiveThrottle(StandardPriorities, + WithPriorityValidator(RejectInvalidPriority), + ) + err := throttle.Throttle(context.Background(), Priority(StandardPriorities), func(_ context.Context) error { + return nil + }) + if err == nil { + t.Error("expected error for out-of-range priority, got nil") + } + }) + + t.Run("default validator adjusts invalid priority", func(t *testing.T) { + throttle := NewAdaptiveThrottle(StandardPriorities) + called := false + err := throttle.Throttle(context.Background(), Priority(StandardPriorities), func(_ context.Context) error { + called = true + return nil + }) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !called { + t.Error("expected throttled function to be called after priority adjustment") + } + }) +}