From bff87a296b03d330c43c0d067fdf8f0453f97327 Mon Sep 17 00:00:00 2001 From: Christopher Petito Date: Thu, 19 Feb 2026 20:44:36 +0100 Subject: [PATCH] dont fallback on certain errors when using a models gateway Signed-off-by: Christopher Petito --- agent-schema.json | 11 ++ pkg/agent/agent.go | 7 ++ pkg/agent/opts.go | 8 ++ pkg/config/latest/types.go | 16 +++ pkg/config/latest/validate.go | 6 ++ pkg/config/v4/types.go | 3 + pkg/config/v4/validate.go | 6 ++ pkg/runtime/fallback.go | 39 +++++++ pkg/runtime/fallback_test.go | 196 ++++++++++++++++++++++++++++++++++ pkg/teamloader/teamloader.go | 3 + 10 files changed, 295 insertions(+) diff --git a/agent-schema.json b/agent-schema.json index 718912067..29554d684 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -353,6 +353,17 @@ "2m30s", "5m" ] + }, + "no_fallback_status_codes": { + "type": "array", + "description": "HTTP status codes for which the fallback chain should be skipped entirely when using a models gateway (--models-gateway). If the primary model returns one of these codes, the error is returned immediately without trying fallback models. Useful with gateways that perform their own routing.", + "items": { + "type": "integer" + }, + "examples": [ + [401, 403], + [429] + ] } }, "additionalProperties": false diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 5bdc1f42e..2dae66a96 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -25,6 +25,7 @@ type Agent struct { fallbackModels []provider.Provider // Fallback models to try if primary fails fallbackRetries int // Number of retries per fallback model with exponential backoff fallbackCooldown time.Duration // Duration to stick with fallback after non-retryable error + noFallbackStatusCodes []int // Status codes that skip fallback when using a gateway modelOverrides atomic.Pointer[[]provider.Provider] // Optional model override(s) set at runtime (supports alloy) subAgents []*Agent handoffs []*Agent @@ -188,6 +189,12 @@ func (a *Agent) FallbackCooldown() time.Duration { return a.fallbackCooldown } +// NoFallbackStatusCodes returns the HTTP status codes for which fallback should +// be skipped when using a models gateway. +func (a *Agent) NoFallbackStatusCodes() []int { + return a.noFallbackStatusCodes +} + // Commands returns the named commands configured for this agent. func (a *Agent) Commands() types.Commands { return a.commands diff --git a/pkg/agent/opts.go b/pkg/agent/opts.go index 3cedf991a..4ec3395a9 100644 --- a/pkg/agent/opts.go +++ b/pkg/agent/opts.go @@ -166,3 +166,11 @@ func WithThinkingConfigured(configured bool) Opt { a.thinkingConfigured = configured } } + +// WithNoFallbackStatusCodes sets the HTTP status codes for which the fallback +// chain should be skipped entirely when using a models gateway. +func WithNoFallbackStatusCodes(codes []int) Opt { + return func(a *Agent) { + a.noFallbackStatusCodes = codes + } +} diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 6d9f6d78e..df0649791 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -127,6 +127,13 @@ type FallbackConfig struct { // retrying the primary. Only applies after a non-retryable error (e.g., 429). // Default is 1 minute. Use Go duration format (e.g., "1m", "30s", "2m30s"). Cooldown Duration `json:"cooldown"` + // NoFallbackStatusCodes is a list of HTTP status codes for which the fallback + // chain should be skipped entirely when using a models gateway. If the primary + // model returns one of these status codes, the error is returned immediately + // without trying fallback models. This is useful with gateways that perform + // their own routing and return specific codes (e.g., 401, 403) that should + // not trigger client-side fallback. + NoFallbackStatusCodes []int `json:"no_fallback_status_codes,omitempty"` } // Duration is a wrapper around time.Duration that supports YAML/JSON unmarshaling @@ -257,6 +264,15 @@ func (a *AgentConfig) GetFallbackCooldown() time.Duration { return 0 } +// GetNoFallbackStatusCodes returns the status codes for which fallback should be +// skipped when using a models gateway. +func (a *AgentConfig) GetNoFallbackStatusCodes() []int { + if a.Fallback != nil { + return a.Fallback.NoFallbackStatusCodes + } + return nil +} + // ModelConfig represents the configuration for a model type ModelConfig struct { // Name is the manifest model name (map key), populated at runtime. diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 438c1ef11..0a5e2b8e4 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -53,6 +53,12 @@ func (a *AgentConfig) validateFallback() error { return errors.New("fallback.cooldown must be non-negative") } + for _, code := range a.Fallback.NoFallbackStatusCodes { + if code < 400 || code > 599 { + return errors.New("fallback.no_fallback_status_codes must contain HTTP error codes (400-599)") + } + } + return nil } diff --git a/pkg/config/v4/types.go b/pkg/config/v4/types.go index d443528d6..67860513e 100644 --- a/pkg/config/v4/types.go +++ b/pkg/config/v4/types.go @@ -127,6 +127,9 @@ type FallbackConfig struct { // retrying the primary. Only applies after a non-retryable error (e.g., 429). // Default is 1 minute. Use Go duration format (e.g., "1m", "30s", "2m30s"). Cooldown Duration `json:"cooldown"` + // NoFallbackStatusCodes is a list of HTTP status codes for which the fallback + // chain should be skipped entirely when using a models gateway. + NoFallbackStatusCodes []int `json:"no_fallback_status_codes,omitempty"` } // Duration is a wrapper around time.Duration that supports YAML/JSON unmarshaling diff --git a/pkg/config/v4/validate.go b/pkg/config/v4/validate.go index aede53511..3a2b12521 100644 --- a/pkg/config/v4/validate.go +++ b/pkg/config/v4/validate.go @@ -53,6 +53,12 @@ func (a *AgentConfig) validateFallback() error { return errors.New("fallback.cooldown must be non-negative") } + for _, code := range a.Fallback.NoFallbackStatusCodes { + if code < 400 || code > 599 { + return errors.New("fallback.no_fallback_status_codes must contain HTTP error codes (400-599)") + } + } + return nil } diff --git a/pkg/runtime/fallback.go b/pkg/runtime/fallback.go index 14a5549c6..f1dcde575 100644 --- a/pkg/runtime/fallback.go +++ b/pkg/runtime/fallback.go @@ -389,6 +389,18 @@ func getEffectiveRetries(a *agent.Agent) int { return retries } +// isNoFallbackStatusCode returns true if the given status code is in the +// agent's configured no-fallback set. Used with models gateways to short-circuit +// the fallback chain for specific error codes. +func isNoFallbackStatusCode(statusCode int, codes []int) bool { + for _, c := range codes { + if c == statusCode { + return true + } + } + return false +} + // tryModelWithFallback attempts to create a stream and get a response using the primary model, // falling back to configured fallback models if the primary fails. // @@ -441,6 +453,13 @@ func (r *LocalRuntime) tryModelWithFallback( "cooldown_until", cooldownState.until.Format(time.RFC3339)) } + // When using a models gateway, check if this error's status code should + // skip the entire fallback chain. The gateway handles its own routing, + // so certain status codes (e.g., auth errors) won't resolve by trying + // a different client-side fallback model. + useGateway := r.modelSwitcherCfg != nil && r.modelSwitcherCfg.ModelsGateway != "" + noFallbackCodes := a.NoFallbackStatusCodes() + var lastErr error primaryFailedWithNonRetryable := false @@ -508,6 +527,16 @@ func (r *LocalRuntime) tryModelWithFallback( "model", modelEntry.provider.ID(), "error", err) + if useGateway && len(noFallbackCodes) > 0 { + if sc := extractHTTPStatusCode(err); sc != 0 && isNoFallbackStatusCode(sc, noFallbackCodes) { + slog.Warn("Gateway no-fallback status code, skipping entire fallback chain", + "agent", a.Name(), + "status_code", sc, + "error", err) + return streamResult{}, nil, err + } + } + // Track if primary failed with non-retryable error if !modelEntry.isFallback { primaryFailedWithNonRetryable = true @@ -544,6 +573,16 @@ func (r *LocalRuntime) tryModelWithFallback( "model", modelEntry.provider.ID(), "error", err) + if useGateway && len(noFallbackCodes) > 0 { + if sc := extractHTTPStatusCode(err); sc != 0 && isNoFallbackStatusCode(sc, noFallbackCodes) { + slog.Warn("Gateway no-fallback status code, skipping entire fallback chain", + "agent", a.Name(), + "status_code", sc, + "error", err) + return streamResult{}, nil, err + } + } + // Track if primary failed with non-retryable error if !modelEntry.isFallback { primaryFailedWithNonRetryable = true diff --git a/pkg/runtime/fallback_test.go b/pkg/runtime/fallback_test.go index 31d133fff..641f0037b 100644 --- a/pkg/runtime/fallback_test.go +++ b/pkg/runtime/fallback_test.go @@ -877,6 +877,202 @@ func TestFallbackModelsClonedWithThinkingEnabled(t *testing.T) { }) } +func TestIsNoFallbackStatusCode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + codes []int + expected bool + }{ + { + name: "empty codes list", + statusCode: 401, + codes: nil, + expected: false, + }, + { + name: "status code in list", + statusCode: 401, + codes: []int{401, 403}, + expected: true, + }, + { + name: "status code not in list", + statusCode: 429, + codes: []int{401, 403}, + expected: false, + }, + { + name: "single code match", + statusCode: 429, + codes: []int{429}, + expected: true, + }, + { + name: "zero status code", + statusCode: 0, + codes: []int{401, 403}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := isNoFallbackStatusCode(tt.statusCode, tt.codes) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGatewayNoFallbackStatusCodes(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary fails with 401 (configured as no-fallback when using gateway) + primary := &failingProvider{id: "primary/auth-fail", err: errors.New("401 unauthorized")} + + // Fallback should NOT be tried because 401 is in no_fallback_status_codes + fallback := &countingProvider{ + id: "fallback/should-not-be-called", + failCount: 0, + stream: newStreamBuilder(). + AddContent("Fallback content"). + AddStopWithUsage(5, 2). + Build(), + } + + root := agent.New("root", "test", + agent.WithModel(primary), + agent.WithFallbackModel(fallback), + agent.WithFallbackRetries(2), + agent.WithNoFallbackStatusCodes([]int{401, 403}), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(&ModelSwitcherConfig{ + ModelsGateway: "https://gateway.example.com", + }), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "Gateway No-Fallback Test" + + events := rt.RunStream(t.Context(), sess) + + var gotError bool + var gotFallbackContent bool + for ev := range events { + if _, ok := ev.(*ErrorEvent); ok { + gotError = true + } + if choice, ok := ev.(*AgentChoiceEvent); ok { + if choice.Content == "Fallback content" { + gotFallbackContent = true + } + } + } + + assert.True(t, gotError, "should get an error since 401 is in no-fallback codes") + assert.False(t, gotFallbackContent, "fallback should not be tried for no-fallback status code") + assert.Equal(t, 0, fallback.callCount, "fallback provider should not be called") + }) +} + +func TestGatewayNoFallbackStatusCodes_AllowsFallbackForOtherCodes(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary fails with 429 (NOT in the no-fallback list) + primary := &failingProvider{id: "primary/rate-limited", err: errors.New("429 too many requests")} + + // Fallback should be tried since 429 is not in no_fallback_status_codes + successStream := newStreamBuilder(). + AddContent("Fallback success"). + AddStopWithUsage(10, 5). + Build() + fallback := &mockProvider{id: "fallback/success", stream: successStream} + + root := agent.New("root", "test", + agent.WithModel(primary), + agent.WithFallbackModel(fallback), + agent.WithFallbackRetries(0), + agent.WithNoFallbackStatusCodes([]int{401, 403}), // Only 401 and 403 block fallback + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(&ModelSwitcherConfig{ + ModelsGateway: "https://gateway.example.com", + }), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "Gateway Allows Fallback Test" + + events := rt.RunStream(t.Context(), sess) + + var gotFallbackContent bool + for ev := range events { + if choice, ok := ev.(*AgentChoiceEvent); ok { + if choice.Content == "Fallback success" { + gotFallbackContent = true + } + } + } + + assert.True(t, gotFallbackContent, "should receive fallback content since 429 is not in no-fallback list") + }) +} + +func TestGatewayNoFallbackStatusCodes_NoEffectWithoutGateway(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary fails with 401 + primary := &failingProvider{id: "primary/auth-fail", err: errors.New("401 unauthorized")} + + // Even though 401 is in no_fallback_status_codes, without a gateway + // the fallback chain should proceed normally + successStream := newStreamBuilder(). + AddContent("Fallback success without gateway"). + AddStopWithUsage(10, 5). + Build() + fallback := &mockProvider{id: "fallback/success", stream: successStream} + + root := agent.New("root", "test", + agent.WithModel(primary), + agent.WithFallbackModel(fallback), + agent.WithFallbackRetries(0), + agent.WithNoFallbackStatusCodes([]int{401, 403}), + ) + + tm := team.New(team.WithAgents(root)) + // No WithModelSwitcherConfig — no gateway + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "No Gateway Test" + + events := rt.RunStream(t.Context(), sess) + + var gotFallbackContent bool + for ev := range events { + if choice, ok := ev.(*AgentChoiceEvent); ok { + if choice.Content == "Fallback success without gateway" { + gotFallbackContent = true + } + } + } + + assert.True(t, gotFallbackContent, "fallback should proceed normally without a gateway") + }) +} + // Verify interface compliance var ( _ provider.Provider = (*mockProvider)(nil) diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 8cd0e3c9b..abd5a96da 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -216,6 +216,9 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c agent.WithFallbackRetries(agentConfig.GetFallbackRetries()), agent.WithFallbackCooldown(agentConfig.GetFallbackCooldown()), ) + if codes := agentConfig.GetNoFallbackStatusCodes(); len(codes) > 0 { + opts = append(opts, agent.WithNoFallbackStatusCodes(codes)) + } } agentTools, warnings := getToolsForAgent(ctx, &agentConfig, parentDir, runConfig, loadOpts.toolsetRegistry)