Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions adaptive.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,21 @@ type AdaptiveThrottle struct {
k float64
minPerWindow float64

requests []windowedCounter
accepts []windowedCounter
priorities int
requests []windowedCounter
accepts []windowedCounter
validate func(p Priority, priorities int) (Priority, error)
}

// NewAdaptiveThrottle returns an AdaptiveThrottle.
//
// priorities is the number of priorities that the throttle will accept. Giving a priority outside
// of `[0, priorities)` will panic.
func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *AdaptiveThrottle {
if priorities <= 0 {
panic("bulwark: priorities must be greater than 0")
Comment thread
basgys marked this conversation as resolved.
}

opts := adaptiveThrottleOptions{
d: time.Minute,
k: K,
Expand All @@ -68,11 +74,18 @@ func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *Ada
accepts[i] = newWindowedCounter(now, opts.d/10, 10)
}

validate := opts.validate
if validate == nil {
validate = ClampInvalidPriority
}

return &AdaptiveThrottle{
k: opts.k,
priorities: priorities,
requests: requests,
accepts: accepts,
minPerWindow: opts.minRate * opts.d.Seconds(),
validate: validate,
}
}

Expand All @@ -93,7 +106,11 @@ func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *Ada
func (t *AdaptiveThrottle) Throttle(
ctx context.Context, defaultPriority Priority, fn throttledFn, fallbackFn ...fallbackFn,
) error {
priority := PriorityFromContext(ctx, defaultPriority)
priority, err := t.validate(PriorityFromContext(ctx, defaultPriority), t.priorities)
if err != nil {
return err
}

now := Now()
rejectionProbability := t.rejectionProbability(priority, now)
if rand.Float64() < rejectionProbability {
Expand All @@ -112,7 +129,7 @@ func (t *AdaptiveThrottle) Throttle(
return ClientSideRejectionError
}

err := fn(ctx)
err = fn(ctx)

now = Now()
switch {
Expand Down Expand Up @@ -191,6 +208,7 @@ type adaptiveThrottleOptions struct {
minRate float64
d time.Duration
isErrorAccepted func(err error) bool
validate func(p Priority, priorities int) (Priority, error)
}

// WithAdaptiveThrottleRatio sets the ratio of the measured success rate and the rate that the throttle
Expand Down Expand Up @@ -234,14 +252,38 @@ func WithAcceptedErrors(fn func(err error) bool) AdaptiveThrottleOption {
}}
}

// WithPriorityValidator sets the function that validates input priority values.
//
// The function should return the validated priority value. If the priority is
// invalid, the function should return an error.
func WithPriorityValidator(fn func(p Priority, priorities int) (Priority, error)) AdaptiveThrottleOption {
return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) {
opts.validate = func(p Priority, priorities int) (Priority, error) {
p, err := fn(p, priorities)
if err != nil {
return p, err
}

// Safeguard in case the validator returns an out-of-range priority
// without an error. Clamp rather than panic to stay consistent with
// the default behaviour.
return ClampInvalidPriority(p, priorities)
}
}}
}

func Throttle[T any](
ctx context.Context,
at *AdaptiveThrottle,
defaultPriority Priority,
throttledFn throttledArgsFn[T],
fallbackFn ...fallbackArgsFn[T],
) (T, error) {
priority := PriorityFromContext(ctx, defaultPriority)
) (res T, err error) {
priority, err := at.validate(PriorityFromContext(ctx, defaultPriority), at.priorities)
if err != nil {
return res, err
}

now := Now()
rejectionProbability := at.rejectionProbability(priority, now)
if rand.Float64() < rejectionProbability {
Expand All @@ -261,7 +303,7 @@ func Throttle[T any](
return zero, ClientSideRejectionError
}

t, err := throttledFn(ctx)
res, err = throttledFn(ctx)

now = Now()
switch {
Expand All @@ -282,7 +324,7 @@ func Throttle[T any](
return fallbackFn[0](ctx, err, false)
}

return t, err
return res, err
}

// WithAdaptiveThrottle is used to send a request to a backend using the given AdaptiveThrottle for
Expand All @@ -298,7 +340,12 @@ func WithAdaptiveThrottle[T any](
at *AdaptiveThrottle,
priority Priority,
throttledFn func() (T, error),
) (T, error) {
) (res T, err error) {
priority, err = at.validate(priority, at.priorities)
if err != nil {
return res, err
}

now := Now()
rejectionProbability := at.rejectionProbability(priority, now)
if rand.Float64() < rejectionProbability {
Expand All @@ -314,7 +361,7 @@ func WithAdaptiveThrottle[T any](
return zero, ClientSideRejectionError
}

t, err := throttledFn()
res, err = throttledFn()

now = Now()
switch {
Expand All @@ -331,7 +378,7 @@ func WithAdaptiveThrottle[T any](
at.accept(priority, now)
}

return t, err
return res, err
}

// RejectedError wraps an error to indicate that the error should be considered
Expand Down
48 changes: 48 additions & 0 deletions validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package bulwark

import (
"fmt"
"log/slog"

"github.com/deixis/faults"
)

var (
// AssertValidPriority panics when a priority is out of range.
// A priority is out of range when it is less than 0 or greater than or equal
// to priorities.
AssertValidPriority = func(p Priority, priorities int) (Priority, error) {
if p < 0 || int(p) >= priorities {
panic(fmt.Sprintf("bulwark: priority must be in the range [0, %d), but got %d", priorities, p))
}

return p, nil
}

// ClampInvalidPriority clamps any out-of-range priority to the lowest valid
// priority (priorities-1). This applies to both negative values and values
// that exceed the configured number of priorities, preventing invalid or
// malicious input from being promoted to a higher-importance tier.
ClampInvalidPriority = func(p Priority, priorities int) (Priority, error) {
if p >= 0 && int(p) < priorities {
return p, nil
}
slog.Warn("bulwark: priority is out of range", "max", priorities-1, "priority", p)

return Priority(priorities - 1), nil
}

// RejectInvalidPriority returns an error when a priority is out of range.
// A priority is out of range when it is less than 0 or greater than or equal
// to priorities.
RejectInvalidPriority = func(p Priority, priorities int) (Priority, error) {
if p < 0 || int(p) >= priorities {
return p, faults.Bad(&faults.FieldViolation{
Field: "priority",
Description: fmt.Sprintf("priority must be in the range [0, %d), but got %d", priorities, p),
})
}

return p, nil
}
)
163 changes: 163 additions & 0 deletions validator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package bulwark

import (
"context"
"testing"
)

// TestAssertValidPriority verifies that AssertValidPriority passes valid
// priorities through unchanged and panics for any priority outside [0, priorities).
// This validator is intended for development environments where misuse of the
// library should be caught immediately rather than silently corrected.
func TestAssertValidPriority(t *testing.T) {
t.Run("valid priority passes through", func(t *testing.T) {
for _, p := range []Priority{High, Important, Medium, Low} {
got, err := AssertValidPriority(p, StandardPriorities)
if err != nil {
t.Errorf("priority %d: unexpected error: %v", p, err)
}
if got != p {
t.Errorf("priority %d: expected %d, got %d", p, p, got)
}
}
})

t.Run("negative priority panics", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic, got none")
}
}()
AssertValidPriority(-1, StandardPriorities)
})

t.Run("out of range priority panics", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic, got none")
}
}()
AssertValidPriority(Priority(StandardPriorities), StandardPriorities)
})
}

