Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ pub trait API: Sync + Send {
/// Provides a list of models available in the current environment
async fn get_models(&self) -> Result<Vec<Model>>;

/// Provides models from all configured providers. Providers that fail to
/// return models are silently skipped.
/// Provides models from all configured providers.
/// Returns an error if any configured provider fails to return models.
async fn get_all_provider_models(&self) -> Result<Vec<ProviderModels>>;

/// Provides a list of agents available in the current environment
Expand Down
28 changes: 14 additions & 14 deletions crates/forge_app/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use anyhow::Result;
use anyhow::{Context, Result};
use chrono::Local;
use forge_domain::*;
use forge_stream::MpscStream;
Expand Down Expand Up @@ -260,8 +260,8 @@ impl<S: Services> ForgeApp<S> {
/// Gets available models from all configured providers concurrently.
///
/// Returns a list of `ProviderModels` for each configured provider.
/// All providers are queried in parallel; providers that fail to
/// return models are silently skipped.
/// All providers are queried in parallel and the first provider error is
/// returned to the caller.
pub async fn get_all_provider_models(&self) -> Result<Vec<ProviderModels>> {
let all_providers = self.services.get_all_providers().await?;

Expand All @@ -277,20 +277,20 @@ impl<S: Services> ForgeApp<S> {
.provider_auth_service()
.refresh_provider_credential(provider)
.await
.ok()?;
let models = services.models(refreshed).await.ok()?;
Some(ProviderModels { provider_id, models })
.with_context(|| {
format!("Failed to refresh credentials for provider '{provider_id}'")
})?;
let models = services.models(refreshed).await.with_context(|| {
format!("Failed to fetch models for provider '{provider_id}'")
})?;

Ok(ProviderModels { provider_id, models })
}
})
.collect();

// Execute all provider fetches concurrently and collect successful results
let results = futures::future::join_all(futures)
.await
.into_iter()
.flatten()
.collect();

Ok(results)
// Execute all provider fetches concurrently and fail fast on errors so
// callers such as login/model selection can surface the root cause.
futures::future::try_join_all(futures).await
}
}
56 changes: 34 additions & 22 deletions crates/forge_services/src/provider_auth.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;
use std::time::Duration;

use anyhow::Context;
use forge_app::{AuthStrategy, ProviderAuthService, StrategyFactory};
use forge_domain::{
AuthContextRequest, AuthContextResponse, AuthMethod, Provider, ProviderId, ProviderRepository,
Expand Down Expand Up @@ -136,7 +137,8 @@ where
/// Checks if credential needs refresh (5 minute buffer before expiry),
/// iterates through provider's auth methods, and attempts to refresh.
/// Returns the provider with updated credentials, or original if refresh
/// fails or isn't needed.
/// isn't needed. Returns an error if refresh is needed but fails for all
/// configured auth methods.
async fn refresh_provider_credential(
&self,
mut provider: Provider<url::Url>,
Expand All @@ -146,6 +148,8 @@ where
let buffer = chrono::Duration::minutes(5);

if credential.needs_refresh(buffer) {
let mut last_error: Option<anyhow::Error> = None;

// Iterate through auth methods and try to refresh
for auth_method in &provider.auth_methods {
match auth_method {
Expand All @@ -169,37 +173,45 @@ where
};

// Create strategy and refresh credential
if let Ok(strategy) = self.infra.create_auth_strategy(
let strategy = match self.infra.create_auth_strategy(
provider.id.clone(),
auth_method.clone(),
required_params,
) {
match strategy.refresh(&existing_credential).await {
Ok(refreshed) => {
// Store refreshed credential
if self
.infra
.upsert_credential(refreshed.clone())
.await
.is_err()
{
continue;
}

// Update provider with refreshed credential
provider.credential = Some(refreshed);
break; // Success, stop trying other methods
}
Err(_) => {
// If refresh fails, continue with
// existing credentials
}
Ok(s) => s,
Err(err) => {
last_error = Some(err);
continue;
}
};

match strategy.refresh(&existing_credential).await {
Ok(refreshed) => {
// Store refreshed credential
self.infra.upsert_credential(refreshed.clone()).await?;

// Update provider with refreshed credential
provider.credential = Some(refreshed);
return Ok(provider);
}
Err(err) => {
last_error = Some(err);
}
}
}
_ => {}
}
}

// If we got here, all auth methods failed
if let Some(err) = last_error {
return Err(err).with_context(|| {
format!(
"Failed to refresh credentials for provider '{}'",
provider.id
)
});
}
}
}

Expand Down
Loading