diff --git a/src/caps.rs b/src/caps.rs index 8d07f98ac..10fffbb40 100644 --- a/src/caps.rs +++ b/src/caps.rs @@ -1,15 +1,16 @@ -use std::path::PathBuf; +use hyper::StatusCode; +use serde::Deserialize; +use serde::Serialize; +use serde_json::Value; use std::collections::HashMap; use std::fs::File; use std::io::Read; +use std::path::PathBuf; use std::sync::Arc; use std::sync::RwLock as StdRwLock; -use serde::Deserialize; -use serde::Serialize; -use serde_json::Value; use tokio::sync::RwLock as ARwLock; -use url::Url; use tracing::{error, info, warn}; +use url::Url; use crate::custom_error::ScratchError; use crate::global_context::{try_load_caps_quickly_if_not_present, GlobalContext}; @@ -308,15 +309,21 @@ async fn load_caps_buf_from_file( async fn load_caps_buf_from_url( cmdline: crate::global_context::CommandLine, gcx: Arc>, -) -> Result<(String, String), String> { +) -> Result<(String, String), ScratchError> { let mut buffer = String::new(); let mut caps_urls: Vec = Vec::new(); if cmdline.address_url.to_lowercase() == "refact" { caps_urls.push("https://inference.smallcloud.ai/coding_assistant_caps.json".to_string()); } else { - let base_url = Url::parse(&cmdline.address_url.clone()).map_err(|_| "failed to parse address url (1)".to_string())?; - let joined_url = base_url.join(&CAPS_FILENAME).map_err(|_| "failed to parse address url (2)".to_string())?; - let joined_url_fallback = base_url.join(&CAPS_FILENAME_FALLBACK).map_err(|_| "failed to parse address url (2)".to_string())?; + let base_url = Url::parse(&cmdline.address_url.clone()).map_err(|_| { + ScratchError::new_internal("failed to parse address url (1)".to_string()) + })?; + let joined_url = base_url.join(&CAPS_FILENAME).map_err(|_| { + ScratchError::new_internal("failed to parse address url (2)".to_string()) + })?; + let joined_url_fallback = base_url.join(&CAPS_FILENAME_FALLBACK).map_err(|_| { + ScratchError::new_internal("failed to parse address url (2)".to_string()) + })?; caps_urls.push(joined_url.to_string()); caps_urls.push(joined_url_fallback.to_string()); } @@ -332,11 +339,25 @@ async fn load_caps_buf_from_url( let mut status: u16 = 0; for url in caps_urls.iter() { info!("fetching caps from {}", url); - let response = http_client.get(url).headers(headers.clone()).send().await.map_err(|e| format!("{}", e))?; + let response = http_client + .get(url) + .headers(headers.clone()) + .send() + .await + .map_err(|e| { + ScratchError::new( + e.status() + .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR) + .as_u16() + .try_into() + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + e.to_string(), + ) + })?; status = response.status().as_u16(); buffer = match response.text().await { Ok(v) => v, - Err(_) => continue + Err(_) => continue, }; if status == 200 { @@ -349,18 +370,40 @@ async fn load_caps_buf_from_url( let response_json: serde_json::Result = serde_json::from_str(&buffer); return if let Ok(response_json) = response_json { if let Some(detail) = response_json.get("detail") { - Err(detail.as_str().unwrap().to_string()) + Err(ScratchError::new( + status + .try_into() + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + detail.as_str().unwrap().to_string(), + )) } else { - Err(format!("cannot fetch caps, status={}", status)) + Err(ScratchError::new( + status + .try_into() + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + format!("cannot fetch caps, status={}", status), + )) } } else { - Err(format!("cannot fetch caps, status={}", status)) + Err(ScratchError::new( + status + .try_into() + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + format!("cannot fetch caps, status={}", status), + )) }; } let caps_url: String = match caps_urls.get(0) { Some(u) => u.clone(), - None => return Err("caps_url is none".to_string()) + None => { + return Err(ScratchError::new( + status + .try_into() + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + "caps_url is none".to_string(), + )) + } }; Ok((buffer, caps_url)) @@ -369,15 +412,17 @@ async fn load_caps_buf_from_url( pub async fn load_caps( cmdline: crate::global_context::CommandLine, gcx: Arc>, -) -> Result>, String> { +) -> Result>, ScratchError> { let mut caps_url = cmdline.address_url.clone(); let buf: String; if caps_url.to_lowercase() == "refact" || caps_url.starts_with("http") { (buf, caps_url) = load_caps_buf_from_url(cmdline, gcx).await? } else { - (buf, caps_url) = load_caps_buf_from_file(cmdline, gcx).await? + (buf, caps_url) = load_caps_buf_from_file(cmdline, gcx) + .await + .map_err(ScratchError::new_internal)? } - load_caps_from_buf(&buf, &caps_url) + load_caps_from_buf(&buf, &caps_url).map_err(ScratchError::new_internal) } pub fn strip_model_from_finetune(model: &String) -> String { diff --git a/src/custom_error.rs b/src/custom_error.rs index d537be3da..e5f36ada8 100644 --- a/src/custom_error.rs +++ b/src/custom_error.rs @@ -49,14 +49,23 @@ impl ScratchError { } } + /// This is a helper function to create a new [`ScratchError`] + /// with `status_code` = `INTERNAL_SERVER_ERROR` + pub fn new_internal(message: String) -> Self { + ScratchError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + message, + telemetry_skip: false, + } + } + pub fn to_response(&self) -> Response { let body = json!({"detail": self.message}).to_string(); error!("client will see {}", body); - let response = Response::builder() + Response::builder() .status(self.status_code) .header("Content-Type", "application/json") .body(Body::from(body)) - .unwrap(); - response + .unwrap() } } diff --git a/src/global_context.rs b/src/global_context.rs index fa374c8be..a619b7c40 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -146,7 +146,7 @@ pub struct GlobalContext { pub config_dir: PathBuf, pub caps: Option>>, pub caps_reading_lock: Arc>, - pub caps_last_error: String, + pub caps_last_error: Option, pub caps_last_attempted_ts: u64, pub tokenizer_map: HashMap< String, Arc>>, pub tokenizer_download_lock: Arc>, @@ -229,7 +229,10 @@ pub async fn try_load_caps_quickly_if_not_present( } if caps_last_attempted_ts + CAPS_RELOAD_BACKOFF > now { let gcx_locked = gcx.write().await; - return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, gcx_locked.caps_last_error.clone())); + info!("Returning caps error from cache"); + return Err(gcx_locked.caps_last_error.clone().unwrap_or( + ScratchError::new_internal("Expected cached error, but none found".to_string()), + )); } } @@ -244,13 +247,19 @@ pub async fn try_load_caps_quickly_if_not_present( match caps_result { Ok(caps) => { gcx_locked.caps = Some(caps.clone()); - gcx_locked.caps_last_error = "".to_string(); + gcx_locked.caps_last_error = None; Ok(caps) - }, + } Err(e) => { error!("caps fetch failed: {:?}", e); - gcx_locked.caps_last_error = format!("caps fetch failed: {}", e); - return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, gcx_locked.caps_last_error.clone())); + gcx_locked.caps_last_error = Some(ScratchError::new( + e.status_code, + format!("caps fetch failed: {}", e.message), + )); + Err(gcx_locked + .caps_last_error + .clone() + .expect("The previous line is assigning it")) } } } @@ -355,7 +364,7 @@ pub async fn create_global_context( config_dir, caps: None, caps_reading_lock: Arc::new(AMutex::::new(false)), - caps_last_error: String::new(), + caps_last_error: None, caps_last_attempted_ts: 0, tokenizer_map: HashMap::new(), tokenizer_download_lock: Arc::new(AMutex::::new(false)), diff --git a/src/http/routers/v1/caps.rs b/src/http/routers/v1/caps.rs index aa1352471..e6d646b91 100644 --- a/src/http/routers/v1/caps.rs +++ b/src/http/routers/v1/caps.rs @@ -1,9 +1,8 @@ use std::sync::Arc; use tokio::sync::RwLock as ARwLock; -use axum::Extension; -use axum::response::Result; -use hyper::{Body, Response, StatusCode}; +use axum::{response::Result, Extension}; +use hyper::{Body, Response}; use crate::custom_error::ScratchError; use crate::global_context::GlobalContext; @@ -25,16 +24,10 @@ pub async fn handle_v1_caps( Extension(global_context): Extension>>, _: hyper::body::Bytes, ) -> Result, ScratchError> { - let caps_result = crate::global_context::try_load_caps_quickly_if_not_present( - global_context.clone(), - 0, - ).await; - let caps_arc = match caps_result { - Ok(x) => x, - Err(e) => { - return Err(ScratchError::new(StatusCode::SERVICE_UNAVAILABLE, format!("{}", e))); - } - }; + let caps_arc = + crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0) + .await?; + let caps_locked = caps_arc.read().unwrap(); let body = serde_json::to_string_pretty(&*caps_locked).unwrap(); let response = Response::builder()