// TestClampInvalidPriority verifies that ClampInvalidPriority passes valid
// priorities through unchanged and clamps any out-of-range priority — including
// negative values — to the lowest valid priority (priorities-1). Clamping to
// the lowest rather than the nearest boundary ensures that invalid or malicious
// input is never silently promoted to a higher-importance tier.
func TestClampInvalidPriority(t *testing.T) {
t.Run("valid priority passes through", func(t *testing.T) {
for _, p := range []Priority{High, Important, Medium, Low} {
got, err := ClampInvalidPriority(p, StandardPriorities)
if err != nil {
t.Errorf("priority %d: unexpected error: %v", p, err)
}
if got != p {
t.Errorf("priority %d: expected %d, got %d", p, p, got)
}
}
})

t.Run("negative priority clamped to lowest", func(t *testing.T) {
got, err := ClampInvalidPriority(-1, StandardPriorities)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if got != Priority(StandardPriorities-1) {
t.Errorf("expected %d, got %d", StandardPriorities-1, got)
}
})

t.Run("out of range priority adjusted to lowest", func(t *testing.T) {
got, err := ClampInvalidPriority(Priority(StandardPriorities), StandardPriorities)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if got != Priority(StandardPriorities-1) {
t.Errorf("expected %d, got %d", StandardPriorities-1, got)
}
})
}

// TestRejectInvalidPriority verifies that RejectInvalidPriority passes valid
// priorities through unchanged and returns an error for any priority outside
// [0, priorities). This validator is suited for APIs where the caller is
// responsible for providing a valid priority and must handle the error.
func TestRejectInvalidPriority(t *testing.T) {
t.Run("valid priority passes through", func(t *testing.T) {
for _, p := range []Priority{High, Important, Medium, Low} {
got, err := RejectInvalidPriority(p, StandardPriorities)
if err != nil {
t.Errorf("priority %d: unexpected error: %v", p, err)
}
if got != p {
t.Errorf("priority %d: expected %d, got %d", p, p, got)
}
}
})

t.Run("negative priority returns error", func(t *testing.T) {
_, err := RejectInvalidPriority(-1, StandardPriorities)
if err == nil {
t.Error("expected error, got nil")
}
})

t.Run("out of range priority returns error", func(t *testing.T) {
_, err := RejectInvalidPriority(Priority(StandardPriorities), StandardPriorities)
if err == nil {
t.Error("expected error, got nil")
}
})
}

// TestWithPriorityValidator verifies that WithPriorityValidator wires a custom
// validator into the throttle so that it is called on every request. It also
// confirms that the default behaviour (no option provided) uses
// ClampInvalidPriority: out-of-range priorities are clamped rather than
// rejected, so the throttled function is still invoked.
func TestWithPriorityValidator(t *testing.T) {
t.Run("custom validator is applied", func(t *testing.T) {
called := false
throttle := NewAdaptiveThrottle(StandardPriorities,
WithPriorityValidator(func(p Priority, priorities int) (Priority, error) {
called = true
return p, nil
}),
)
throttle.Throttle(context.Background(), High, func(_ context.Context) error {
return nil
})
if !called {
t.Error("expected custom validator to be called")
}
})

t.Run("RejectInvalidPriority used as validator rejects invalid priority", func(t *testing.T) {
throttle := NewAdaptiveThrottle(StandardPriorities,
WithPriorityValidator(RejectInvalidPriority),
)
err := throttle.Throttle(context.Background(), Priority(StandardPriorities), func(_ context.Context) error {
return nil
})
if err == nil {
t.Error("expected error for out-of-range priority, got nil")
}
})

t.Run("default validator adjusts invalid priority", func(t *testing.T) {
throttle := NewAdaptiveThrottle(StandardPriorities)
called := false
err := throttle.Throttle(context.Background(), Priority(StandardPriorities), func(_ context.Context) error {
called = true
return nil
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !called {
t.Error("expected throttled function to be called after priority adjustment")
}
})
}
Loading