diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index 2a58651881..d8b384e6cc 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -173,6 +173,22 @@ pub trait API: Sync + Send { /// suggestion generation). async fn set_suggest_config(&self, config: forge_domain::SuggestConfig) -> anyhow::Result<()>; + /// Gets the per-agent model configuration for a specific agent. + async fn get_agent_model_config( + &self, + agent_id: &AgentId, + ) -> anyhow::Result>; + + /// Sets the per-agent model configuration for a specific agent. + async fn set_agent_model_config( + &self, + agent_id: AgentId, + config: forge_domain::AgentModelConfig, + ) -> anyhow::Result<()>; + + /// Clears the per-agent model configuration, reverting to global defaults. + async fn clear_agent_model_config(&self, agent_id: AgentId) -> anyhow::Result<()>; + /// Refresh MCP caches by fetching fresh data async fn reload_mcp(&self) -> Result<()>; diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index a2dc7847a7..31a2ddc3d8 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -300,6 +300,31 @@ impl< self.services.set_suggest_config(config).await } + async fn get_agent_model_config( + &self, + agent_id: &AgentId, + ) -> anyhow::Result> { + self.services.get_agent_model_config(agent_id).await + } + + async fn set_agent_model_config( + &self, + agent_id: AgentId, + config: AgentModelConfig, + ) -> anyhow::Result<()> { + let result = self.services.set_agent_model_config(agent_id, config).await; + // Invalidate agent cache so the new model is picked up + let _ = self.services.reload_agents().await; + result + } + + async fn clear_agent_model_config(&self, agent_id: AgentId) -> anyhow::Result<()> { + let result = self.services.clear_agent_model_config(agent_id).await; + // Invalidate agent cache + let _ = self.services.reload_agents().await; + result + } + async fn get_login_info(&self) -> Result> { self.services.auth_service().get_auth_token().await } diff --git a/crates/forge_app/src/command_generator.rs b/crates/forge_app/src/command_generator.rs index dd04b2b5cb..cf45aca320 100644 --- a/crates/forge_app/src/command_generator.rs +++ b/crates/forge_app/src/command_generator.rs @@ -268,6 +268,25 @@ mod tests { async fn set_suggest_config(&self, _config: forge_domain::SuggestConfig) -> Result<()> { Ok(()) } + + async fn get_agent_model_config( + &self, + _agent_id: &forge_domain::AgentId, + ) -> Result> { + Ok(None) + } + + async fn set_agent_model_config( + &self, + _agent_id: forge_domain::AgentId, + _config: forge_domain::AgentModelConfig, + ) -> Result<()> { + Ok(()) + } + + async fn clear_agent_model_config(&self, _agent_id: forge_domain::AgentId) -> Result<()> { + Ok(()) + } } #[tokio::test] diff --git a/crates/forge_app/src/services.rs b/crates/forge_app/src/services.rs index 0e1525af5f..0311a9cef2 100644 --- a/crates/forge_app/src/services.rs +++ b/crates/forge_app/src/services.rs @@ -227,6 +227,23 @@ pub trait AppConfigService: Send + Sync { /// Sets the suggest configuration (provider and model for command /// suggestion generation). async fn set_suggest_config(&self, config: forge_domain::SuggestConfig) -> anyhow::Result<()>; + + /// Gets the per-agent model configuration for a specific agent. + async fn get_agent_model_config( + &self, + agent_id: &forge_domain::AgentId, + ) -> anyhow::Result>; + + /// Sets the per-agent model configuration for a specific agent. + async fn set_agent_model_config( + &self, + agent_id: forge_domain::AgentId, + config: forge_domain::AgentModelConfig, + ) -> anyhow::Result<()>; + + /// Clears the per-agent model configuration, reverting to global defaults. + async fn clear_agent_model_config(&self, agent_id: forge_domain::AgentId) + -> anyhow::Result<()>; } #[async_trait::async_trait] @@ -1060,6 +1077,32 @@ impl AppConfigService for I { async fn set_suggest_config(&self, config: forge_domain::SuggestConfig) -> anyhow::Result<()> { self.config_service().set_suggest_config(config).await } + + async fn get_agent_model_config( + &self, + agent_id: &forge_domain::AgentId, + ) -> anyhow::Result> { + self.config_service().get_agent_model_config(agent_id).await + } + + async fn set_agent_model_config( + &self, + agent_id: forge_domain::AgentId, + config: forge_domain::AgentModelConfig, + ) -> anyhow::Result<()> { + self.config_service() + .set_agent_model_config(agent_id, config) + .await + } + + async fn clear_agent_model_config( + &self, + agent_id: forge_domain::AgentId, + ) -> anyhow::Result<()> { + self.config_service() + .clear_agent_model_config(agent_id) + .await + } } #[async_trait::async_trait] diff --git a/crates/forge_config/src/config.rs b/crates/forge_config/src/config.rs index 37862df654..176e2f2dd1 100644 --- a/crates/forge_config/src/config.rs +++ b/crates/forge_config/src/config.rs @@ -77,6 +77,15 @@ pub struct ForgeConfig { /// Provider and model to use for shell command suggestion generation #[serde(default)] pub suggest: Option, + /// Provider and model override for the forge agent + #[serde(default, skip_serializing_if = "Option::is_none")] + pub forge_model: Option, + /// Provider and model override for the sage agent + #[serde(default, skip_serializing_if = "Option::is_none")] + pub sage_model: Option, + /// Provider and model override for the muse agent + #[serde(default, skip_serializing_if = "Option::is_none")] + pub muse_model: Option, /// API key for Forge authentication #[serde(default)] pub api_key: Option, diff --git a/crates/forge_domain/src/app_config.rs b/crates/forge_domain/src/app_config.rs index 886df2730c..b0fb4a26f7 100644 --- a/crates/forge_domain/src/app_config.rs +++ b/crates/forge_domain/src/app_config.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; use derive_more::From; +use derive_setters::Setters; use serde::{Deserialize, Serialize}; -use crate::{CommitConfig, ModelId, ProviderId, SuggestConfig}; +use crate::{AgentId, CommitConfig, ModelId, ProviderId, SuggestConfig}; #[derive(Deserialize)] #[serde(rename_all = "camelCase")] @@ -13,6 +14,19 @@ pub struct InitAuth { pub token: String, } +/// Per-agent model and provider configuration. +/// +/// Allows overriding the default provider and model for a specific agent +/// (e.g., forge, sage, muse). Both fields must be specified together. +#[derive(Debug, Clone, Serialize, Deserialize, Setters, PartialEq)] +#[setters(into)] +pub struct AgentModelConfig { + /// Provider ID to use for this agent. + pub provider: ProviderId, + /// Model ID to use for this agent. + pub model: ModelId, +} + #[derive(Default, Clone, Debug, PartialEq)] pub struct AppConfig { pub key_info: Option, @@ -20,6 +34,9 @@ pub struct AppConfig { pub model: HashMap, pub commit: Option, pub suggest: Option, + /// Per-agent model overrides. When set, the agent will use the specified + /// provider and model instead of the global defaults. + pub agent_models: HashMap, } #[derive(Clone, Serialize, Deserialize, From, Debug, PartialEq)] @@ -53,4 +70,8 @@ pub enum AppConfigOperation { SetCommitConfig(CommitConfig), /// Set the shell-command suggestion configuration. SetSuggestConfig(SuggestConfig), + /// Set the model and provider for a specific agent. + SetAgentModel(AgentId, AgentModelConfig), + /// Clear the per-agent model override, reverting to global defaults. + ClearAgentModel(AgentId), } diff --git a/crates/forge_main/src/built_in_commands.json b/crates/forge_main/src/built_in_commands.json index b57f5b0dc8..40eb89da8a 100644 --- a/crates/forge_main/src/built_in_commands.json +++ b/crates/forge_main/src/built_in_commands.json @@ -35,6 +35,18 @@ "command": "config", "description": "List current configuration values" }, + { + "command": "config-forge-model", + "description": "Set the model for the forge agent [alias: cfm]" + }, + { + "command": "config-sage-model", + "description": "Set the model for the sage agent [alias: csgm]" + }, + { + "command": "config-muse-model", + "description": "Set the model for the muse agent [alias: cmm]" + }, { "command": "config-edit", "description": "Open the global forge config file (~/forge/.forge.toml) in an editor [alias: ce]" diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index 13d6845afc..840107fd20 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -551,6 +551,15 @@ pub enum ConfigSetField { /// Model ID to use for command suggestion generation. model: ModelId, }, + /// Set the provider and model for a specific agent (forge, sage, or muse). + AgentModel { + /// Agent ID (forge, sage, or muse). + agent: AgentId, + /// Provider ID to use for this agent. + provider: ProviderId, + /// Model ID to use for this agent. + model: ModelId, + }, } /// Type-safe subcommands for `forge config get`. @@ -564,6 +573,11 @@ pub enum ConfigGetField { Commit, /// Get the command suggestion generation config. Suggest, + /// Get the per-agent model configuration for a specific agent. + AgentModel { + /// Agent ID (forge, sage, or muse). + agent: AgentId, + }, } /// Command group for conversation management. diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 9d8ec6a0a0..01c58ccfa7 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -1330,10 +1330,47 @@ impl A + Send + Sync> UI { .map(|c| c.model.as_str().to_string()) .unwrap_or_else(|| markers::EMPTY.to_string()); + // Per-agent model overrides + let forge_config = self + .api + .get_agent_model_config(&AgentId::FORGE) + .await + .ok() + .flatten(); + let forge_model = forge_config + .as_ref() + .map(|c| format!("{} ({})", c.model.as_str(), c.provider)) + .unwrap_or_else(|| markers::EMPTY.to_string()); + + let sage_config = self + .api + .get_agent_model_config(&AgentId::SAGE) + .await + .ok() + .flatten(); + let sage_model = sage_config + .as_ref() + .map(|c| format!("{} ({})", c.model.as_str(), c.provider)) + .unwrap_or_else(|| markers::EMPTY.to_string()); + + let muse_config = self + .api + .get_agent_model_config(&AgentId::MUSE) + .await + .ok() + .flatten(); + let muse_model = muse_config + .as_ref() + .map(|c| format!("{} ({})", c.model.as_str(), c.provider)) + .unwrap_or_else(|| markers::EMPTY.to_string()); + let info = Info::new() .add_title("CONFIGURATION") .add_key_value("Default Model", model) .add_key_value("Default Provider", provider) + .add_key_value("Forge Model", forge_model) + .add_key_value("Sage Model", sage_model) + .add_key_value("Muse Model", muse_model) .add_key_value("Commit Provider", commit_provider) .add_key_value("Commit Model", commit_model) .add_key_value("Suggest Provider", suggest_provider) @@ -3444,6 +3481,20 @@ impl A + Send + Sync> UI { format!("is now the suggest model for provider '{provider}'"), ))?; } + ConfigSetField::AgentModel { agent, provider, model } => { + // Validate provider exists and model belongs to that specific provider + let validated_model = self.validate_model(model.as_str(), Some(&provider)).await?; + let agent_config = forge_domain::AgentModelConfig { + provider: provider.clone(), + model: validated_model.clone(), + }; + self.api + .set_agent_model_config(agent.clone(), agent_config) + .await?; + self.writeln_title(TitleFormat::action(validated_model.as_str()).sub_title( + format!("is now the model for agent '{agent}' (provider: '{provider}')"), + ))?; + } } Ok(()) @@ -3505,6 +3556,16 @@ impl A + Send + Sync> UI { None => self.writeln("Suggest: Not set")?, } } + ConfigGetField::AgentModel { agent } => { + let agent_config = self.api.get_agent_model_config(&agent).await?; + match agent_config { + Some(config) => { + self.writeln(config.provider.as_ref())?; + self.writeln(config.model.as_str().to_string())?; + } + None => self.writeln(format!("Agent model for '{agent}': Not set"))?, + } + } } Ok(()) diff --git a/crates/forge_repo/src/app_config.rs b/crates/forge_repo/src/app_config.rs index 89061537b9..fc6e8b8ca9 100644 --- a/crates/forge_repo/src/app_config.rs +++ b/crates/forge_repo/src/app_config.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use forge_config::{ConfigReader, ForgeConfig, ModelConfig}; use forge_domain::{ - AppConfig, AppConfigOperation, AppConfigRepository, CommitConfig, LoginInfo, ModelId, - ProviderId, SuggestConfig, + AgentId, AgentModelConfig, AppConfig, AppConfigOperation, AppConfigRepository, CommitConfig, + LoginInfo, ModelId, ProviderId, SuggestConfig, }; use tokio::sync::Mutex; use tracing::{debug, error}; @@ -50,7 +50,23 @@ fn forge_config_to_app_config(fc: ForgeConfig) -> AppConfig { }) }); - AppConfig { key_info, provider, model, commit, suggest } + // Build per-agent model overrides + let mut agent_models = std::collections::HashMap::new(); + for (agent_id, mc_opt) in [ + (AgentId::FORGE, fc.forge_model), + (AgentId::SAGE, fc.sage_model), + (AgentId::MUSE, fc.muse_model), + ] { + if let Some(mc) = mc_opt + && let (Some(pid), Some(mid)) = (mc.provider_id, mc.model_id) { + agent_models.insert( + agent_id, + AgentModelConfig { provider: ProviderId::from(pid), model: ModelId::new(mid) }, + ); + } + } + + AppConfig { key_info, provider, model, commit, suggest, agent_models } } /// Applies a single [`AppConfigOperation`] directly onto a [`ForgeConfig`] @@ -106,6 +122,33 @@ fn apply_op(op: AppConfigOperation, fc: &mut ForgeConfig) { .model_id(suggest.model.to_string()), ); } + AppConfigOperation::SetAgentModel(agent_id, config) => { + let mc = Some( + ModelConfig::default() + .provider_id(config.provider.as_ref().to_string()) + .model_id(config.model.to_string()), + ); + match agent_id.as_str() { + "forge" => fc.forge_model = mc, + "sage" => fc.sage_model = mc, + "muse" => fc.muse_model = mc, + _ => { + // For custom agents, we currently only support the built-in + // three. This could be extended with a HashMap in ForgeConfig + // if needed. + tracing::warn!( + agent = agent_id.as_str(), + "Per-agent model config is only supported for forge, sage, and muse" + ); + } + } + } + AppConfigOperation::ClearAgentModel(agent_id) => match agent_id.as_str() { + "forge" => fc.forge_model = None, + "sage" => fc.sage_model = None, + "muse" => fc.muse_model = None, + _ => {} + }, } } diff --git a/crates/forge_services/src/agent_registry.rs b/crates/forge_services/src/agent_registry.rs index 256e003574..38cfc6419e 100644 --- a/crates/forge_services/src/agent_registry.rs +++ b/crates/forge_services/src/agent_registry.rs @@ -90,8 +90,15 @@ impl ForgeAgentRe // Convert definitions to runtime agents and populate map for def in agent_defs { - let agent = + let mut agent = Agent::from_agent_def(def, default_provider.id.clone(), default_model.clone()); + + // Apply per-agent model override if configured + if let Some(agent_config) = app_config.agent_models.get(&agent.id) { + agent.provider = agent_config.provider.clone(); + agent.model = agent_config.model.clone(); + } + agents_map.insert(agent.id.as_str().to_string(), agent); } diff --git a/crates/forge_services/src/app_config.rs b/crates/forge_services/src/app_config.rs index d4737c27e6..26c6c548ed 100644 --- a/crates/forge_services/src/app_config.rs +++ b/crates/forge_services/src/app_config.rs @@ -100,6 +100,31 @@ impl AppConfigService self.update(AppConfigOperation::SetSuggestConfig(suggest_config)) .await } + + async fn get_agent_model_config( + &self, + agent_id: &forge_domain::AgentId, + ) -> anyhow::Result> { + let config = self.infra.get_app_config().await?; + Ok(config.agent_models.get(agent_id).cloned()) + } + + async fn set_agent_model_config( + &self, + agent_id: forge_domain::AgentId, + config: forge_domain::AgentModelConfig, + ) -> anyhow::Result<()> { + self.update(AppConfigOperation::SetAgentModel(agent_id, config)) + .await + } + + async fn clear_agent_model_config( + &self, + agent_id: forge_domain::AgentId, + ) -> anyhow::Result<()> { + self.update(AppConfigOperation::ClearAgentModel(agent_id)) + .await + } } #[cfg(test)] @@ -201,6 +226,12 @@ mod tests { } AppConfigOperation::SetCommitConfig(commit) => config.commit = Some(commit), AppConfigOperation::SetSuggestConfig(suggest) => config.suggest = Some(suggest), + AppConfigOperation::SetAgentModel(agent_id, agent_config) => { + config.agent_models.insert(agent_id, agent_config); + } + AppConfigOperation::ClearAgentModel(agent_id) => { + config.agent_models.remove(&agent_id); + } } } Ok(()) diff --git a/shell-plugin/lib/actions/config.zsh b/shell-plugin/lib/actions/config.zsh index a05a331b72..29781e1a90 100644 --- a/shell-plugin/lib/actions/config.zsh +++ b/shell-plugin/lib/actions/config.zsh @@ -239,6 +239,61 @@ function _forge_action_suggest_model() { ) } +# Helper: Select model for a specific agent (forge, sage, or muse). +# Calls `forge config set agent-model ` on selection. +# Arguments: +# $1 agent_id - The agent to configure (forge, sage, or muse) +# $2 input_text - Optional pre-fill query for fzf +function _forge_action_agent_model() { + local agent_id="$1" + local input_text="$2" + ( + echo + # config get agent-model outputs two lines: provider_id (raw) then model_id + local agent_output current_agent_model current_agent_provider + agent_output=$(_forge_exec config get agent-model "$agent_id" 2>/dev/null) + current_agent_provider=$(echo "$agent_output" | head -n 1) + current_agent_model=$(echo "$agent_output" | tail -n 1) + + # If output contains "Not set", clear the values + if [[ "$agent_output" == *"Not set"* ]]; then + current_agent_provider="" + current_agent_model="" + fi + + local prompt_label="${agent_id:u} Model" + local selected + # provider_id from config get agent-model is the raw id, matching porcelain field 4 + selected=$(_forge_pick_model "${prompt_label} ❯ " "$current_agent_model" "$input_text" "$current_agent_provider" 4) + + if [[ -n "$selected" ]]; then + # Field 1 = model_id (raw), field 4 = provider_id (raw) + local model_id provider_id + read -r model_id provider_id <<<$(echo "$selected" | awk -F ' +' '{print $1, $4}') + + model_id=${model_id//[[:space:]]/} + provider_id=${provider_id//[[:space:]]/} + + _forge_exec config set agent-model "$agent_id" "$provider_id" "$model_id" + fi + ) +} + +# Action handler: Select model for the forge agent +function _forge_action_forge_model() { + _forge_action_agent_model "forge" "$1" +} + +# Action handler: Select model for the sage agent +function _forge_action_sage_model() { + _forge_action_agent_model "sage" "$1" +} + +# Action handler: Select model for the muse agent +function _forge_action_muse_model() { + _forge_action_agent_model "muse" "$1" +} + # Action handler: Sync workspace for codebase search function _forge_action_sync() { echo diff --git a/shell-plugin/lib/dispatcher.zsh b/shell-plugin/lib/dispatcher.zsh index e162f28dcf..3054a2b8b2 100644 --- a/shell-plugin/lib/dispatcher.zsh +++ b/shell-plugin/lib/dispatcher.zsh @@ -184,6 +184,15 @@ function forge-accept-line() { config-suggest-model|csm) _forge_action_suggest_model "$input_text" ;; + config-forge-model|cfm) + _forge_action_forge_model "$input_text" + ;; + config-sage-model|csgm) + _forge_action_sage_model "$input_text" + ;; + config-muse-model|cmm) + _forge_action_muse_model "$input_text" + ;; tools|t) _forge_action_tools ;;