Skip to content
This repository was archived by the owner on Feb 21, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 63 additions & 18 deletions src/caps.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::path::PathBuf;
use hyper::StatusCode;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use tokio::sync::RwLock as ARwLock;
use url::Url;
use tracing::{error, info, warn};
use url::Url;

use crate::custom_error::ScratchError;
use crate::global_context::{try_load_caps_quickly_if_not_present, GlobalContext};
Expand Down Expand Up @@ -308,15 +309,21 @@ async fn load_caps_buf_from_file(
async fn load_caps_buf_from_url(
cmdline: crate::global_context::CommandLine,
gcx: Arc<ARwLock<GlobalContext>>,
) -> Result<(String, String), String> {
) -> Result<(String, String), ScratchError> {
let mut buffer = String::new();
let mut caps_urls: Vec<String> = Vec::new();
if cmdline.address_url.to_lowercase() == "refact" {
caps_urls.push("https://inference.smallcloud.ai/coding_assistant_caps.json".to_string());
} else {
let base_url = Url::parse(&cmdline.address_url.clone()).map_err(|_| "failed to parse address url (1)".to_string())?;
let joined_url = base_url.join(&CAPS_FILENAME).map_err(|_| "failed to parse address url (2)".to_string())?;
let joined_url_fallback = base_url.join(&CAPS_FILENAME_FALLBACK).map_err(|_| "failed to parse address url (2)".to_string())?;
let base_url = Url::parse(&cmdline.address_url.clone()).map_err(|_| {
ScratchError::new_internal("failed to parse address url (1)".to_string())
})?;
let joined_url = base_url.join(&CAPS_FILENAME).map_err(|_| {
ScratchError::new_internal("failed to parse address url (2)".to_string())
})?;
let joined_url_fallback = base_url.join(&CAPS_FILENAME_FALLBACK).map_err(|_| {
ScratchError::new_internal("failed to parse address url (2)".to_string())
})?;
caps_urls.push(joined_url.to_string());
caps_urls.push(joined_url_fallback.to_string());
}
Expand All @@ -332,11 +339,25 @@ async fn load_caps_buf_from_url(
let mut status: u16 = 0;
for url in caps_urls.iter() {
info!("fetching caps from {}", url);
let response = http_client.get(url).headers(headers.clone()).send().await.map_err(|e| format!("{}", e))?;
let response = http_client
.get(url)
.headers(headers.clone())
.send()
.await
.map_err(|e| {
ScratchError::new(
e.status()
.unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR)
.as_u16()
.try_into()
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
e.to_string(),
)
})?;
status = response.status().as_u16();
buffer = match response.text().await {
Ok(v) => v,
Err(_) => continue
Err(_) => continue,
};

if status == 200 {
Expand All @@ -349,18 +370,40 @@ async fn load_caps_buf_from_url(
let response_json: serde_json::Result<Value> = serde_json::from_str(&buffer);
return if let Ok(response_json) = response_json {
if let Some(detail) = response_json.get("detail") {
Err(detail.as_str().unwrap().to_string())
Err(ScratchError::new(
status
.try_into()
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
detail.as_str().unwrap().to_string(),
))
} else {
Err(format!("cannot fetch caps, status={}", status))
Err(ScratchError::new(
status
.try_into()
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
format!("cannot fetch caps, status={}", status),
))
}
} else {
Err(format!("cannot fetch caps, status={}", status))
Err(ScratchError::new(
status
.try_into()
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
format!("cannot fetch caps, status={}", status),
))
};
}

let caps_url: String = match caps_urls.get(0) {
Some(u) => u.clone(),
None => return Err("caps_url is none".to_string())
None => {
return Err(ScratchError::new(
status
.try_into()
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
"caps_url is none".to_string(),
))
}
};

Ok((buffer, caps_url))
Expand All @@ -369,15 +412,17 @@ async fn load_caps_buf_from_url(
pub async fn load_caps(
cmdline: crate::global_context::CommandLine,
gcx: Arc<ARwLock<GlobalContext>>,
) -> Result<Arc<StdRwLock<CodeAssistantCaps>>, String> {
) -> Result<Arc<StdRwLock<CodeAssistantCaps>>, ScratchError> {
let mut caps_url = cmdline.address_url.clone();
let buf: String;
if caps_url.to_lowercase() == "refact" || caps_url.starts_with("http") {
(buf, caps_url) = load_caps_buf_from_url(cmdline, gcx).await?
} else {
(buf, caps_url) = load_caps_buf_from_file(cmdline, gcx).await?
(buf, caps_url) = load_caps_buf_from_file(cmdline, gcx)
.await
.map_err(ScratchError::new_internal)?
}
load_caps_from_buf(&buf, &caps_url)
load_caps_from_buf(&buf, &caps_url).map_err(ScratchError::new_internal)
}

pub fn strip_model_from_finetune(model: &String) -> String {
Expand Down
15 changes: 12 additions & 3 deletions src/custom_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,23 @@ impl ScratchError {
}
}

/// This is a helper function to create a new [`ScratchError`]
/// with `status_code` = `INTERNAL_SERVER_ERROR`
pub fn new_internal(message: String) -> Self {
ScratchError {
status_code: StatusCode::INTERNAL_SERVER_ERROR,
message,
telemetry_skip: false,
}
}

pub fn to_response(&self) -> Response<Body> {
let body = json!({"detail": self.message}).to_string();
error!("client will see {}", body);
let response = Response::builder()
Response::builder()
.status(self.status_code)
.header("Content-Type", "application/json")
.body(Body::from(body))
.unwrap();
response
.unwrap()
}
}
23 changes: 16 additions & 7 deletions src/global_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ pub struct GlobalContext {
pub config_dir: PathBuf,
pub caps: Option<Arc<StdRwLock<CodeAssistantCaps>>>,
pub caps_reading_lock: Arc<AMutex<bool>>,
pub caps_last_error: String,
pub caps_last_error: Option<ScratchError>,
pub caps_last_attempted_ts: u64,
pub tokenizer_map: HashMap< String, Arc<StdRwLock<Tokenizer>>>,
pub tokenizer_download_lock: Arc<AMutex<bool>>,
Expand Down Expand Up @@ -229,7 +229,10 @@ pub async fn try_load_caps_quickly_if_not_present(
}
if caps_last_attempted_ts + CAPS_RELOAD_BACKOFF > now {
let gcx_locked = gcx.write().await;
return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, gcx_locked.caps_last_error.clone()));
info!("Returning caps error from cache");
return Err(gcx_locked.caps_last_error.clone().unwrap_or(
ScratchError::new_internal("Expected cached error, but none found".to_string()),
));
}
}

Expand All @@ -244,13 +247,19 @@ pub async fn try_load_caps_quickly_if_not_present(
match caps_result {
Ok(caps) => {
gcx_locked.caps = Some(caps.clone());
gcx_locked.caps_last_error = "".to_string();
gcx_locked.caps_last_error = None;
Ok(caps)
},
}
Err(e) => {
error!("caps fetch failed: {:?}", e);
gcx_locked.caps_last_error = format!("caps fetch failed: {}", e);
return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, gcx_locked.caps_last_error.clone()));
gcx_locked.caps_last_error = Some(ScratchError::new(
e.status_code,
format!("caps fetch failed: {}", e.message),
));
Err(gcx_locked
.caps_last_error
.clone()
.expect("The previous line is assigning it"))
}
}
}
Expand Down Expand Up @@ -355,7 +364,7 @@ pub async fn create_global_context(
config_dir,
caps: None,
caps_reading_lock: Arc::new(AMutex::<bool>::new(false)),
caps_last_error: String::new(),
caps_last_error: None,
caps_last_attempted_ts: 0,
tokenizer_map: HashMap::new(),
tokenizer_download_lock: Arc::new(AMutex::<bool>::new(false)),
Expand Down
19 changes: 6 additions & 13 deletions src/http/routers/v1/caps.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::sync::Arc;
use tokio::sync::RwLock as ARwLock;

use axum::Extension;
use axum::response::Result;
use hyper::{Body, Response, StatusCode};
use axum::{response::Result, Extension};
use hyper::{Body, Response};

use crate::custom_error::ScratchError;
use crate::global_context::GlobalContext;
Expand All @@ -25,16 +24,10 @@ pub async fn handle_v1_caps(
Extension(global_context): Extension<Arc<ARwLock<GlobalContext>>>,
_: hyper::body::Bytes,
) -> Result<Response<Body>, ScratchError> {
let caps_result = crate::global_context::try_load_caps_quickly_if_not_present(
global_context.clone(),
0,
).await;
let caps_arc = match caps_result {
Ok(x) => x,
Err(e) => {
return Err(ScratchError::new(StatusCode::SERVICE_UNAVAILABLE, format!("{}", e)));
}
};
let caps_arc =
crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0)
.await?;

let caps_locked = caps_arc.read().unwrap();
let body = serde_json::to_string_pretty(&*caps_locked).unwrap();
let response = Response::builder()
Expand Down