Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,22 @@ pub trait API: Sync + Send {
/// suggestion generation).
async fn set_suggest_config(&self, config: forge_domain::SuggestConfig) -> anyhow::Result<()>;

/// Gets the per-agent model configuration for a specific agent.
async fn get_agent_model_config(
&self,
agent_id: &AgentId,
) -> anyhow::Result<Option<forge_domain::AgentModelConfig>>;

/// Sets the per-agent model configuration for a specific agent.
async fn set_agent_model_config(
&self,
agent_id: AgentId,
config: forge_domain::AgentModelConfig,
) -> anyhow::Result<()>;

/// Clears the per-agent model configuration, reverting to global defaults.
async fn clear_agent_model_config(&self, agent_id: AgentId) -> anyhow::Result<()>;

/// Refresh MCP caches by fetching fresh data
async fn reload_mcp(&self) -> Result<()>;

Expand Down
25 changes: 25 additions & 0 deletions crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,31 @@ impl<
self.services.set_suggest_config(config).await
}

async fn get_agent_model_config(
&self,
agent_id: &AgentId,
) -> anyhow::Result<Option<AgentModelConfig>> {
self.services.get_agent_model_config(agent_id).await
}

async fn set_agent_model_config(
&self,
agent_id: AgentId,
config: AgentModelConfig,
) -> anyhow::Result<()> {
let result = self.services.set_agent_model_config(agent_id, config).await;
// Invalidate agent cache so the new model is picked up
let _ = self.services.reload_agents().await;
result
}

async fn clear_agent_model_config(&self, agent_id: AgentId) -> anyhow::Result<()> {
let result = self.services.clear_agent_model_config(agent_id).await;
// Invalidate agent cache
let _ = self.services.reload_agents().await;
result
}

async fn get_login_info(&self) -> Result<Option<LoginInfo>> {
self.services.auth_service().get_auth_token().await
}
Expand Down
19 changes: 19 additions & 0 deletions crates/forge_app/src/command_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,25 @@ mod tests {
async fn set_suggest_config(&self, _config: forge_domain::SuggestConfig) -> Result<()> {
Ok(())
}

async fn get_agent_model_config(
&self,
_agent_id: &forge_domain::AgentId,
) -> Result<Option<forge_domain::AgentModelConfig>> {
Ok(None)
}

async fn set_agent_model_config(
&self,
_agent_id: forge_domain::AgentId,
_config: forge_domain::AgentModelConfig,
) -> Result<()> {
Ok(())
}

async fn clear_agent_model_config(&self, _agent_id: forge_domain::AgentId) -> Result<()> {
Ok(())
}
}

