diff --git a/server/src/ollama_client.rs b/server/src/ollama_client.rs index 46ddfe0..52c6ee0 100644 --- a/server/src/ollama_client.rs +++ b/server/src/ollama_client.rs @@ -84,6 +84,7 @@ pub async fn send_to_ollama( prompt: String, model: String, ollama_host: String, + num_predict: Option, ) -> Result>, reqwest::Error> { tracing::info!("Sending images to Ollama: {:?}", images_batch.len()); let images_b64: Vec = images_batch @@ -96,6 +97,14 @@ pub async fn send_to_ollama( "images": images_b64, "model": model.clone(), "stream": true, +<<<<<<< Updated upstream +<<<<<<< Updated upstream +======= + "num_predict": num_predict.unwrap_or(16), +>>>>>>> Stashed changes +======= + "num_predict": num_predict.unwrap_or(16), +>>>>>>> Stashed changes }); let url = format!("{}/api/generate", ollama_host); diff --git a/server/src/stream.rs b/server/src/stream.rs index b7d8bbe..463cae8 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -2,6 +2,7 @@ use actix_web::{rt, web, Error, HttpRequest, HttpResponse}; use actix_ws::Message; use futures::{select, FutureExt}; use futures_util::StreamExt as _; +use serde_json::Value; use tokio::time::{self, Duration, Instant}; use crate::{config, ollama_client::{send_to_ollama}}; @@ -12,10 +13,11 @@ async fn proxy_ollama_response( prompt: String, model: String, ollama_host: String, + num_predict: i32, ) { let mut session_clone = session.clone(); rt::spawn(async move { - match send_to_ollama(images_batch, prompt, model, ollama_host).await { + match send_to_ollama(images_batch, prompt, model, ollama_host, Some(num_predict)).await { Ok(mut rx) => { while let Some(res) = rx.next().await { match res { @@ -43,7 +45,7 @@ async fn handle_binary( images: &mut Vec>, bin: bytes::Bytes, session: &mut actix_ws::Session, - last_prompt: &Option, + last_prompt: &Option<(String, i32)>, model: String, ollama_host: String, image_batch_size: usize, @@ -51,9 +53,9 @@ async fn handle_binary( images.push(bin.to_vec()); if images.len() >= image_batch_size { - if let Some(prompt) = last_prompt.clone() { + if let Some((prompt, num_predict)) = last_prompt.clone() { let images_batch = std::mem::take(images); - proxy_ollama_response(session, images_batch, prompt, model, ollama_host).await; + proxy_ollama_response(session, images_batch, prompt, model, ollama_host, num_predict).await; } else { let _ = session.text("No prompt received for image batch".to_string()).await; } @@ -64,16 +66,24 @@ async fn handle_text( session: &mut actix_ws::Session, images: &mut Vec>, text: String, - last_prompt: &mut Option, + last_prompt: &mut Option<(String, i32)>, model: String, ollama_host: String, ) { - *last_prompt = Some(text.clone()); + let (prompt, num_predict) = if let Ok(value) = serde_json::from_str::(&text) { + let prompt = value["prompt"].as_str().unwrap_or(&text).to_string(); + let num_predict = value["num_predict"].as_i64().map(|n| n as i32).unwrap_or(16); + (prompt, num_predict) + } else { + (text.clone(), 16) + }; + + *last_prompt = Some((prompt.clone(), num_predict)); // If prompt updated, send the images to Ollama if !images.is_empty() { let images_batch = std::mem::take(images); - proxy_ollama_response(session, images_batch, text, model, ollama_host).await; + proxy_ollama_response(session, images_batch, prompt, model, ollama_host, num_predict).await; } } @@ -92,7 +102,7 @@ pub async fn ws_index(req: HttpRequest, stream: web::Payload, vlm_config: web::D rt::spawn(async move { let mut images: Vec> = Vec::new(); - let mut last_prompt: Option = None; + let mut last_prompt: Option<(String, i32)> = None; let mut inference_interval = time::interval_at(Instant::now() + Duration::from_secs(30), Duration::from_secs(10)); inference_interval.set_missed_tick_behavior(time::MissedTickBehavior::Skip); @@ -141,9 +151,11 @@ pub async fn ws_index(req: HttpRequest, stream: web::Payload, vlm_config: web::D tracing::error!("Error sending ping: {:?}", e); break; } - if let Some(prompt) = last_prompt.clone() { - tracing::info!("Inference interval fired: running handle_text with last prompt"); - handle_text(&mut session, &mut images, prompt.clone(), &mut last_prompt, model.clone(), ollama_host.clone()).await; + if let Some((prompt, num_predict)) = last_prompt.clone() { + if !images.is_empty() { + let images_batch = std::mem::take(&mut images); + proxy_ollama_response(&mut session, images_batch, prompt, model.clone(), ollama_host.clone(), num_predict).await; + } } } } diff --git a/worker/__pycache__/vlm.cpython-311.pyc b/worker/__pycache__/vlm.cpython-311.pyc new file mode 100644 index 0000000..b7d36f5 Binary files /dev/null and b/worker/__pycache__/vlm.cpython-311.pyc differ diff --git a/worker/vlm.py b/worker/vlm.py index 33a666a..e2d022a 100644 --- a/worker/vlm.py +++ b/worker/vlm.py @@ -74,18 +74,25 @@ def parse_image_timestamp(image_path): return match.group('ts') return "" -def run_inference(vlm_prompt, prompt, image_paths): +def run_inference(vlm_prompt, prompt, image_paths, temperature: float = 0.2, num_predict: int = 16): start_image = None end_image = None # Get model names from environment variables vlm_model = os.environ.get("VLM_MODEL", "llava:7b") +<<<<<<< Updated upstream llm_model = os.environ.get("LLM_MODEL", "llama3:latest") logger.info("Using VLM model: " + vlm_model) logger.info("Using LLM model: " + llm_model) # Ensure models are available +======= + + logger.info("Using VLM model: " + vlm_model + " with temperature=" + str(temperature) + " and num_predict=" + str(num_predict)) + + # Ensure VLM model is available +>>>>>>> Stashed changes ensure_model_available(vlm_model) ensure_model_available(llm_model) @@ -98,6 +105,13 @@ def run_inference(vlm_prompt, prompt, image_paths): model=vlm_model, prompt=vlm_prompt, images=[image_path], +<<<<<<< Updated upstream +======= + options={ + "temperature": float(temperature), + "num_predict": num_predict + }, +>>>>>>> Stashed changes ) results += '"' + parse_image_id(image_path) + '",' + '"' + parse_image_timestamp(image_path) + '",' + '"' + res.response + '"\n' logger.info("Inference output: " + str(res)) @@ -183,7 +197,12 @@ def run(conn, job, input_dir, output_dir): return try: +<<<<<<< Updated upstream results = run_inference(inputs['vlm_prompt'], inputs['prompt'], image_paths) +======= + temperature = float(inputs.get('temperature', os.environ.get("VLM_TEMPERATURE", 0.2))) + num_predict = int(inputs.get('num_predict', os.environ.get("VLM_NUM_PREDICT", 16))) + results = run_inference(inputs['vlm_prompt'], inputs['prompt'], image_paths, temperature=temperature, num_predict=num_predict) except Exception as e: logger.error("Error processing job", extra={"job_id": job['id'], "error": str(e)}) err = {