diff --git a/Cargo.lock b/Cargo.lock index 2628020d..be47fb6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2571,6 +2571,7 @@ dependencies = [ "mime_guess", "pingora", "ratatui", + "regex", "reqwest", "reqwest-eventsource", "rust-embed", diff --git a/Cargo.toml b/Cargo.toml index 042e5f5a..86d03cce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ serde_json = "1.0.140" tokio = { version = "1.45.1", features = ["full"] } tokio-stream = { version = "0.1.17", features = ["sync"] } url = { version = "2.5.4", features = ["serde"] } +regex = "1.11.1" chrono = { version = "0.4.41", optional = true } crossterm = { version = "0.28.1", features = ["event-stream"], optional = true } diff --git a/resources/ts/components/AgentsList.tsx b/resources/ts/components/AgentsList.tsx index 8b41aee9..1c0cbcfb 100644 --- a/resources/ts/components/AgentsList.tsx +++ b/resources/ts/components/AgentsList.tsx @@ -1,5 +1,5 @@ import clsx from "clsx"; -import React, { CSSProperties } from "react"; +import React, { CSSProperties, useState } from "react"; import { type Agent } from "../schemas/Agent"; @@ -9,27 +9,145 @@ import { agentUsage, agentUsage__progress, agentsTable, + sortIndicator, + sortIndicatorAsc, + sortIndicatorDesc, } from "./Dashboard.module.css"; function formatTimestamp(timestamp: number): string { return new Date(timestamp * 1000).toLocaleString(); } +type SortColumn = + | "name" + | "model" + | "issues" + | "llamacppAddr" + | "lastUpdate" + | "idleSlots" + | "processingSlots"; + +function getSortIndicator( + sortConfig: { key: SortColumn; direction: "ascending" | "descending" }, + currentKey: SortColumn +): React.ReactNode { + if (sortConfig.key !== currentKey) { + return null; + } + const className = clsx(sortIndicator, sortConfig.direction === "ascending" ? sortIndicatorAsc : sortIndicatorDesc); + return ( + + {sortConfig.direction === "ascending" ? "↑" : "↓"} + + ); +} + export function AgentsList({ agents }: { agents: Array }) { + const [sortConfig, setSortConfig] = useState<{ + key: SortColumn; + direction: "ascending" | "descending"; + }>({ key: "name", direction: "ascending" }); + + function sortAgents(agents: Array): Array { + const sortableAgents = [...agents]; + sortableAgents.sort(function (a, b) { + const { key, direction } = sortConfig; + + // Helper function to get comparison value based on column type + function getValue(agent: Agent, key: SortColumn): string | number { + switch (key) { + case "name": + return agent.status.agent_name || ""; + case "model": + return agent.status.model || ""; + case "llamacppAddr": + return agent.status.external_llamacpp_addr; + case "lastUpdate": + return agent.last_update.secs_since_epoch; + case "idleSlots": + return agent.status.slots_idle; + case "processingSlots": + return agent.status.slots_processing; + default: + return ""; + } + } + + // Special handling for issues column + if (key === "issues") { + const hasIssuesA = a.status.error !== null; + const hasIssuesB = b.status.error !== null; + if (hasIssuesA !== hasIssuesB) { + return direction === "ascending" ? (hasIssuesA ? 1 : -1) : (hasIssuesA ? -1 : 1); + } + const errorA = a.status.error || ""; + const errorB = b.status.error || ""; + if (errorA < errorB) return direction === "ascending" ? -1 : 1; + if (errorA > errorB) return direction === "ascending" ? 1 : -1; + return 0; + } + + const valueA = getValue(a, key); + const valueB = getValue(b, key); + + // Handle string comparison + if (typeof valueA === "string" && typeof valueB === "string") { + if (valueA < valueB) return direction === "ascending" ? -1 : 1; + if (valueA > valueB) return direction === "ascending" ? 1 : -1; + return 0; + } + + // Handle numeric comparison + if (typeof valueA === "number" && typeof valueB === "number") { + if (valueA < valueB) return direction === "ascending" ? -1 : 1; + if (valueA > valueB) return direction === "ascending" ? 1 : -1; + return 0; + } + + return 0; + }); + return sortableAgents; + } + + function requestSort(key: SortColumn) { + let direction: "ascending" | "descending" = "ascending"; + if (sortConfig.key === key && sortConfig.direction === "ascending") { + direction = "descending"; + } + setSortConfig({ key, direction }); + } + + const sortedAgents = sortAgents(agents); + return ( - - - - - - + + + + + + + - {agents.map(function ({ + {sortedAgents.map(function ({ agent_id, last_update, quarantined_until, @@ -47,13 +165,12 @@ export function AgentsList({ agents }: { agents: Array }) { quarantined_until; return ( - + +
NameIssuesLlama.cpp addressLast updateIdle slotsProcessing slots + Name{getSortIndicator(sortConfig, "name")} + + Model{getSortIndicator(sortConfig, "model")} + + Issues{getSortIndicator(sortConfig, "issues")} + + Llama.cpp address{getSortIndicator(sortConfig, "llamacppAddr")} + + Last update{getSortIndicator(sortConfig, "lastUpdate")} + + Idle slots{getSortIndicator(sortConfig, "idleSlots")} + + Processing slots{getSortIndicator(sortConfig, "processingSlots")} +
{status.agent_name}{status.model} {status.error && ( <> diff --git a/resources/ts/components/Dashboard.module.css b/resources/ts/components/Dashboard.module.css index 86f49b91..85458e4e 100644 --- a/resources/ts/components/Dashboard.module.css +++ b/resources/ts/components/Dashboard.module.css @@ -7,11 +7,29 @@ th { border: 1px solid var(--color-border); padding: var(--spacing-base); + cursor: pointer; p + p { margin-top: var(--spacing-half); } } + + th:hover { + background-color: var(--color-hover); + } +} + +.sortIndicator { + margin-left: var(--spacing-half); + font-size: 0.8em; +} + +.sortIndicatorAsc { + color: var(--color-success); +} + +.sortIndicatorDesc { + color: var(--color-error); } .agentRow.agentRowError { diff --git a/resources/ts/schemas/Agent.ts b/resources/ts/schemas/Agent.ts index 678bd124..c2eea84d 100644 --- a/resources/ts/schemas/Agent.ts +++ b/resources/ts/schemas/Agent.ts @@ -5,6 +5,7 @@ import { StatusUpdateSchema } from "./StatusUpdate"; export const AgentSchema = z .object({ agent_id: z.string(), + model: z.string().nullable(), last_update: z.object({ nanos_since_epoch: z.number(), secs_since_epoch: z.number(), diff --git a/resources/ts/schemas/StatusUpdate.ts b/resources/ts/schemas/StatusUpdate.ts index 7543d925..52747e5b 100644 --- a/resources/ts/schemas/StatusUpdate.ts +++ b/resources/ts/schemas/StatusUpdate.ts @@ -14,6 +14,7 @@ export const StatusUpdateSchema = z is_unexpected_response_status: z.boolean().nullable(), slots_idle: z.number(), slots_processing: z.number(), + model: z.string().nullable(), }) .strict(); diff --git a/src/agent/monitoring_service.rs b/src/agent/monitoring_service.rs index 4837493f..1eb14658 100644 --- a/src/agent/monitoring_service.rs +++ b/src/agent/monitoring_service.rs @@ -23,6 +23,7 @@ pub struct MonitoringService { monitoring_interval: Duration, name: Option, status_update_tx: Sender, + check_model: bool, // Store the check_model flag } impl MonitoringService { @@ -32,6 +33,7 @@ impl MonitoringService { monitoring_interval: Duration, name: Option, status_update_tx: Sender, + check_model: bool, // Include the check_model flag ) -> Result { Ok(MonitoringService { external_llamacpp_addr, @@ -39,6 +41,7 @@ impl MonitoringService { monitoring_interval, name, status_update_tx, + check_model, }) } @@ -50,6 +53,15 @@ impl MonitoringService { .filter(|slot| slot.is_processing) .count(); + let model: Option = if self.check_model { + match self.llamacpp_client.get_model().await { + Ok(model) => model, + Err(_) => None, + } + } else { + Some("".to_string()) + }; + StatusUpdate { agent_name: self.name.to_owned(), error: slots_response.error, @@ -63,6 +75,7 @@ impl MonitoringService { is_unexpected_response_status: slots_response.is_unexpected_response_status, slots_idle: slots_response.slots.len() - slots_processing, slots_processing, + model, } } @@ -109,4 +122,4 @@ impl Service for MonitoringService { fn threads(&self) -> Option { Some(1) } -} +} \ No newline at end of file diff --git a/src/balancer/proxy_service.rs b/src/balancer/proxy_service.rs index cfab2dfd..9eeb1a54 100644 --- a/src/balancer/proxy_service.rs +++ b/src/balancer/proxy_service.rs @@ -6,6 +6,7 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use log::error; +use log::info; use pingora::http::RequestHeader; use pingora::proxy::ProxyHttp; use pingora::proxy::Session; @@ -41,6 +42,7 @@ pub struct ProxyService { buffered_request_timeout: Duration, max_buffered_requests: usize, rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, upstream_peer_pool: Arc, } @@ -48,6 +50,7 @@ pub struct ProxyService { impl ProxyService { pub fn new( rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, upstream_peer_pool: Arc, buffered_request_timeout: Duration, @@ -55,6 +58,7 @@ impl ProxyService { ) -> Self { Self { rewrite_host_header, + check_model, slots_endpoint_enable, upstream_peer_pool, buffered_request_timeout, @@ -73,6 +77,7 @@ impl ProxyHttp for ProxyService { slot_taken: false, upstream_peer_pool: self.upstream_peer_pool.clone(), uses_slots: false, + requested_model: Some("".to_string()), } } @@ -180,10 +185,108 @@ impl ProxyHttp for ProxyService { } "/chat/completions" => true, "/completion" => true, + "/v1/completions" => true, "/v1/chat/completions" => true, _ => false, }; + info!("upstream_peer - {:?} request | rewrite_host_header? {} check_model? {}", session.req_header().method, self.rewrite_host_header, self.check_model); + + // Check if the request method is POST and the content type is JSON + if self.check_model && ctx.uses_slots { + info!("Checking model..."); + ctx.requested_model = None; + if session.req_header().method == "POST" { + // Check if the content type is application/json + if let Some(content_type) = session.get_header("Content-Type") { + if let Ok(content_type_str) = content_type.to_str() { + if content_type_str.contains("application/json") { + // Enable retry buffering to preserve the request body, reference: https://github.com/cloudflare/pingora/issues/349#issuecomment-2377277028 + session.enable_retry_buffering(); + session.read_body_or_idle(false).await.unwrap().unwrap(); + let request_body = session.get_retry_buffer(); + + if let Some(body_bytes) = request_body { + match std::str::from_utf8(&body_bytes) { + Ok(_) => { + // The bytes are valid UTF-8, proceed as normal + if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { + if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { + ctx.requested_model = Some(model.to_string()); + info!("Model in request: {:?}", ctx.requested_model); + } + } else { + info!("Failed to parse JSON payload, trying regex extraction"); + let body_str = String::from_utf8_lossy(&body_bytes).to_string(); + let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap(); + if let Some(caps) = re.captures(&body_str) { + if let Some(model) = caps.get(1) { + ctx.requested_model = Some(model.as_str().to_string()); + info!("Model via regex: {:?}", ctx.requested_model); + } + } else { + info!("Failed to extract model using regex"); + } + } + }, + Err(e) => { + // Invalid UTF-8 detected. Truncate to the last valid UTF-8 boundary. + let valid_up_to = e.valid_up_to(); + info!("Invalid UTF-8 detected. Truncating from {} bytes to {} bytes.", body_bytes.len(), valid_up_to); + + // Create a new `Bytes` slice containing only the valid UTF-8 part. + let valid_body_bytes = body_bytes.slice(0..valid_up_to); + + // Now proceed with the (truncated) valid_body_bytes + if let Ok(json_value) = serde_json::from_slice::(&valid_body_bytes) { + if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { + ctx.requested_model = Some(model.to_string()); + info!("Model in request (after truncation): {:?}", ctx.requested_model); + } + } else { + info!("Failed to parse JSON payload (after truncation), trying regex extraction"); + let body_str = String::from_utf8_lossy(&valid_body_bytes).to_string(); + let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap(); + if let Some(caps) = re.captures(&body_str) { + if let Some(model) = caps.get(1) { + ctx.requested_model = Some(model.as_str().to_string()); + info!("Model via regex (after truncation): {:?}", ctx.requested_model); + } + } else { + info!("Failed to extract model using regex (after truncation)"); + } + } + } + } + } else { + info!("Request body is None"); + } + } + } + } + } + // abort if model has not been set + if ctx.requested_model == None { + info!("Model missing in request"); + session + .respond_error(pingora::http::StatusCode::BAD_REQUEST.as_u16()) + .await?; + + return Err(Error::new_down(pingora::ErrorType::ConnectRefused)); + } + else if ctx.has_peer_supporting_model() == false { + info!("Model {:?} not supported by upstream", ctx.requested_model); + session + .respond_error(pingora::http::StatusCode::NOT_FOUND.as_u16()) + .await?; + + return Err(Error::new_down(pingora::ErrorType::ConnectRefused)); + } + else { + info!("Model {:?}", ctx.requested_model); + } + } + let peer = tokio::select! { result = async { loop { diff --git a/src/balancer/request_context.rs b/src/balancer/request_context.rs index eb8b1c11..da308717 100644 --- a/src/balancer/request_context.rs +++ b/src/balancer/request_context.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use anyhow::anyhow; use log::error; +use log::info; use pingora::Error; use pingora::Result; @@ -13,6 +14,7 @@ pub struct RequestContext { pub selected_peer: Option, pub upstream_peer_pool: Arc, pub uses_slots: bool, + pub requested_model: Option, } impl RequestContext { @@ -30,16 +32,19 @@ impl RequestContext { } } - pub fn use_best_peer_and_take_slot(&mut self) -> anyhow::Result> { + pub fn use_best_peer_and_take_slot(&mut self, model: Option) -> anyhow::Result> { if let Some(peer) = self.upstream_peer_pool.with_agents_write(|agents| { + let model_str = model.as_deref().unwrap_or(""); for peer in agents.iter_mut() { - if peer.is_usable() { - peer.take_slot()?; + let is_usable = peer.is_usable(); + let is_usable_for_model = peer.is_usable_for_model(model_str); + if is_usable && (model.is_none() || is_usable_for_model) { + info!("Peer {} is usable: {}, usable for model '{}': {}", peer.agent_id, is_usable, model_str, is_usable_for_model); + peer.take_slot()?; return Ok(Some(peer.clone())); } } - Ok(None) })? { self.upstream_peer_pool.restore_integrity()?; @@ -52,11 +57,26 @@ impl RequestContext { } } + pub fn has_peer_supporting_model(&self) -> bool { + let model_str = self.requested_model.as_deref().unwrap_or(""); + match self.upstream_peer_pool.with_agents_read(|agents| { + for peer in agents.iter() { + if peer.supports_model(model_str) { + return Ok(true); + } + } + Ok(false) + }) { + Ok(result) => result, + Err(_) => false, // or handle the error as needed + } + } + pub fn select_upstream_peer(&mut self) -> Result<()> { let result_option_peer = if self.uses_slots && !self.slot_taken { - self.use_best_peer_and_take_slot() + self.use_best_peer_and_take_slot(self.requested_model.clone()) } else { - self.upstream_peer_pool.use_best_peer() + self.upstream_peer_pool.use_best_peer(self.requested_model.clone()) }; self.selected_peer = match result_option_peer { @@ -95,6 +115,7 @@ mod tests { selected_peer: None, upstream_peer_pool, uses_slots: true, + requested_model: Some("llama3".to_string()), } } @@ -105,7 +126,7 @@ mod tests { pool.register_status_update("test_agent", mock_status_update("test_agent", 0, 0))?; - assert!(ctx.use_best_peer_and_take_slot().unwrap().is_none()); + assert!(ctx.use_best_peer_and_take_slot(ctx.requested_model.clone()).unwrap().is_none()); assert!(!ctx.slot_taken); assert_eq!(ctx.selected_peer, None); diff --git a/src/balancer/status_update.rs b/src/balancer/status_update.rs index 6f2c5409..5e811640 100644 --- a/src/balancer/status_update.rs +++ b/src/balancer/status_update.rs @@ -20,6 +20,7 @@ pub struct StatusUpdate { pub is_unexpected_response_status: Option, pub slots_idle: usize, pub slots_processing: usize, + pub model: Option, } impl StatusUpdate { diff --git a/src/balancer/test/mock_status_update.rs b/src/balancer/test/mock_status_update.rs index 07d12a1f..8cc9da41 100644 --- a/src/balancer/test/mock_status_update.rs +++ b/src/balancer/test/mock_status_update.rs @@ -22,5 +22,6 @@ pub fn mock_status_update( is_unexpected_response_status: Some(false), slots_idle, slots_processing, + model: Some("llama3".to_string()), } } diff --git a/src/balancer/upstream_peer.rs b/src/balancer/upstream_peer.rs index a8eaff95..24b52d62 100644 --- a/src/balancer/upstream_peer.rs +++ b/src/balancer/upstream_peer.rs @@ -13,6 +13,7 @@ use crate::balancer::status_update::StatusUpdate; #[derive(Clone, Debug, Eq, Serialize, Deserialize)] pub struct UpstreamPeer { pub agent_id: String, + pub model: Option, pub last_update: SystemTime, pub quarantined_until: Option, pub slots_taken: usize, @@ -24,6 +25,7 @@ impl UpstreamPeer { pub fn new_from_status_update(agent_id: String, status: StatusUpdate) -> Self { Self { agent_id, + model: status.model.clone(), last_update: SystemTime::now(), quarantined_until: None, slots_taken: 0, @@ -36,6 +38,14 @@ impl UpstreamPeer { !self.status.has_issues() && self.status.slots_idle > 0 && self.quarantined_until.is_none() } + pub fn supports_model(&self, requested_model: &str) -> bool { + requested_model.is_empty() || self.model.as_deref() == Some(requested_model) + } + + pub fn is_usable_for_model(&self, requested_model: &str) -> bool { + self.is_usable() && (requested_model.is_empty() || self.model.as_deref() == Some(requested_model)) + } + pub fn release_slot(&mut self) -> Result<()> { if self.slots_taken < 1 { return Err(anyhow!( @@ -59,6 +69,7 @@ impl UpstreamPeer { self.last_update = SystemTime::now(); self.quarantined_until = None; self.slots_taken_since_last_status_update = 0; + self.model = status_update.model.clone(); self.status = status_update; } @@ -110,6 +121,7 @@ mod tests { fn create_test_peer() -> UpstreamPeer { UpstreamPeer { agent_id: "test_agent".to_string(), + model: Some("llama3".to_string()), last_update: SystemTime::now(), quarantined_until: None, slots_taken: 0, @@ -130,6 +142,7 @@ mod tests { is_unexpected_response_status: None, slots_idle: 5, slots_processing: 0, + model: Some("llama3".to_string()), }, } } @@ -177,7 +190,6 @@ mod tests { #[test] fn test_update_status() { let mut peer = create_test_peer(); - let slots: Vec = vec![]; let slots_idle = slots.iter().filter(|slot| !slot.is_processing).count(); @@ -194,6 +206,7 @@ mod tests { is_unexpected_response_status: None, slots_idle, slots_processing: slots.len() - slots_idle, + model: Some("llama3".to_string()) }; peer.update_status(status_update); diff --git a/src/balancer/upstream_peer_pool.rs b/src/balancer/upstream_peer_pool.rs index 2e6924a9..4593bab9 100644 --- a/src/balancer/upstream_peer_pool.rs +++ b/src/balancer/upstream_peer_pool.rs @@ -2,6 +2,7 @@ use std::sync::atomic::AtomicUsize; use std::sync::RwLock; use std::time::Duration; use std::time::SystemTime; +use log::info; use anyhow::anyhow; use anyhow::Result; @@ -159,10 +160,15 @@ impl UpstreamPeerPool { }) } - pub fn use_best_peer(&self) -> Result> { - self.with_agents_read(|agents| { + pub fn use_best_peer(&self, model: Option) -> Result> { + self.with_agents_write(|agents| { for peer in agents.iter() { - if peer.is_usable() { + let model_str = model.as_deref().unwrap_or(""); + let is_usable = peer.is_usable(); + let is_usable_for_model = peer.is_usable_for_model(model_str); + + if is_usable && (model.is_none() || is_usable_for_model) { + info!("Peer {} is usable: {}, usable for model '{}': {}", peer.agent_id, is_usable, model_str, is_usable_for_model); return Ok(Some(peer.clone())); } } @@ -263,7 +269,7 @@ mod tests { pool.register_status_update("test2", mock_status_update("test2", 3, 0))?; pool.register_status_update("test3", mock_status_update("test3", 0, 0))?; - let best_peer = pool.use_best_peer()?.unwrap(); + let best_peer = pool.use_best_peer(None)?.unwrap(); assert_eq!(best_peer.agent_id, "test1"); assert_eq!(best_peer.status.slots_idle, 5); diff --git a/src/cmd/agent.rs b/src/cmd/agent.rs index aef1b711..578f5dbe 100644 --- a/src/cmd/agent.rs +++ b/src/cmd/agent.rs @@ -18,6 +18,7 @@ pub fn handle( management_addr: SocketAddr, monitoring_interval: Duration, name: Option, + check_model: bool, // Include the check_model flag ) -> Result<()> { let (status_update_tx, _status_update_rx) = channel::(1); @@ -29,6 +30,7 @@ pub fn handle( monitoring_interval, name, status_update_tx.clone(), + check_model, // Pass the check_model flag )?; let reporting_service = ReportingService::new(management_addr, status_update_tx)?; @@ -45,4 +47,4 @@ pub fn handle( pingora_server.add_service(monitoring_service); pingora_server.add_service(reporting_service); pingora_server.run_forever(); -} +} \ No newline at end of file diff --git a/src/cmd/balancer.rs b/src/cmd/balancer.rs index dbca8bc4..f5ac3d4d 100644 --- a/src/cmd/balancer.rs +++ b/src/cmd/balancer.rs @@ -2,6 +2,7 @@ use std::net::SocketAddr; use std::sync::Arc; #[cfg(feature = "statsd_reporter")] use std::time::Duration; +use log::info; use anyhow::Result; use pingora::proxy::http_proxy_service; @@ -24,6 +25,7 @@ pub fn handle( metrics_endpoint_enable: bool, reverseproxy_addr: &SocketAddr, rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, #[cfg(feature = "statsd_reporter")] statsd_addr: Option, #[cfg(feature = "statsd_reporter")] statsd_prefix: String, @@ -45,6 +47,7 @@ pub fn handle( &pingora_server.configuration, ProxyService::new( rewrite_host_header, + check_model, slots_endpoint_enable, upstream_peer_pool.clone(), buffered_request_timeout, @@ -76,5 +79,7 @@ pub fn handle( pingora_server.add_service(statsd_service); } + info!("rewrite_host_header? {} check_model? {} slots_endpoint_enable? {}", rewrite_host_header, check_model, slots_endpoint_enable); + pingora_server.run_forever(); } diff --git a/src/llamacpp/llamacpp_client.rs b/src/llamacpp/llamacpp_client.rs index 742025f4..4e802164 100644 --- a/src/llamacpp/llamacpp_client.rs +++ b/src/llamacpp/llamacpp_client.rs @@ -1,16 +1,19 @@ use std::net::SocketAddr; use std::time::Duration; +use anyhow::anyhow; use anyhow::Result; use reqwest::header; use url::Url; use crate::llamacpp::slot::Slot; use crate::llamacpp::slots_response::SlotsResponse; +use crate::llamacpp::models_response::ModelsResponse; pub struct LlamacppClient { client: reqwest::Client, slots_endpoint_url: String, + models_endpoint_url: String, } impl LlamacppClient { @@ -36,6 +39,7 @@ impl LlamacppClient { Ok(Self { client: builder.build()?, slots_endpoint_url: Url::parse(&format!("http://{addr}/slots"))?.to_string(), + models_endpoint_url: Url::parse(&format!("http://{addr}/v1/models"))?.to_string(), }) } @@ -115,4 +119,40 @@ impl LlamacppClient { }, } } + + pub async fn get_model(&self) -> Result> { + let url = self.models_endpoint_url.to_owned(); + + let response = match self.client.get(url.clone()).send().await { + Ok(resp) => resp, + Err(err) => { + return Err(anyhow!( + "Request to '{}' failed: '{}'; connect issue: {}; decode issue: {}; request issue: {}; status issue: {}; status: {:?}", + url, + err, + err.is_connect(), + err.is_decode(), + err.is_request(), + err.is_status(), + err.status() + )); + } + }; + + match response.status() { + reqwest::StatusCode::OK => { + let models_response: ModelsResponse = response.json().await?; + if let Some(models) = models_response.models { + if models.is_empty() { + Ok(None) + } else { + Ok(models.first().and_then(|m| Some(m.model.clone()))) + } + } else { + Ok(None) + } + }, + _ => Err(anyhow!("Unexpected response status")), + } + } } diff --git a/src/llamacpp/mod.rs b/src/llamacpp/mod.rs index 2b1e1d24..68f5cae8 100644 --- a/src/llamacpp/mod.rs +++ b/src/llamacpp/mod.rs @@ -1,3 +1,4 @@ pub mod llamacpp_client; pub mod slot; pub mod slots_response; +pub mod models_response; diff --git a/src/llamacpp/models_response.rs b/src/llamacpp/models_response.rs new file mode 100644 index 00000000..d7d04b50 --- /dev/null +++ b/src/llamacpp/models_response.rs @@ -0,0 +1,13 @@ +// paddler/src/llamacpp/models_response.rs +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +pub struct ModelsResponse { + pub models: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct Model { + pub model: String, + // Add other fields as needed +} diff --git a/src/main.rs b/src/main.rs index 007721c3..baa24430 100644 --- a/src/main.rs +++ b/src/main.rs @@ -87,6 +87,10 @@ enum Commands { #[arg(long)] /// Name of the agent (optional) name: Option, + + #[arg(long)] + /// Flag whether to check the model served by llama.cpp and reject requests for other models + check_model: bool, }, /// Balances incoming requests to llama.cpp instances and optionally provides a web dashboard Balancer { @@ -134,6 +138,10 @@ enum Commands { /// Enable the slots endpoint (not recommended) slots_endpoint_enable: bool, + #[arg(long)] + /// Flag to check the model served by llama.cpp and reject requests for other models + check_model: bool, + #[cfg(feature = "statsd_reporter")] #[arg(long, value_parser = parse_socket_addr)] /// Address of the statsd server to report metrics to @@ -171,6 +179,7 @@ fn main() -> Result<()> { management_addr, monitoring_interval, name, + check_model, }) => cmd::agent::handle( match external_llamacpp_addr { Some(addr) => addr.to_owned(), @@ -181,6 +190,7 @@ fn main() -> Result<()> { management_addr.to_owned(), monitoring_interval.to_owned(), name.to_owned(), + *check_model ), Some(Commands::Balancer { buffered_request_timeout, @@ -192,6 +202,7 @@ fn main() -> Result<()> { metrics_endpoint_enable, reverseproxy_addr, rewrite_host_header, + check_model, slots_endpoint_enable, #[cfg(feature = "statsd_reporter")] statsd_addr, @@ -213,6 +224,7 @@ fn main() -> Result<()> { metrics_endpoint_enable.to_owned(), reverseproxy_addr, rewrite_host_header.to_owned(), + *check_model, slots_endpoint_enable.to_owned(), #[cfg(feature = "statsd_reporter")] statsd_addr.to_owned(),