From cc7ce14042f6f0b021d442bbc38953d9c845d26d Mon Sep 17 00:00:00 2001 From: basgys Date: Sun, 15 Mar 2026 16:36:02 +0100 Subject: [PATCH] feat: per-instance random source for deterministic behaviour MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add WithRandomSource(src rand.Source) to allow per-instance control of the random sampling used to make shed/pass decisions. The source must be safe for concurrent use if the throttle is used concurrently. Follows the same pattern as WithNow and WithRejectedErrorFunc. Switch the package from math/rand to math/rand/v2. Update tests to use two fixed sources — alwaysShed (Uint64 returns 0) and neverShed (Uint64 returns MaxUint64) — replacing probabilistic assertions with exact ones throughout. --- adaptive.go | 38 ++++- adaptive_test.go | 419 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 348 insertions(+), 109 deletions(-) diff --git a/adaptive.go b/adaptive.go index 36bd421..7124e07 100644 --- a/adaptive.go +++ b/adaptive.go @@ -3,7 +3,7 @@ package bulwark import ( "context" "errors" - "math/rand" + "math/rand/v2" "sync" "time" @@ -50,6 +50,7 @@ type AdaptiveThrottle struct { isRejectedError func(error) bool clientSideRejectionError error now func() time.Time + rng *rand.Rand } // NewAdaptiveThrottle returns an AdaptiveThrottle. @@ -88,6 +89,11 @@ func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *Ada validate = ClampInvalidPriority } + var rng *rand.Rand + if opts.randSource != nil { + rng = rand.New(opts.randSource) + } + return &AdaptiveThrottle{ k: opts.k, priorities: priorities, @@ -98,6 +104,7 @@ func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *Ada isRejectedError: opts.isRejectedError, clientSideRejectionError: opts.clientSideRejectionError, now: opts.now, + rng: rng, } } @@ -125,7 +132,7 @@ func (t *AdaptiveThrottle) Throttle( now := t.nowTime() rejectionProbability := t.rejectionProbability(priority, now) - if rand.Float64() < rejectionProbability { + if t.randFloat64() < rejectionProbability { // As Bulwark starts rejecting requests, requests will continue to exceed // accepts. While it may seem counterintuitive, given that locally rejected // requests aren't actually propagated, this is the preferred behavior. As the @@ -231,6 +238,13 @@ func (t *AdaptiveThrottle) rejectionError() error { return ClientSideRejectionError } +func (t *AdaptiveThrottle) randFloat64() float64 { + if t.rng != nil { + return t.rng.Float64() + } + return rand.Float64() +} + // Additional options for the AdaptiveThrottle type. These options do not frequently need to be // tuned as the defaults work in a majority of cases. type AdaptiveThrottleOption struct { @@ -245,6 +259,7 @@ type adaptiveThrottleOptions struct { isRejectedError func(error) bool clientSideRejectionError error now func() time.Time + randSource rand.Source } // WithAdaptiveThrottleRatio sets the ratio of the measured success rate and the rate that the throttle @@ -330,6 +345,21 @@ func WithNow(fn func() time.Time) AdaptiveThrottleOption { }} } +// WithRandomSource sets the per-instance random source used to sample the +// rejection probability. This is primarily useful in tests to produce +// deterministic behaviour: a source that always returns 0 will shed every +// request whose rejection probability is greater than zero, and a source that +// always returns math.MaxUint64 will never shed. +// +// The provided source must be safe for concurrent use if the AdaptiveThrottle +// is used concurrently. rand.NewPCG and rand.NewChaCha8 are not concurrent-safe +// by default. +func WithRandomSource(src rand.Source) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.randSource = src + }} +} + func Throttle[T any]( ctx context.Context, at *AdaptiveThrottle, @@ -344,7 +374,7 @@ func Throttle[T any]( now := at.nowTime() rejectionProbability := at.rejectionProbability(priority, now) - if rand.Float64() < rejectionProbability { + if at.randFloat64() < rejectionProbability { // As Bulwark starts rejecting requests, requests will continue to exceed // accepts. While it may seem counterintuitive, given that locally rejected // requests aren't actually propagated, this is the preferred behavior. As the @@ -407,7 +437,7 @@ func WithAdaptiveThrottle[T any]( now := at.nowTime() rejectionProbability := at.rejectionProbability(priority, now) - if rand.Float64() < rejectionProbability { + if at.randFloat64() < rejectionProbability { // As Bulwark starts rejecting requests, requests will continue to exceed // accepts. While it may seem counterintuitive, given that locally rejected // requests aren't actually propagated, this is the preferred behavior. As the diff --git a/adaptive_test.go b/adaptive_test.go index d41e35b..1852650 100644 --- a/adaptive_test.go +++ b/adaptive_test.go @@ -3,6 +3,7 @@ package bulwark import ( "context" "errors" + "math" "sync" "sync/atomic" "testing" @@ -114,74 +115,170 @@ func TestAdaptiveThrottlePriorityShedding(t *testing.T) { // overload, Bulwark stops shedding and resumes forwarding requests. This // validates the other half of the adaptive throttle contract: not only must it // shed under pressure, it must also relax when the pressure is gone. +// +// The clock is controlled via WithNow so the window can be expired instantly +// without real sleeps. func TestAdaptiveThrottleRecovery(t *testing.T) { const ( - window = 1 * time.Second - demand = 20 - supply = demand * 2 // healthy: capacity well above demand - overload = demand / 10 // sick: capacity well below demand + k = 2.0 + minRate = 1.0 + window = time.Second + // minPerWindow = minRate * window.Seconds() = 1 ) - var serverHealthy atomic.Bool - serverHealthy.Store(true) + now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + at := NewAdaptiveThrottle(1, + WithAdaptiveThrottleRatio(k), + WithAdaptiveThrottleMinimumRate(minRate), + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + ) - healthyLimiter := rate.NewLimiter(rate.Limit(supply), 1) - sickLimiter := rate.NewLimiter(rate.Limit(overload), 1) + // Phase 1: healthy — 100 requests all accepted. + // P = (100 - 2*100) / (100+1) = -100/101 → clamped to 0. + for range 100 { + at.accept(0, now) + } + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("healthy phase: expected P=0, got %f", p) + } - throttle := NewAdaptiveThrottle(1, WithAdaptiveThrottleWindow(window)) - clientLimiter := rate.NewLimiter(rate.Limit(demand), 1) + // Expire the healthy window, then enter overload. + now = now.Add(window) - measureShedRate := func(d time.Duration) float64 { - var attempts, sent int64 - deadline := time.Now().Add(d) - for time.Now().Before(deadline) { - if err := clientLimiter.Wait(context.Background()); err != nil { - return 0 - } - attempts++ - _, _ = WithAdaptiveThrottle(throttle, Priority(0), func() (struct{}, error) { - sent++ - lim := healthyLimiter - if !serverHealthy.Load() { - lim = sickLimiter - } - if !lim.Allow() { - return struct{}{}, RejectedError(faults.Unavailable(0)) - } - return struct{}{}, nil - }) + // Phase 2: overloaded — 100 requests all rejected. + // P = (100 - 0) / (100+1) = 100/101 ≈ 0.99. + for range 100 { + at.reject(0, now) + } + wantOverload := 100.0 / 101.0 + if p := at.rejectionProbability(0, now); math.Abs(p-wantOverload) > 1e-10 { + t.Errorf("overloaded phase: expected P=%.10f, got %.10f", wantOverload, p) + } + + // Phase 3: expire the overload window — recovery. + now = now.Add(window) + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("recovery phase: expected P=0, got %f", p) + } +} + +// TestRejectionProbabilityFormula verifies the exact rejection probability +// formula at known request/accept counts, including window expiry. +func TestRejectionProbabilityFormula(t *testing.T) { + const ( + k = 2.0 + minRate = 1.0 + window = time.Minute + // minPerWindow = 1.0 * 60 = 60 + ) + + now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + newThrottle := func() *AdaptiveThrottle { + return NewAdaptiveThrottle(1, + WithAdaptiveThrottleRatio(k), + WithAdaptiveThrottleMinimumRate(minRate), + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + ) + } + + t.Run("no requests", func(t *testing.T) { + // P = (0 - 0) / (0 + 60) = 0 + at := newThrottle() + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("expected 0, got %f", p) + } + }) + + t.Run("requests equal k*accepts", func(t *testing.T) { + // 10 requests, 5 accepts: P = (10 - 2*5) / (10+60) = 0/70 = 0 + at := newThrottle() + for range 5 { + at.accept(0, now) + } + for range 5 { + at.reject(0, now) } - if attempts == 0 { - return 0 + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("expected 0, got %f", p) } - return float64(attempts-sent) / float64(attempts) + }) + + t.Run("all rejections", func(t *testing.T) { + // 100 requests, 0 accepts: P = 100 / (100+60) = 100/160 = 0.625 + at := newThrottle() + for range 100 { + at.reject(0, now) + } + want := 100.0 / 160.0 + if p := at.rejectionProbability(0, now); math.Abs(p-want) > 1e-10 { + t.Errorf("expected %.10f, got %.10f", want, p) + } + }) + + t.Run("window expiry resets probability", func(t *testing.T) { + at := newThrottle() + for range 100 { + at.reject(0, now) + } + if p := at.rejectionProbability(0, now); p <= 0 { + t.Fatal("expected non-zero probability before expiry") + } + // Advance past the full window — all 10 buckets expire. + now = now.Add(window) + if p := at.rejectionProbability(0, now); p != 0 { + t.Errorf("expected 0 after window expiry, got %f", p) + } + }) +} + +// TestRejectionProbabilityPriorityCrossContamination verifies that failures at +// higher priority (lower number) are added to lower priority's request count, +// raising its rejection probability. +func TestRejectionProbabilityPriorityCrossContamination(t *testing.T) { + const ( + k = 2.0 + minRate = 1.0 + window = time.Minute + // minPerWindow = 60 + ) + + now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + at := NewAdaptiveThrottle(2, + WithAdaptiveThrottleRatio(k), + WithAdaptiveThrottleMinimumRate(minRate), + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + ) + + // Priority 0 (high): 100 requests, 10 accepts → 90 rejections. + for range 10 { + at.accept(0, now) + } + for range 90 { + at.reject(0, now) } - // Phase 1: healthy baseline — no shedding expected. - shedRate := measureShedRate(3 * time.Second) - t.Logf("healthy: shed=%.1f%%", shedRate*100) - if shedRate > 0.05 { - t.Errorf("healthy phase shed rate %.1f%% should be near 0%%", shedRate*100) + // Priority 1 (low): 10 requests, 8 accepts → 2 rejections. + for range 8 { + at.accept(1, now) + } + for range 2 { + at.reject(1, now) } - // Phase 2: overloaded — shedding must kick in. - serverHealthy.Store(false) - shedRate = measureShedRate(3 * time.Second) - t.Logf("overloaded: shed=%.1f%%", shedRate*100) - if shedRate < 0.3 { - t.Errorf("overloaded phase shed rate %.1f%% should be significant", shedRate*100) + // P(0) = (100 - 2*10) / (100+60) = 80/160 = 0.5 + wantP0 := 80.0 / 160.0 + if p := at.rejectionProbability(0, now); math.Abs(p-wantP0) > 1e-10 { + t.Errorf("priority 0: expected %.10f, got %.10f", wantP0, p) } - // Phase 3: recovery — after the failure window expires, shedding must cease. - // Sleep for 2× the window so the windowed counters are fully cleared before - // we start measuring; otherwise the first window of recovery still contains - // the failures from phase 2. - serverHealthy.Store(true) - time.Sleep(2 * window) - shedRate = measureShedRate(3 * time.Second) - t.Logf("recovery: shed=%.1f%%", shedRate*100) - if shedRate > 0.05 { - t.Errorf("recovery phase shed rate %.1f%% should be near 0%%", shedRate*100) + // P(1): requests = r1 + (r0-a0) = 10 + 90 = 100; accepts = a1 = 8 + // P(1) = (100 - 2*8) / (100+60) = 84/160 = 0.525 + wantP1 := 84.0 / 160.0 + if p := at.rejectionProbability(1, now); math.Abs(p-wantP1) > 1e-10 { + t.Errorf("priority 1: expected %.10f, got %.10f", wantP1, p) } } @@ -216,15 +313,14 @@ func TestAdaptiveThrottleNonRejectionErrors(t *testing.T) { }) t.Run("rejection errors do trigger shedding", func(t *testing.T) { - throttle := NewAdaptiveThrottle(1) + // alwaysShed ensures rand returns 0, so any P > 0 causes a shed. + throttle := NewAdaptiveThrottle(1, WithRandomSource(alwaysShed{})) for i := 0; i < 100; i++ { throttle.Throttle(ctx, 0, func(ctx context.Context) error { return RejectedError(faults.Unavailable(0)) }) } - // P(reject) = req / (req + min) ≈ 100/160 ≈ 62%. With 10 attempts, - // P(all forwarded) = 0.38^10 < 0.01% — effectively impossible. calls := 0 for i := 0; i < 10; i++ { throttle.Throttle(ctx, 0, func(ctx context.Context) error { @@ -232,60 +328,45 @@ func TestAdaptiveThrottleNonRejectionErrors(t *testing.T) { return nil }) } - if calls == 10 { - t.Error("expected shedding after repeated rejection errors, but all 10 calls were forwarded") + if calls != 0 { + t.Errorf("expected all requests shed after repeated rejection errors, got %d forwarded", calls) } }) } -// TestAdaptiveThrottleMinimumRate verifies that even under total server failure -// Bulwark still forwards approximately MinRPS requests per second. This probe -// traffic is essential: without it the throttle could never detect that a -// failed server has recovered. +// TestAdaptiveThrottleMinimumRate verifies that the minimum rate floor is +// correctly encoded in the rejection probability formula. Because minPerWindow +// sits in the denominator, P = N/(N+minPerWindow) approaches but never reaches +// 1, guaranteeing that some probe traffic always reaches the backend. func TestAdaptiveThrottleMinimumRate(t *testing.T) { const ( - window = 1 * time.Second + k = 2.0 minRate = 1.0 - demand = 20 - testDur = 5 * time.Second + window = time.Second + // minPerWindow = minRate * window.Seconds() = 1 ) - throttle := NewAdaptiveThrottle(1, - WithAdaptiveThrottleWindow(window), - WithAdaptiveThrottleMinimumRate(minRate), - ) - clientLimiter := rate.NewLimiter(rate.Limit(demand), 1) - - var attempts, sent int64 - start := time.Now() - for time.Since(start) < testDur { - if err := clientLimiter.Wait(context.Background()); err != nil { - break + now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + for _, n := range []int{1, 10, 100, 10_000} { + at := NewAdaptiveThrottle(1, + WithAdaptiveThrottleRatio(k), + WithAdaptiveThrottleMinimumRate(minRate), + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + ) + for range n { + at.reject(0, now) + } + // P = (n - k*0) / (n + minPerWindow) = n / (n+1) + want := float64(n) / float64(n+1) + got := at.rejectionProbability(0, now) + if math.Abs(got-want) > 1e-10 { + t.Errorf("n=%d: expected P=%.10f, got %.10f", n, want, got) + } + if got >= 1.0 { + t.Errorf("n=%d: P must be < 1 (got %f) — minimum rate floor is broken", n, got) } - attempts++ - _, _ = WithAdaptiveThrottle(throttle, Priority(0), func() (struct{}, error) { - sent++ - return struct{}{}, RejectedError(faults.Unavailable(0)) - }) - } - - elapsed := time.Since(start).Seconds() - forwardedPerSec := float64(sent) / elapsed - shedRate := float64(attempts-sent) / float64(attempts) - - t.Logf("attempts=%d sent=%d forwarded=%.2f/sec shed=%.1f%%", - attempts, sent, forwardedPerSec, shedRate*100) - - // The throttle must be actively shedding under total server failure. - if shedRate < 0.5 { - t.Errorf("shed rate %.1f%% is too low — throttle should shed most requests under total failure", shedRate*100) - } - // But the MinRPS floor must keep probe traffic flowing to enable recovery - // detection. With minRate=1 and demand=20, steady-state forwarding rate is - // demand*(minPerWindow/(demand+minPerWindow)) ≈ 20*(1/21) ≈ 0.95/sec. - // Use a conservative lower bound of 0.4/sec (2 requests over 5s). - if forwardedPerSec < 0.4 { - t.Errorf("forwarded %.2f/sec is below MinRPS floor — probe traffic must reach the server", forwardedPerSec) } } @@ -293,12 +374,8 @@ func TestAdaptiveThrottleMinimumRate(t *testing.T) { // rejected by the throttle. func TestFallback(t *testing.T) { ctx := context.Background() - // Short window keeps minPerWindow = 1 * 0.1 = 0.1, so after 100 rejections: - // P = 100 / (100 + 0.1) = 99.9%. Without this, the default 60s window gives - // P = 100/160 = 62.5% — not deterministic enough for a single-shot assertion. throttle := NewAdaptiveThrottle(StandardPriorities, - WithAdaptiveThrottleRatio(1), - WithAdaptiveThrottleWindow(100*time.Millisecond), + WithRandomSource(alwaysShed{}), ) for i := 0; i < 100; i++ { throttle.Throttle(ctx, 0, func(ctx context.Context) error { @@ -324,6 +401,126 @@ func TestFallback(t *testing.T) { } } +// TestWithNow verifies that WithNow controls the clock used by the throttle, +// allowing windowed counters to be fast-forwarded without real time passing. +// It saturates the throttle with failures, then advances the fake clock past +// the window and confirms the throttle stops shedding. +func TestWithNow(t *testing.T) { + const window = time.Second + + now := time.Now() + throttle := NewAdaptiveThrottle(1, + WithAdaptiveThrottleWindow(window), + WithNow(func() time.Time { return now }), + WithRandomSource(alwaysShed{}), + ) + + ctx := context.Background() + + // Saturate with rejections so shedding kicks in. + for i := 0; i < 100; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + return RejectedError(faults.Unavailable(0)) + }) + } + + // With alwaysShed and P > 0, every subsequent call must be shed. + calls := 0 + for i := 0; i < 10; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + calls++ + return nil + }) + } + if calls != 0 { + t.Fatalf("expected all calls shed before clock advance, got %d forwarded", calls) + } + + // Advance fake clock past the window — all failure buckets should expire. + now = now.Add(window) + + // With P = 0, alwaysShed returns 0 but 0 < 0 is false — all calls go through. + calls = 0 + for i := 0; i < 10; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + calls++ + return nil + }) + } + if calls != 10 { + t.Errorf("expected all 10 calls forwarded after window expired, got %d", calls) + } +} + +// TestWithRejectedErrorFunc verifies that WithRejectedErrorFunc replaces the +// global IsRejectedError for a single instance without affecting others. +func TestWithRejectedErrorFunc(t *testing.T) { + ctx := context.Background() + sentinel := errors.New("custom rejection") + + // This throttle treats sentinel as a rejection; the default would not. + custom := NewAdaptiveThrottle(1, + WithRejectedErrorFunc(func(err error) bool { + return errors.Is(err, sentinel) + }), + WithRandomSource(alwaysShed{}), + ) + // This throttle uses the default classifier; sentinel is not a rejection. + // neverShed ensures rand never triggers a shed regardless of P. + defaultThrottle := NewAdaptiveThrottle(1, WithRandomSource(neverShed{})) + + saturate := func(at *AdaptiveThrottle, err error) { + for i := 0; i < 100; i++ { + at.Throttle(ctx, 0, func(ctx context.Context) error { return err }) + } + } + + countForwarded := func(at *AdaptiveThrottle) int { + calls := 0 + for i := 0; i < 10; i++ { + at.Throttle(ctx, 0, func(ctx context.Context) error { calls++; return nil }) + } + return calls + } + + saturate(custom, sentinel) + if got := countForwarded(custom); got != 0 { + t.Errorf("custom throttle: expected all calls shed, got %d forwarded", got) + } + + saturate(defaultThrottle, sentinel) + if got := countForwarded(defaultThrottle); got != 10 { + t.Errorf("default throttle: sentinel should not count as rejection, got %d forwarded (want 10)", got) + } +} + +// TestWithClientSideRejectionError verifies that WithClientSideRejectionError +// controls the error returned for locally rejected requests. +func TestWithClientSideRejectionError(t *testing.T) { + ctx := context.Background() + customErr := errors.New("custom client rejection") + + throttle := NewAdaptiveThrottle(1, + WithClientSideRejectionError(customErr), + WithRandomSource(alwaysShed{}), + ) + + // Saturate so P > 0. + for i := 0; i < 100; i++ { + throttle.Throttle(ctx, 0, func(ctx context.Context) error { + return RejectedError(faults.Unavailable(0)) + }) + } + + // With alwaysShed every call must be shed with the custom error. + for i := 0; i < 10; i++ { + err := throttle.Throttle(ctx, 0, func(ctx context.Context) error { return nil }) + if !errors.Is(err, customErr) { + t.Errorf("call %d: expected custom rejection error, got %v", i, err) + } + } +} + // TestInvalidFallback ensures errors from the throttled function are passed // through correctly and do not trigger client-side rejection logic unexpectedly. func TestInvalidFallback(t *testing.T) { @@ -355,3 +552,15 @@ func TestInvalidFallback(t *testing.T) { }) } } + +// alwaysShed is a rand.Source whose Float64() always returns 0, causing the +// throttle to shed every request whose rejection probability is greater than 0. +type alwaysShed struct{} + +func (alwaysShed) Uint64() uint64 { return 0 } + +// neverShed is a rand.Source whose Float64() always returns a value just below +// 1, causing the throttle to never shed (since rejection probability ≤ 1). +type neverShed struct{} + +func (neverShed) Uint64() uint64 { return math.MaxUint64 }