#[tokio::test]
Expand Down
43 changes: 43 additions & 0 deletions crates/forge_app/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,23 @@ pub trait AppConfigService: Send + Sync {
/// Sets the suggest configuration (provider and model for command
/// suggestion generation).
async fn set_suggest_config(&self, config: forge_domain::SuggestConfig) -> anyhow::Result<()>;

/// Gets the per-agent model configuration for a specific agent.
async fn get_agent_model_config(
&self,
agent_id: &forge_domain::AgentId,
) -> anyhow::Result<Option<forge_domain::AgentModelConfig>>;

/// Sets the per-agent model configuration for a specific agent.
async fn set_agent_model_config(
&self,
agent_id: forge_domain::AgentId,
config: forge_domain::AgentModelConfig,
) -> anyhow::Result<()>;

/// Clears the per-agent model configuration, reverting to global defaults.
async fn clear_agent_model_config(&self, agent_id: forge_domain::AgentId)
-> anyhow::Result<()>;
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -1060,6 +1077,32 @@ impl<I: Services> AppConfigService for I {
async fn set_suggest_config(&self, config: forge_domain::SuggestConfig) -> anyhow::Result<()> {
self.config_service().set_suggest_config(config).await
}

async fn get_agent_model_config(
&self,
agent_id: &forge_domain::AgentId,
) -> anyhow::Result<Option<forge_domain::AgentModelConfig>> {
self.config_service().get_agent_model_config(agent_id).await
}

async fn set_agent_model_config(
&self,
agent_id: forge_domain::AgentId,
config: forge_domain::AgentModelConfig,
) -> anyhow::Result<()> {
self.config_service()
.set_agent_model_config(agent_id, config)
.await
}

async fn clear_agent_model_config(
&self,
agent_id: forge_domain::AgentId,
) -> anyhow::Result<()> {
self.config_service()
.clear_agent_model_config(agent_id)
.await
}
}

#[async_trait::async_trait]
Expand Down
9 changes: 9 additions & 0 deletions crates/forge_config/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ pub struct ForgeConfig {
/// Provider and model to use for shell command suggestion generation
#[serde(default)]
pub suggest: Option<ModelConfig>,
/// Provider and model override for the forge agent
#[serde(default, skip_serializing_if = "Option::is_none")]
pub forge_model: Option<ModelConfig>,
/// Provider and model override for the sage agent
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sage_model: Option<ModelConfig>,
/// Provider and model override for the muse agent
#[serde(default, skip_serializing_if = "Option::is_none")]
pub muse_model: Option<ModelConfig>,
/// API key for Forge authentication
#[serde(default)]
pub api_key: Option<String>,
Expand Down
23 changes: 22 additions & 1 deletion crates/forge_domain/src/app_config.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::collections::HashMap;

use derive_more::From;
use derive_setters::Setters;
use serde::{Deserialize, Serialize};

use crate::{CommitConfig, ModelId, ProviderId, SuggestConfig};
use crate::{AgentId, CommitConfig, ModelId, ProviderId, SuggestConfig};

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
Expand All @@ -13,13 +14,29 @@ pub struct InitAuth {
pub token: String,
}

/// Per-agent model and provider configuration.
///
/// Allows overriding the default provider and model for a specific agent
/// (e.g., forge, sage, muse). Both fields must be specified together.
#[derive(Debug, Clone, Serialize, Deserialize, Setters, PartialEq)]
#[setters(into)]
pub struct AgentModelConfig {
/// Provider ID to use for this agent.
pub provider: ProviderId,
/// Model ID to use for this agent.
pub model: ModelId,
}

#[derive(Default, Clone, Debug, PartialEq)]
pub struct AppConfig {
pub key_info: Option<LoginInfo>,
pub provider: Option<ProviderId>,
pub model: HashMap<ProviderId, ModelId>,
pub commit: Option<CommitConfig>,
pub suggest: Option<SuggestConfig>,
/// Per-agent model overrides. When set, the agent will use the specified
/// provider and model instead of the global defaults.
pub agent_models: HashMap<AgentId, AgentModelConfig>,
}

#[derive(Clone, Serialize, Deserialize, From, Debug, PartialEq)]
Expand Down Expand Up @@ -53,4 +70,8 @@ pub enum AppConfigOperation {
SetCommitConfig(CommitConfig),
/// Set the shell-command suggestion configuration.
SetSuggestConfig(SuggestConfig),
/// Set the model and provider for a specific agent.
SetAgentModel(AgentId, AgentModelConfig),
/// Clear the per-agent model override, reverting to global defaults.
ClearAgentModel(AgentId),
}
12 changes: 12 additions & 0 deletions crates/forge_main/src/built_in_commands.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@
"command": "config",
"description": "List current configuration values"
},
{
"command": "config-forge-model",
"description": "Set the model for the forge agent [alias: cfm]"
},
{
"command": "config-sage-model",
"description": "Set the model for the sage agent [alias: csgm]"
},
{
"command": "config-muse-model",
"description": "Set the model for the muse agent [alias: cmm]"
},
{
"command": "config-edit",
"description": "Open the global forge config file (~/forge/.forge.toml) in an editor [alias: ce]"
Expand Down
14 changes: 14 additions & 0 deletions crates/forge_main/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,15 @@ pub enum ConfigSetField {
/// Model ID to use for command suggestion generation.
model: ModelId,
},
/// Set the provider and model for a specific agent (forge, sage, or muse).
AgentModel {
/// Agent ID (forge, sage, or muse).
agent: AgentId,
/// Provider ID to use for this agent.
provider: ProviderId,
/// Model ID to use for this agent.
model: ModelId,
},
}

/// Type-safe subcommands for `forge config get`.
Expand All @@ -564,6 +573,11 @@ pub enum ConfigGetField {
Commit,
/// Get the command suggestion generation config.
Suggest,
/// Get the per-agent model configuration for a specific agent.
AgentModel {
/// Agent ID (forge, sage, or muse).
agent: AgentId,
},
}

/// Command group for conversation management.
Expand Down
61 changes: 61 additions & 0 deletions crates/forge_main/src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1330,10 +1330,47 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
.map(|c| c.model.as_str().to_string())
.unwrap_or_else(|| markers::EMPTY.to_string());

// Per-agent model overrides
let forge_config = self
.api
.get_agent_model_config(&AgentId::FORGE)
.await
.ok()
.flatten();
let forge_model = forge_config
.as_ref()
.map(|c| format!("{} ({})", c.model.as_str(), c.provider))
.unwrap_or_else(|| markers::EMPTY.to_string());

let sage_config = self
.api
.get_agent_model_config(&AgentId::SAGE)
.await
.ok()
.flatten();
let sage_model = sage_config
.as_ref()
.map(|c| format!("{} ({})", c.model.as_str(), c.provider))
.unwrap_or_else(|| markers::EMPTY.to_string());

let muse_config = self
.api
.get_agent_model_config(&AgentId::MUSE)
.await
.ok()
.flatten();
let muse_model = muse_config
.as_ref()
.map(|c| format!("{} ({})", c.model.as_str(), c.provider))
.unwrap_or_else(|| markers::EMPTY.to_string());

let info = Info::new()
.add_title("CONFIGURATION")
.add_key_value("Default Model", model)
.add_key_value("Default Provider", provider)
.add_key_value("Forge Model", forge_model)
.add_key_value("Sage Model", sage_model)
.add_key_value("Muse Model", muse_model)
.add_key_value("Commit Provider", commit_provider)
.add_key_value("Commit Model", commit_model)
.add_key_value("Suggest Provider", suggest_provider)
Expand Down Expand Up @@ -3444,6 +3481,20 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
format!("is now the suggest model for provider '{provider}'"),
))?;
}
ConfigSetField::AgentModel { agent, provider, model } => {
// Validate provider exists and model belongs to that specific provider
let validated_model = self.validate_model(model.as_str(), Some(&provider)).await?;
let agent_config = forge_domain::AgentModelConfig {
provider: provider.clone(),
model: validated_model.clone(),
};
self.api
.set_agent_model_config(agent.clone(), agent_config)
.await?;
self.writeln_title(TitleFormat::action(validated_model.as_str()).sub_title(
format!("is now the model for agent '{agent}' (provider: '{provider}')"),
))?;
}
}

Ok(())
Expand Down Expand Up @@ -3505,6 +3556,16 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
None => self.writeln("Suggest: Not set")?,
}
}
ConfigGetField::AgentModel { agent } => {
let agent_config = self.api.get_agent_model_config(&agent).await?;
match agent_config {
Some(config) => {
self.writeln(config.provider.as_ref())?;
self.writeln(config.model.as_str().to_string())?;
}
None => self.writeln(format!("Agent model for '{agent}': Not set"))?,
}
}
}

Ok(())
Expand Down
Loading
Loading