diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 98360f5..7dbb344 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -3,11 +3,16 @@ name: Run unit tests on: push -env: - GO_VERSION: 1.22 - jobs: test: + strategy: + matrix: + module: [., v2] + include: + - module: . + go-version: "1.22" + - module: v2 + go-version: "1.26" runs-on: ubuntu-latest steps: - name: Checkout code @@ -16,9 +21,10 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: ${{ env.GO_VERSION }} + go-version: ${{ matrix.go-version }} - name: Run tests + working-directory: ${{ matrix.module }} # Run up to 3 times in case there is a flaky test run: | retry() { diff --git a/README.md b/README.md index dc5b185..2ab3ff2 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ - [Quick start](#quick-start) - [Error handling](#error-handling) - [Situational](#situational) - - [Global](#global) + - [Per-instance classifier](#per-instance-classifier) - [`deixis/faults`](#deixisfaults) - [Fallback](#fallback) - [Priority](#priority) @@ -15,7 +15,7 @@ - [Throttle ratio](#throttle-ratio) - [Throttle minimum rate](#throttle-minimum-rate) - [Throttle window](#throttle-window) - - [Accepted errors](#accepted-errors) + - [Rejected error classifier](#rejected-error-classifier) - [Under the hood](#under-the-hood) - [Inspirations](#inspirations) - [Further reading](#further-reading) @@ -28,6 +28,8 @@ Distributed services are particularly susceptible to cascading failures when par In normal conditions, when resources meet demand, Bulwark operates passively, allowing all traffic to flow without interference. Unlike traditional throttling mechanisms, Bulwark does not queue requests, ensuring no additional latency is introduced to request handling. +**Requires Go 1.26+** (`github.com/deixis/bulwark/v2`). The v1 module (`github.com/deixis/bulwark`) requires Go 1.22+. + ## Quick start ```go @@ -35,9 +37,10 @@ package main import ( "context" + "errors" "fmt" - "github.com/deixis/bulwark" + "github.com/deixis/bulwark/v2" ) func main() { @@ -68,28 +71,20 @@ func main() { return nil }) if err != nil { - if err == bulwark.ClientSideRejectionError { + if errors.Is(err, bulwark.ErrClientSideRejection) { // Call dropped } // Handle error } - // When the throttled function needs to return a value, this function can be used. + // When the throttled function needs to return a value, use the generic Throttle function. msg, err := bulwark.Throttle(ctx, throttle, bulwark.Medium, func(ctx context.Context) (string, error) { // Call external service here... - var err error - if err != nil { - // Wrap error when it should be considered for throttling. - // By default, errors are ignored unless they are from the `faults` package. - // See the Error handling section for more info. - return bulwark.RejectedError(err) - } - return "Hello", nil }) if err != nil { - if err == bulwark.ClientSideRejectionError { + if errors.Is(err, bulwark.ErrClientSideRejection) { // Call dropped } @@ -119,39 +114,38 @@ if err != nil { > 💡 Wrapping errors with `bulwark.RejectedError` is suitable for initial implementations and simple use cases. However, avoid adding excessive error-handling logic within the throttled function, because it is not easily reusable and can lead to inconsistencies. -### Global +### Per-instance classifier -Bulwark provides a global function, `bulwark.IsRejectedError`, to classify errors for throttling. This is especially useful for handling well-known error types across the codebase, reducing logic duplication in throttled functions. +`WithRejectedErrorFunc` sets the per-instance function that classifies errors as capacity rejections. This is especially useful for handling well-known error types across the codebase, reducing logic duplication in throttled functions. -Errors wrapped with `bulwark.RejectedError(err)` are always treated as capacity issues, so you don't need to include them in your `bulwark.IsRejectedError` implementation. +Errors wrapped with `bulwark.RejectedError(err)` are always treated as capacity issues, so you don't need to include them in your classifier. ```go -bulwark.IsRejectedError = func(err error) bool { - // For example, all timeouts could be considered as a capacity problem. - tempErr, ok := err.(interface { - Timeout() bool - }) - if ok && tempErr.Timeout() { - return true - } - // a "Connection Reset by Peer" could also show symptoms of a capacity problem. - if errors.Is(err, syscall.ECONNRESET) { - return true - } - // Include the default logic - if bulwark.DefaultRejectedError(err) { - return true - } - - return false // Use true or false by default to have a white/black list approach. -} +throttle := bulwark.NewAdaptiveThrottle( + bulwark.StandardPriorities, + bulwark.WithRejectedErrorFunc(func(err error) bool { + // For example, all timeouts could be considered as a capacity problem. + tempErr, ok := err.(interface { + Timeout() bool + }) + if ok && tempErr.Timeout() { + return true + } + // a "Connection Reset by Peer" could also show symptoms of a capacity problem. + if errors.Is(err, syscall.ECONNRESET) { + return true + } + // Include the default logic + return bulwark.DefaultRejectedErrorFunc(err) + }), +) ``` -> 💡 This approach works well in codebases with consistent error definitions for capacity-related issues. For instance, an [Echo](https://echo.labstack.com) server might override `bulwark.IsRejectedError` to include `echo.*HTTPError`. +> 💡 This approach works well in codebases with consistent error definitions for capacity-related issues. For instance, an [Echo](https://echo.labstack.com) server might use `WithRejectedErrorFunc` to include `echo.*HTTPError`. ### `deixis/faults` -Bulwark integrates with the [`deixis/faults`](https://github.com/deixis/faults) library through `bulwark.DefaultRejectedError`. This integration provides a structured and consistent way to categorise errors using well-defined primitives, offering significant benefits beyond load shedding. +Bulwark integrates with the [`deixis/faults`](https://github.com/deixis/faults) library through `bulwark.DefaultRejectedErrorFunc`. This integration provides a structured and consistent way to categorise errors using well-defined primitives, offering significant benefits beyond load shedding. ```go err := throttle.Throttle(ctx, bulwark.Medium, func(ctx context.Context) error { @@ -288,7 +282,7 @@ Higher values of `k` mean that the throttle will react more slowly when a backen ```go throttle := bulwark.NewAdaptiveThrottle( bulwark.StandardPriorities, - bulwark.WithAdaptivethrottleatio(1.1), + bulwark.WithAdaptiveThrottleRatio(1.1), ) ``` @@ -318,21 +312,24 @@ throttle := bulwark.NewAdaptiveThrottle( ) ``` -### Accepted errors +### Rejected error classifier -Set the function that determines whether an error should be considered for the throttling. When the call to `fn` returns true, the error is NOT counted towards the throttling. +Set the per-instance function that determines whether an error returned by the throttled function should be counted as a capacity rejection. Defaults to `DefaultRejectedErrorFunc`. ```go -isAcceptedErrors := func(err error) bool { - return errors.Is(err, context.Canceled) // || other conditions -} throttle := bulwark.NewAdaptiveThrottle( bulwark.StandardPriorities, - bulwark.WithAcceptedErrors(isAcceptedErrors), + bulwark.WithRejectedErrorFunc(func(err error) bool { + // context.Canceled is not a capacity issue — don't count it. + if errors.Is(err, context.Canceled) { + return false + } + return bulwark.DefaultRejectedErrorFunc(err) + }), ) ``` -> Errors unrelated to resource constraints or a service's inability to handle traffic should be allowed. For instance, errors caused by invalid user requests or authentication failures should be accepted. +> Only errors that indicate the backend is under resource pressure should return true. Errors caused by invalid requests, authentication failures, or client cancellations should return false. ## Under the hood diff --git a/go.mod b/go.mod index c654cd1..19b712d 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/deixis/bulwark go 1.22 require ( - github.com/deixis/faults v0.0.0-20240817153531-c0ec10db827f + github.com/deixis/faults v1.0.1 golang.org/x/time v0.6.0 ) diff --git a/go.sum b/go.sum index ed6119d..36f42ec 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ -github.com/deixis/faults v0.0.0-20240817153531-c0ec10db827f h1:n8+Ze8qDZh8DzSdknFqzXpvU3xjVrhqShgyx1xwC8ek= -github.com/deixis/faults v0.0.0-20240817153531-c0ec10db827f/go.mod h1:TmAFyR/M6swaIznYCjZBqZMVJg5MYOJFOsTYOawLZK4= +github.com/deixis/faults v1.0.1 h1:4KbZaJvqOfc2cWh3CjWU2ynGWRY/OpDr2DOgp2j6zeQ= +github.com/deixis/faults v1.0.1/go.mod h1:TmAFyR/M6swaIznYCjZBqZMVJg5MYOJFOsTYOawLZK4= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/v2/adaptive.go b/v2/adaptive.go new file mode 100644 index 0000000..53218f8 --- /dev/null +++ b/v2/adaptive.go @@ -0,0 +1,401 @@ +package bulwark + +import ( + "context" + "errors" + "math/rand/v2" + "sync" + "time" + + "github.com/deixis/faults" +) + +const ( + // K is the default accept multiplier, which is used to determine the number + // of requests that are allowed to reach the backend. + // + // A value of 2 means that the throttle will allow twice as many requests to + // actually reach the backend as it believes will succeed. + K = 2 + // MinRPS is the minimum number of requests per second that the adaptive + // throttle will allow (approximately) through to the upstream, even if every + // request is failing. + MinRPS = 1 +) + +// ErrClientSideRejection is the default error returned when the throttle +// rejects a request on the client side without forwarding it to the backend. +// Use WithClientSideRejectionError to override this per instance. +var ErrClientSideRejection = errors.New("bulwark: client-side rejection") + +// DefaultRejectedErrorFunc is the default function used to classify errors as +// rejections. It treats Unavailable and ResourceExhausted errors as rejections. +// Use WithRejectedErrorFunc to override this per instance. +func DefaultRejectedErrorFunc(err error) bool { + return faults.IsUnavailable(err) || faults.IsResourceExhausted(err) +} + +// AdaptiveThrottle is used in a client to throttle requests to a backend as it becomes unhealthy to +// help it recover from overload more quickly. Because backends must expend resources to reject +// requests over their capacity it is vital for clients to ease off on sending load when they are +// in trouble, lest the backend spend all of its resources on rejecting requests and have none left +// over to actually serve any. +// +// The adaptive throttle works by tracking the success rate of requests over some time interval +// (usually a minute or so), and randomly rejecting requests without sending them to avoid sending +// too much more than the rate that are expected to actually be successful. Some slop is included, +// because even if the backend is serving zero requests successfully, we do need to occasionally +// send it requests to learn when it becomes healthy again. +// +// More on adaptive throttles in https://sre.google/sre-book/handling-overload/ +type AdaptiveThrottle struct { + m sync.Mutex + + k float64 + minPerWindow float64 + + priorities int + requests []windowedCounter + accepts []windowedCounter + validate func(p Priority, priorities int) (Priority, error) + + isRejectedError func(error) bool + clientSideRejectionError error + now func() time.Time + rng *rand.Rand +} + +// NewAdaptiveThrottle returns an AdaptiveThrottle. +// +// priorities is the number of priorities that the throttle will accept. Giving a priority outside +// of `[0, priorities)` will panic. +func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *AdaptiveThrottle { + if priorities <= 0 { + panic("bulwark: priorities must be greater than 0") + } + + opts := adaptiveThrottleOptions{ + d: time.Minute, + k: K, + minRate: MinRPS, + isRejectedError: DefaultRejectedErrorFunc, + clientSideRejectionError: ErrClientSideRejection, + now: time.Now, + } + for _, option := range options { + option.f(&opts) + } + + now := opts.now() + requests := make([]windowedCounter, priorities) + accepts := make([]windowedCounter, priorities) + for i := range requests { + requests[i] = newWindowedCounter(now, opts.d/10, 10) + accepts[i] = newWindowedCounter(now, opts.d/10, 10) + } + + validate := opts.validate + if validate == nil { + validate = ClampInvalidPriority + } + + var rng *rand.Rand + if opts.randSource != nil { + rng = rand.New(opts.randSource) + } + + return &AdaptiveThrottle{ + k: opts.k, + priorities: priorities, + requests: requests, + accepts: accepts, + minPerWindow: opts.minRate * opts.d.Seconds(), + validate: validate, + isRejectedError: opts.isRejectedError, + clientSideRejectionError: opts.clientSideRejectionError, + now: opts.now, + rng: rng, + } +} + +// Throttle sends a request to the backend when the adaptive throttle allows it. +// The request is throttled based on the priority of the request. +// +// The default priority is used when the given `ctx` does not have a priority set. +// The `ctx` can set the priority using `WithPriority`. +// +// When `throttledFn` returns an error, the error is considered as a rejection +// when `WithRejectedErrorFunc` returns true or when the error is wrapped in a +// `RejectedError`. +// +// If there are enough rejections within a given time window, further calls to +// `Throttle` may begin returning `ErrClientSideRejection` immediately +// without invoking `throttledFn`. Lower-priority requests are preferred to be +// rejected first. +func (t *AdaptiveThrottle) Throttle( + ctx context.Context, defaultPriority Priority, fn throttledFn, fallbackFn ...fallbackFn, +) error { + var fb []fallbackArgsFn[struct{}] + if len(fallbackFn) > 0 { + f := fallbackFn[0] + fb = []fallbackArgsFn[struct{}]{ + func(ctx context.Context, err error, local bool) (struct{}, error) { + return struct{}{}, f(ctx, err, local) + }, + } + } + _, err := Throttle(ctx, t, defaultPriority, func(ctx context.Context) (struct{}, error) { + return struct{}{}, fn(ctx) + }, fb...) + return err +} + +// rejectionProbability returns the probability that a request of the given +// priority will be rejected. The result is clamped to the range [0, 1]. +// +// It uses the formula from https://sre.google/sre-book/handling-overload/ to +// calculate the probability that a request will be rejected. The formula is: +// +// clamp(0, (requests - k * accepts) / (requests + minPerWindow), 1) +// +// Where: +// - requests is the number of requests of the given priority in the last d time window. +// - accepts is the number of requests of the given priority that were accepted in the last d time +// window. +// - k is the ratio of the measured success rate and the rate that the throttle will admit. +// - minPerWindow is the minimum number of requests per second that the adaptive throttle will allow +// (approximately) through to the upstream, even if every request is failing. +func (t *AdaptiveThrottle) rejectionProbability(p Priority, now time.Time) float64 { + t.m.Lock() + requests := float64(t.requests[int(p)].get(now)) + accepts := float64(t.accepts[int(p)].get(now)) + for i := range int(p) { + // Also count non-accepted requests for every higher priority as + // non-accepted for this priority. + requests += float64(t.requests[i].get(now) - t.accepts[i].get(now)) + } + t.m.Unlock() + + return clamp(0, (requests-t.k*accepts)/(requests+t.minPerWindow), 1) +} + +// accept records that a request of the given priority was accepted. +func (t *AdaptiveThrottle) accept(p Priority, now time.Time) { + t.m.Lock() + t.requests[int(p)].add(now, 1) + t.accepts[int(p)].add(now, 1) + t.m.Unlock() +} + +// reject records that a request of the given priority was rejected. +func (t *AdaptiveThrottle) reject(p Priority, now time.Time) { + t.m.Lock() + t.requests[int(p)].add(now, 1) + t.m.Unlock() +} + +func (t *AdaptiveThrottle) randFloat64() float64 { + if t.rng != nil { + return t.rng.Float64() + } + return rand.Float64() +} + +// Additional options for the AdaptiveThrottle type. These options do not frequently need to be +// tuned as the defaults work in a majority of cases. +type AdaptiveThrottleOption struct { + f func(*adaptiveThrottleOptions) +} + +type adaptiveThrottleOptions struct { + k float64 + minRate float64 + d time.Duration + validate func(p Priority, priorities int) (Priority, error) + isRejectedError func(error) bool + clientSideRejectionError error + now func() time.Time + randSource rand.Source +} + +// WithAdaptiveThrottleRatio sets the ratio of the measured success rate and the rate that the throttle +// will admit. For example, when k is 2 the throttle will allow twice as many requests to actually +// reach the backend as it believes will succeed. Higher values of k mean that the throttle will +// react more slowly when a backend becomes unhealthy, but react more quickly when it becomes +// healthy again, and will allow more load to an unhealthy backend. k=2 is usually a good place to +// start, but backends that serve "cheap" requests (e.g. in-memory caches) may need a lower value. +func WithAdaptiveThrottleRatio(k float64) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.k = k + }} +} + +// WithAdaptiveThrottleMinimumRate sets the minimum number of requests per second that the adaptive +// throttle will allow (approximately) through to the upstream, even if every request is failing. +// This is important because this is how the adaptive throttle 'learns' when the upstream becomes +// healthy again. +func WithAdaptiveThrottleMinimumRate(x float64) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.minRate = x + }} +} + +// WithAdaptiveThrottleWindow sets the time window over which the throttle remembers requests for use in +// figuring out the success rate. +func WithAdaptiveThrottleWindow(d time.Duration) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.d = d + }} +} + +// WithPriorityValidator sets the function that validates input priority values. +// +// The function should return the validated priority value. If the priority is +// invalid, the function should return an error. +func WithPriorityValidator(fn func(p Priority, priorities int) (Priority, error)) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.validate = func(p Priority, priorities int) (Priority, error) { + p, err := fn(p, priorities) + if err != nil { + return p, err + } + + // Safeguard in case the validator returns an out-of-range priority + // without an error. Clamp rather than panic to stay consistent with + // the default behaviour. + return ClampInvalidPriority(p, priorities) + } + }} +} + +// WithRejectedErrorFunc sets the per-instance function that determines whether +// an error returned by the throttled function should be counted as a rejection. +// Defaults to DefaultRejectedErrorFunc. +func WithRejectedErrorFunc(fn func(error) bool) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.isRejectedError = fn + }} +} + +// WithClientSideRejectionError sets the per-instance error returned when the +// throttle rejects a request on the client side without forwarding it to the +// backend. Defaults to ErrClientSideRejection. +func WithClientSideRejectionError(err error) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.clientSideRejectionError = err + }} +} + +// WithNow sets the per-instance time source. This is primarily useful in tests +// to control the clock without affecting other AdaptiveThrottle instances. +func WithNow(fn func() time.Time) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.now = fn + }} +} + +// WithRandomSource sets the per-instance random source used to sample the +// rejection probability. This is primarily useful in tests to produce +// deterministic behaviour: a source that always returns 0 will shed every +// request whose rejection probability is greater than zero, and a source that +// always returns math.MaxUint64 will never shed. +// +// The provided source must be safe for concurrent use if the AdaptiveThrottle +// is used concurrently. rand.NewPCG and rand.NewChaCha8 are not concurrent-safe +// by default. +func WithRandomSource(src rand.Source) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.randSource = src + }} +} + +// Throttle executes throttledFn through the given AdaptiveThrottle and returns +// the result. It is the generic counterpart to AdaptiveThrottle.Throttle for +// functions that return a value. +// +// The default priority is used when the given `ctx` does not have a priority set. +// The `ctx` can set the priority using `WithPriority`. +func Throttle[T any]( + ctx context.Context, + at *AdaptiveThrottle, + defaultPriority Priority, + throttledFn throttledArgsFn[T], + fallbackFn ...fallbackArgsFn[T], +) (res T, err error) { + priority, err := at.validate(PriorityFromContext(ctx, defaultPriority), at.priorities) + if err != nil { + return res, err + } + + now := at.now() + rejectionProbability := at.rejectionProbability(priority, now) + if at.randFloat64() < rejectionProbability { + // As Bulwark starts rejecting requests, requests will continue to exceed + // accepts. While it may seem counterintuitive, given that locally rejected + // requests aren't actually propagated, this is the preferred behavior. As the + // rate at which the application attempts requests to Bulwark grows + // (relative to the rate at which the backend accepts them), we want to + // increase the probability of dropping new requests. + at.reject(priority, now) + var zero T + + if len(fallbackFn) > 0 { + return fallbackFn[0](ctx, at.clientSideRejectionError, true) + } + + return zero, at.clientSideRejectionError + } + + res, err = throttledFn(ctx) + + now = at.now() + switch { + case err == nil: + at.accept(priority, now) + case errors.Is(err, errRejected{}): + // Unwrap error to return the original error to the caller. + if re, ok := errors.AsType[errRejected](err); ok { + err = re.inner + } + + fallthrough + case at.isRejectedError(err): + at.reject(priority, now) + default: + at.accept(priority, now) + } + + if err != nil && len(fallbackFn) > 0 { + return fallbackFn[0](ctx, err, false) + } + + return res, err +} + +// RejectedError wraps an error to indicate that the error should be considered +// for the throttling. +// +// Any error that indicates that the backend is unhealthy should be wrapped with +// `RejectedError`. But other errors, such as bad requests, authentication failures, +// pre-condition failures, etc., should not be wrapped with `RejectedError`. +func RejectedError(err error) error { return errRejected{inner: err} } + +type errRejected struct{ inner error } + +func (err errRejected) Error() string { return err.inner.Error() } +func (err errRejected) Unwrap() error { return err.inner } +func (err errRejected) Is(target error) bool { + _, ok := target.(errRejected) + + return ok +} + + +func clamp(lo, x, hi float64) float64 { return max(lo, min(x, hi)) } + +type ( + throttledFn func(ctx context.Context) error + fallbackFn func(ctx context.Context, err error, local bool) error + throttledArgsFn[T any] func(ctx context.Context) (T, error) + fallbackArgsFn[T any] func(ctx context.Context, err error, local bool) (T, error) +) diff --git a/v2/adaptive_test.go b/v2/adaptive_test.go new file mode 100644 index 0000000..9fa41bd --- /dev/null +++ b/v2/adaptive_test.go @@ -0,0 +1,566 @@ +package bulwark + +import ( + "context" + "errors" + "math" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/deixis/faults" + "golang.org/x/time/rate" +) + +// TestAdaptiveThrottlePriorityShedding verifies that under server overload, +// Bulwark sheds lower-priority requests at a higher rate than higher-priority +// ones. It measures the client-side shed rate (requests blocked by Bulwark +// before reaching the server), which is what Bulwark directly controls. +// +// The test uses only High and Low priorities. Comparing adjacent intermediate +// priorities is not stable because once a mid-tier priority enters a shed +// spiral its "failures" count inflates the next tier's formula, creating a +// feedback loop whose resolution depends on demand ratios. The High-vs-Low +// ordering is always guaranteed: Low's rejection-probability formula explicitly +// adds High's failure count to its numerator, so P(reject,Low) ≥ P(reject,High). +func TestAdaptiveThrottlePriorityShedding(t *testing.T) { + const ( + duration = 10 * time.Second + highDemand = 15 // rps + lowDemand = 30 // rps + ) + + // Each priority gets its own server limiter to avoid goroutine scheduling + // artifacts from a shared resource. A shared limiter creates a token lottery: + // the Low goroutine (running at 2× rate) tends to win more tokens, collapsing + // High's accept count and triggering a death spiral regardless of priority. + // + // With separate limiters we control the accept rates directly: + // High server: 10 rps capacity → 67% accept → P(reject,High) ≤ 0 → no shedding + // Low server: 10 rps capacity → 33% accept → P(reject,Low) > 0 → Bulwark sheds + // + // This tests the invariant cleanly: Low's rejection-probability formula adds + // High's failure count to its numerator, so P(reject,Low) ≥ P(reject,High). + highServerLimiter := rate.NewLimiter(rate.Limit(float64(highDemand)*2/3), 1) // 10 rps + lowServerLimiter := rate.NewLimiter(rate.Limit(float64(lowDemand)/3), 1) // 10 rps + + clientThrottle := NewAdaptiveThrottle( + StandardPriorities, + WithAdaptiveThrottleWindow(3*time.Second), + ) + + start := time.Now() + var highAttempts, highSent int64 + var lowAttempts, lowSent int64 + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + lim := rate.NewLimiter(rate.Limit(highDemand), 1) + for time.Since(start) < duration { + if err := lim.Wait(context.Background()); err != nil { + return + } + atomic.AddInt64(&highAttempts, 1) + _, _ = Throttle(context.Background(), clientThrottle, High, func(ctx context.Context) (struct{}, error) { + atomic.AddInt64(&highSent, 1) + if !highServerLimiter.Allow() { + return struct{}{}, RejectedError(faults.Unavailable(0)) + } + return struct{}{}, nil + }) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + lim := rate.NewLimiter(rate.Limit(lowDemand), 1) + for time.Since(start) < duration { + if err := lim.Wait(context.Background()); err != nil { + return + } + atomic.AddInt64(&lowAttempts, 1) + _, _ = Throttle(context.Background(), clientThrottle, Low, func(ctx context.Context) (struct{}, error) { + atomic.AddInt64(&lowSent, 1) + if !lowServerLimiter.Allow() { + return struct{}{}, RejectedError(faults.Unavailable(0)) + } + return struct{}{}, nil + }) + } + }() + + wg.Wait() + + ha, hs := atomic.LoadInt64(&highAttempts), atomic.LoadInt64(&highSent) + la, ls := atomic.LoadInt64(&lowAttempts), atomic.LoadInt64(&lowSent) + + highShedRate := float64(ha-hs) / float64(ha) + lowShedRate := float64(la-ls) / float64(la) + + t.Logf("High: attempts=%d sent=%d shed=%.1f%%", ha, hs, highShedRate*100) + t.Logf("Low: attempts=%d sent=%d shed=%.1f%%", la, ls, lowShedRate*100) + + if highShedRate >= lowShedRate { + t.Errorf("High shed rate (%.1f%%) should be lower than Low (%.1f%%)", + highShedRate*100, lowShedRate*100) + } +} + +// TestAdaptiveThrottleRecovery verifies that once a server recovers from +// overload, Bulwark stops shedding and resumes forwarding requests. This +// validates the other half of the adaptive throttle contract: not only must it +// shed under pressure, it must also relax when the pressure is gone. +// +// The clock is controlled via WithNow so the window can be expired instantly +// without real sleeps. +func TestAdaptiveThrottleRecovery(t *testing.T) { + const ( + k = 2.0 + minRate = 1.0 + window = time.Second + // minPerWindow = minRate * window.Seconds() = 1 + ) + + now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + at := NewAdaptiveThrottle(1, + WithAdaptiveThrottleRatio(k), + WithAdaptiveThrottleMinimumRate(minRate), + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + ) + + // Phase 1: healthy — 100 requests all accepted. + // P = (100 - 2*100) / (100+1) = -100/101 → clamped to 0. + for range 100 { + at.accept(0, now) + } + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("healthy phase: expected P=0, got %f", p) + } + + // Expire the healthy window, then enter overload. + now = now.Add(window) + + // Phase 2: overloaded — 100 requests all rejected. + // P = (100 - 0) / (100+1) = 100/101 ≈ 0.99. + for range 100 { + at.reject(0, now) + } + wantOverload := 100.0 / 101.0 + if p := at.rejectionProbability(0, now); math.Abs(p-wantOverload) > 1e-10 { + t.Errorf("overloaded phase: expected P=%.10f, got %.10f", wantOverload, p) + } + + // Phase 3: expire the overload window — recovery. + now = now.Add(window) + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("recovery phase: expected P=0, got %f", p) + } +} + +// TestRejectionProbabilityFormula verifies the exact rejection probability +// formula at known request/accept counts, including window expiry. +func TestRejectionProbabilityFormula(t *testing.T) { + const ( + k = 2.0 + minRate = 1.0 + window = time.Minute + // minPerWindow = 1.0 * 60 = 60 + ) + + now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + newThrottle := func() *AdaptiveThrottle { + return NewAdaptiveThrottle(1, + WithAdaptiveThrottleRatio(k), + WithAdaptiveThrottleMinimumRate(minRate), + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + ) + } + + t.Run("no requests", func(t *testing.T) { + // P = (0 - 0) / (0 + 60) = 0 + at := newThrottle() + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("expected 0, got %f", p) + } + }) + + t.Run("requests equal k*accepts", func(t *testing.T) { + // 10 requests, 5 accepts: P = (10 - 2*5) / (10+60) = 0/70 = 0 + at := newThrottle() + for range 5 { + at.accept(0, now) + } + for range 5 { + at.reject(0, now) + } + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("expected 0, got %f", p) + } + }) + + t.Run("all rejections", func(t *testing.T) { + // 100 requests, 0 accepts: P = 100 / (100+60) = 100/160 = 0.625 + at := newThrottle() + for range 100 { + at.reject(0, now) + } + want := 100.0 / 160.0 + if p := at.rejectionProbability(0, now); math.Abs(p-want) > 1e-10 { + t.Errorf("expected %.10f, got %.10f", want, p) + } + }) + + t.Run("window expiry resets probability", func(t *testing.T) { + at := newThrottle() + for range 100 { + at.reject(0, now) + } + if p := at.rejectionProbability(0, now); p <= 0 { + t.Fatal("expected non-zero probability before expiry") + } + // Advance past the full window — all 10 buckets expire. + now = now.Add(window) + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("expected 0 after window expiry, got %f", p) + } + }) +} + +// TestRejectionProbabilityPriorityCrossContamination verifies that failures at +// higher priority (lower number) are added to lower priority's request count, +// raising its rejection probability. +func TestRejectionProbabilityPriorityCrossContamination(t *testing.T) { + const ( + k = 2.0 + minRate = 1.0 + window = time.Minute + // minPerWindow = 60 + ) + + now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + at := NewAdaptiveThrottle(2, + WithAdaptiveThrottleRatio(k), + WithAdaptiveThrottleMinimumRate(minRate), + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + ) + + // Priority 0 (high): 100 requests, 10 accepts → 90 rejections. + for range 10 { + at.accept(0, now) + } + for range 90 { + at.reject(0, now) + } + + // Priority 1 (low): 10 requests, 8 accepts → 2 rejections. + for range 8 { + at.accept(1, now) + } + for range 2 { + at.reject(1, now) + } + + // P(0) = (100 - 2*10) / (100+60) = 80/160 = 0.5 + wantP0 := 80.0 / 160.0 + if p := at.rejectionProbability(0, now); math.Abs(p-wantP0) > 1e-10 { + t.Errorf("priority 0: expected %.10f, got %.10f", wantP0, p) + } + + // P(1): requests = r1 + (r0-a0) = 10 + 90 = 100; accepts = a1 = 8 + // P(1) = (100 - 2*8) / (100+60) = 84/160 = 0.525 + wantP1 := 84.0 / 160.0 + if p := at.rejectionProbability(1, now); math.Abs(p-wantP1) > 1e-10 { + t.Errorf("priority 1: expected %.10f, got %.10f", wantP1, p) + } +} + +// TestAdaptiveThrottleNonRejectionErrors verifies that only errors signalling +// genuine server overload (Unavailable, ResourceExhausted, or wrapped with +// RejectedError) are counted as failures. Errors that indicate a correct server +// response to a bad request — NotFound, Unauthenticated, FailedPrecondition, +// etc. — must count as accepts and must not drive shedding. +func TestAdaptiveThrottleNonRejectionErrors(t *testing.T) { + ctx := context.Background() + + t.Run("non-rejection errors do not trigger shedding", func(t *testing.T) { + throttle := NewAdaptiveThrottle(1) + for i := 0; i < 100; i++ { + _, _ = Throttle(ctx, throttle, Priority(0), func(ctx context.Context) (struct{}, error) { + return struct{}{}, faults.WithNotFound(errors.New("item not found")) + }) + } + + // P(reject) = (req - 2*acc) / (req + min). With req==acc==100: P ≤ 0. + // All subsequent requests must go through. + calls := 0 + for i := 0; i < 10; i++ { + _, _ = Throttle(ctx, throttle, Priority(0), func(ctx context.Context) (struct{}, error) { + calls++ + return struct{}{}, nil + }) + } + if calls != 10 { + t.Errorf("expected all 10 requests forwarded after non-rejection errors, got %d", calls) + } + }) + + t.Run("rejection errors do trigger shedding", func(t *testing.T) { + // alwaysShed ensures rand returns 0, so any P > 0 causes a shed. + throttle := NewAdaptiveThrottle(1, WithRandomSource(alwaysShed{})) + for i := 0; i < 100; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + return RejectedError(faults.Unavailable(0)) + }) + } + + calls := 0 + for i := 0; i < 10; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + calls++ + return nil + }) + } + if calls != 0 { + t.Errorf("expected all requests shed after repeated rejection errors, got %d forwarded", calls) + } + }) +} + +// TestAdaptiveThrottleMinimumRate verifies that the minimum rate floor is +// correctly encoded in the rejection probability formula. Because minPerWindow +// sits in the denominator, P = N/(N+minPerWindow) approaches but never reaches +// 1, guaranteeing that some probe traffic always reaches the backend. +func TestAdaptiveThrottleMinimumRate(t *testing.T) { + const ( + k = 2.0 + minRate = 1.0 + window = time.Second + // minPerWindow = minRate * window.Seconds() = 1 + ) + + now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + for _, n := range []int{1, 10, 100, 10_000} { + at := NewAdaptiveThrottle(1, + WithAdaptiveThrottleRatio(k), + WithAdaptiveThrottleMinimumRate(minRate), + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + ) + for range n { + at.reject(0, now) + } + // P = (n - k*0) / (n + minPerWindow) = n / (n+1) + want := float64(n) / float64(n+1) + got := at.rejectionProbability(0, now) + if math.Abs(got-want) > 1e-10 { + t.Errorf("n=%d: expected P=%.10f, got %.10f", n, want, got) + } + if got >= 1.0 { + t.Errorf("n=%d: P must be < 1 (got %f) — minimum rate floor is broken", n, got) + } + } +} + +// TestFallback ensures the fallback function is called when a request is +// rejected by the throttle. +func TestFallback(t *testing.T) { + ctx := context.Background() + throttle := NewAdaptiveThrottle(StandardPriorities, + WithRandomSource(alwaysShed{}), + ) + for i := 0; i < 100; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + return faults.Unavailable(0) + }) + } + + throttledFnCalls := 0 + fallbackFnCalls := 0 + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + throttledFnCalls++ + return nil + }, func(ctx context.Context, err error, local bool) error { + fallbackFnCalls++ + return err + }) + + if throttledFnCalls != 0 { + t.Errorf("expected throttled function to not be called, got %d", throttledFnCalls) + } + if fallbackFnCalls != 1 { + t.Errorf("expected fallback function to be called once, got %d", fallbackFnCalls) + } +} + +// TestWithNow verifies that WithNow controls the clock used by the throttle, +// allowing windowed counters to be fast-forwarded without real time passing. +// It saturates the throttle with failures, then advances the fake clock past +// the window and confirms the throttle stops shedding. +func TestWithNow(t *testing.T) { + const window = time.Second + + now := time.Now() + throttle := NewAdaptiveThrottle(1, + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + WithRandomSource(alwaysShed{}), + ) + + ctx := context.Background() + + // Saturate with rejections so shedding kicks in. + for i := 0; i < 100; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + return RejectedError(faults.Unavailable(0)) + }) + } + + // With alwaysShed and P > 0, every subsequent call must be shed. + calls := 0 + for i := 0; i < 10; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + calls++ + return nil + }) + } + if calls != 0 { + t.Fatalf("expected all calls shed before clock advance, got %d forwarded", calls) + } + + // Advance fake clock past the window — all failure buckets should expire. + now = now.Add(window) + + // With P = 0, alwaysShed returns 0 but 0 < 0 is false — all calls go through. + calls = 0 + for i := 0; i < 10; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + calls++ + return nil + }) + } + if calls != 10 { + t.Errorf("expected all 10 calls forwarded after window expired, got %d", calls) + } +} + +// TestWithRejectedErrorFunc verifies that WithRejectedErrorFunc replaces the +// default error classifier for a single instance without affecting others. +func TestWithRejectedErrorFunc(t *testing.T) { + ctx := context.Background() + sentinel := errors.New("custom rejection") + + // This throttle treats sentinel as a rejection; the default would not. + custom := NewAdaptiveThrottle(1, + WithRejectedErrorFunc(func(err error) bool { + return errors.Is(err, sentinel) + }), + WithRandomSource(alwaysShed{}), + ) + // This throttle uses the default classifier; sentinel is not a rejection. + // neverShed ensures rand never triggers a shed regardless of P. + defaultThrottle := NewAdaptiveThrottle(1, WithRandomSource(neverShed{})) + + saturate := func(at *AdaptiveThrottle, err error) { + for i := 0; i < 100; i++ { + at.Throttle(ctx, 0, func(ctx context.Context) error { return err }) + } + } + + countForwarded := func(at *AdaptiveThrottle) int { + calls := 0 + for i := 0; i < 10; i++ { + at.Throttle(ctx, 0, func(ctx context.Context) error { calls++; return nil }) + } + return calls + } + + saturate(custom, sentinel) + if got := countForwarded(custom); got != 0 { + t.Errorf("custom throttle: expected all calls shed, got %d forwarded", got) + } + + saturate(defaultThrottle, sentinel) + if got := countForwarded(defaultThrottle); got != 10 { + t.Errorf("default throttle: sentinel should not count as rejection, got %d forwarded (want 10)", got) + } +} + +// TestWithClientSideRejectionError verifies that WithClientSideRejectionError +// controls the error returned for locally rejected requests. +func TestWithClientSideRejectionError(t *testing.T) { + ctx := context.Background() + customErr := errors.New("custom client rejection") + + throttle := NewAdaptiveThrottle(1, + WithClientSideRejectionError(customErr), + WithRandomSource(alwaysShed{}), + ) + + // Saturate so P > 0. + for i := 0; i < 100; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + return RejectedError(faults.Unavailable(0)) + }) + } + + // With alwaysShed every call must be shed with the custom error. + for i := 0; i < 10; i++ { + err := throttle.Throttle(ctx, 0, func(ctx context.Context) error { return nil }) + if !errors.Is(err, customErr) { + t.Errorf("call %d: expected custom rejection error, got %v", i, err) + } + } +} + +// TestInvalidFallback ensures errors from the throttled function are passed +// through correctly and do not trigger client-side rejection logic unexpectedly. +func TestInvalidFallback(t *testing.T) { + stdError := errors.New("standard error") + + table := []struct { + name string + err error + expect error + }{ + {"no error", nil, nil}, + {"client rejection error", ErrClientSideRejection, ErrClientSideRejection}, + {"wrapped rejection error", RejectedError(faults.ResourceExhausted()), faults.ResourceExhausted()}, + {"standard error", stdError, stdError}, + } + + ctx := context.Background() + for _, tt := range table { + t.Run(tt.name, func(t *testing.T) { + throttle := NewAdaptiveThrottle(StandardPriorities) + err := throttle.Throttle(ctx, 0, func(ctx context.Context) error { + return tt.err + }, func(ctx context.Context, err error, local bool) error { + return err + }) + if !errors.Is(err, tt.expect) { + t.Errorf("expected %v, got %v", tt.expect, err) + } + }) + } +} + +// alwaysShed is a rand.Source whose Float64() always returns 0, causing the +// throttle to shed every request whose rejection probability is greater than 0. +type alwaysShed struct{} + +func (alwaysShed) Uint64() uint64 { return 0 } + +// neverShed is a rand.Source whose Float64() always returns a value just below +// 1, causing the throttle to never shed (since rejection probability ≤ 1). +type neverShed struct{} + +func (neverShed) Uint64() uint64 { return math.MaxUint64 } diff --git a/v2/context.go b/v2/context.go new file mode 100644 index 0000000..367e007 --- /dev/null +++ b/v2/context.go @@ -0,0 +1,31 @@ +package bulwark + +import "context" + +type priorityKey struct{} + +var activePriorityKey = priorityKey{} + +// PriorityFromContext returns the `Priority` attached to the context. +// If no priority is attached, it returns the default priority. +// +// It is good practive to attach a global priority to requests, so all throttles +// can adapt their behaviour accordingly. +func PriorityFromContext(ctx context.Context, defaultPriority Priority) Priority { + if p, ok := ctx.Value(activePriorityKey).(Priority); ok { + return p + } + + return defaultPriority +} + +// WithPriority attaches the given `Priority` to the context. +// It is good practice to call the adaptive throttle this way: +// +// `bulwark.WithAdaptiveThrottle(at, bulwark.PriorityFromContext(ctx, priority), f)` +// +// Then requests should have a priority attached to them, so all throttles can +// adapt their behaviour accordingly. +func WithPriority(ctx context.Context, priority Priority) context.Context { + return context.WithValue(ctx, activePriorityKey, priority) +} diff --git a/v2/context_test.go b/v2/context_test.go new file mode 100644 index 0000000..1878be4 --- /dev/null +++ b/v2/context_test.go @@ -0,0 +1,24 @@ +package bulwark_test + +import ( + "context" + "testing" + + bulwark "github.com/deixis/bulwark/v2" +) + +func TestPriorityContext(t *testing.T) { + ctx := context.Background() + defaultPriority := bulwark.Medium + got := bulwark.PriorityFromContext(ctx, defaultPriority) + if got != defaultPriority { + t.Errorf("PriorityFromContext(ctx) = %v; want %v", got, defaultPriority) + } + + priority := bulwark.High + ctx = bulwark.WithPriority(ctx, priority) + got = bulwark.PriorityFromContext(ctx, defaultPriority) + if got != priority { + t.Errorf("PriorityFromContext(ctx) = %v; want %v", got, priority) + } +} diff --git a/v2/counter.go b/v2/counter.go new file mode 100644 index 0000000..c3e2963 --- /dev/null +++ b/v2/counter.go @@ -0,0 +1,62 @@ +package bulwark + +// Source: https://github.com/bradenaw/backpressure + +import ( + "time" +) + +// windowedCounter counts events in an approximate time window. It does this by splitting time into +// buckets of some width and removing buckets that are too old. +type windowedCounter struct { + // The width of a single bucket. + width time.Duration + + // The last time the bucket was read or written. + last time.Time + // The sum of all buckets. + count int + // The count of evens that happened in each bucket. This is a circular buffer. + buckets []int + // The index of the 'head' of the circular buffer, that is, the bucket that corresponds to + // `last`. + head int +} + +func newWindowedCounter(now time.Time, width time.Duration, n int) windowedCounter { + return windowedCounter{ + width: width, + last: now, + buckets: make([]int, n), + } +} + +func (c *windowedCounter) add(now time.Time, x int) { + c.get(now) + c.buckets[c.head] += x + c.count += x +} + +func (c *windowedCounter) get(now time.Time) int { + elapsed := now.Sub(c.last) + + // How many buckets have we passed since `last`? + // Since it's a circular buffer, passing more than all of the buckets is the same as passing all + // of them. + bucketsPassed := min(max(int(elapsed/c.width), 0), len(c.buckets)) + + // For all of the buckets that already happened, zero them out, advance head, and remove their + // amounts from c.count. + for range bucketsPassed { + nextIdx := (c.head + 1) % len(c.buckets) + c.count -= c.buckets[nextIdx] + c.buckets[nextIdx] = 0 + c.head = nextIdx + } + + if bucketsPassed > 0 { + c.last = now + } + + return c.count +} diff --git a/v2/go.mod b/v2/go.mod new file mode 100644 index 0000000..c37f90a --- /dev/null +++ b/v2/go.mod @@ -0,0 +1,8 @@ +module github.com/deixis/bulwark/v2 + +go 1.26 + +require ( + github.com/deixis/faults v1.0.1 + golang.org/x/time v0.15.0 +) diff --git a/v2/go.sum b/v2/go.sum new file mode 100644 index 0000000..8102a55 --- /dev/null +++ b/v2/go.sum @@ -0,0 +1,4 @@ +github.com/deixis/faults v1.0.1 h1:4KbZaJvqOfc2cWh3CjWU2ynGWRY/OpDr2DOgp2j6zeQ= +github.com/deixis/faults v1.0.1/go.mod h1:TmAFyR/M6swaIznYCjZBqZMVJg5MYOJFOsTYOawLZK4= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= diff --git a/v2/priority.go b/v2/priority.go new file mode 100644 index 0000000..bfc7f5b --- /dev/null +++ b/v2/priority.go @@ -0,0 +1,34 @@ +package bulwark + +// StandardPriorities is the number of priority levels that are available. +// This value should be used when creating a new AdaptiveThrottle when the +// default Priority constants are used. +// +// throttler := bulwark.NewAdaptiveThrottle(bulwark.StandardPriorities) +// _, err := bulwark.Throttle(ctx, throttler, bulwark.High, throttledFn) +// if err != nil { +// // handle the error +// } +const StandardPriorities = 4 + +// Priority determines the importance of a request in ascending order. +// e.g. priority 0 is more important than priority 1. +// +// When a system reaches its capacity, it will sort requests by their priority +// and process them. Lower-priority requests can either be delayed or dropped. +type Priority int8 + +// These are pre-defined priority levels that can be used, but any int value +// can be used as a priority. +const ( + // Use High when for requests that are critical to the overall experience. + High Priority = 0 + // Use Important for requests that are important, but not critical. + Important Priority = 1 + // Use Medium for noncritical requests where an elevated latency or + // failure rate would not significantly impact the experience. + Medium Priority = 2 + // Use Low for trivial requests and good for any system that can retry + // later when the system has spare capacity. + Low Priority = 3 +) diff --git a/v2/validator.go b/v2/validator.go new file mode 100644 index 0000000..34297ee --- /dev/null +++ b/v2/validator.go @@ -0,0 +1,46 @@ +package bulwark + +import ( + "fmt" + "log/slog" + + "github.com/deixis/faults" +) + +// AssertValidPriority panics when a priority is out of range. +// A priority is out of range when it is less than 0 or greater than or equal +// to priorities. +func AssertValidPriority(p Priority, priorities int) (Priority, error) { + if p < 0 || int(p) >= priorities { + panic(fmt.Sprintf("bulwark: priority must be in the range [0, %d), but got %d", priorities, p)) + } + + return p, nil +} + +// ClampInvalidPriority clamps any out-of-range priority to the lowest valid +// priority (priorities-1). This applies to both negative values and values +// that exceed the configured number of priorities, preventing invalid or +// malicious input from being promoted to a higher-importance tier. +func ClampInvalidPriority(p Priority, priorities int) (Priority, error) { + if p >= 0 && int(p) < priorities { + return p, nil + } + slog.Warn("bulwark: priority is out of range", "max", priorities-1, "priority", p) + + return Priority(priorities - 1), nil +} + +// RejectInvalidPriority returns an error when a priority is out of range. +// A priority is out of range when it is less than 0 or greater than or equal +// to priorities. +func RejectInvalidPriority(p Priority, priorities int) (Priority, error) { + if p < 0 || int(p) >= priorities { + return p, faults.Bad(&faults.FieldViolation{ + Field: "priority", + Description: fmt.Sprintf("priority must be in the range [0, %d), but got %d", priorities, p), + }) + } + + return p, nil +}