Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions server/src/ollama_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub async fn send_to_ollama(
prompt: String,
model: String,
ollama_host: String,
num_predict: Option<i32>,
) -> Result<mpsc::Receiver<Result<Bytes, reqwest::Error>>, reqwest::Error> {
tracing::info!("Sending images to Ollama: {:?}", images_batch.len());
let images_b64: Vec<String> = images_batch
Expand All @@ -96,6 +97,14 @@ pub async fn send_to_ollama(
"images": images_b64,
"model": model.clone(),
"stream": true,
<<<<<<< Updated upstream
<<<<<<< Updated upstream
=======
"num_predict": num_predict.unwrap_or(16),
>>>>>>> Stashed changes
=======
"num_predict": num_predict.unwrap_or(16),
>>>>>>> Stashed changes
});

let url = format!("{}/api/generate", ollama_host);
Expand Down
34 changes: 23 additions & 11 deletions server/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use actix_web::{rt, web, Error, HttpRequest, HttpResponse};
use actix_ws::Message;
use futures::{select, FutureExt};
use futures_util::StreamExt as _;
use serde_json::Value;
use tokio::time::{self, Duration, Instant};

use crate::{config, ollama_client::{send_to_ollama}};
Expand All @@ -12,10 +13,11 @@ async fn proxy_ollama_response(
prompt: String,
model: String,
ollama_host: String,
num_predict: i32,
) {
let mut session_clone = session.clone();
rt::spawn(async move {
match send_to_ollama(images_batch, prompt, model, ollama_host).await {
match send_to_ollama(images_batch, prompt, model, ollama_host, Some(num_predict)).await {
Ok(mut rx) => {
while let Some(res) = rx.next().await {
match res {
Expand Down Expand Up @@ -43,17 +45,17 @@ async fn handle_binary(
images: &mut Vec<Vec<u8>>,
bin: bytes::Bytes,
session: &mut actix_ws::Session,
last_prompt: &Option<String>,
last_prompt: &Option<(String, i32)>,
model: String,
ollama_host: String,
image_batch_size: usize,
) {
images.push(bin.to_vec());

if images.len() >= image_batch_size {
if let Some(prompt) = last_prompt.clone() {
if let Some((prompt, num_predict)) = last_prompt.clone() {
let images_batch = std::mem::take(images);
proxy_ollama_response(session, images_batch, prompt, model, ollama_host).await;
proxy_ollama_response(session, images_batch, prompt, model, ollama_host, num_predict).await;
} else {
let _ = session.text("No prompt received for image batch".to_string()).await;
}
Expand All @@ -64,16 +66,24 @@ async fn handle_text(
session: &mut actix_ws::Session,
images: &mut Vec<Vec<u8>>,
text: String,
last_prompt: &mut Option<String>,
last_prompt: &mut Option<(String, i32)>,
model: String,
ollama_host: String,
) {
*last_prompt = Some(text.clone());
let (prompt, num_predict) = if let Ok(value) = serde_json::from_str::<Value>(&text) {
let prompt = value["prompt"].as_str().unwrap_or(&text).to_string();
let num_predict = value["num_predict"].as_i64().map(|n| n as i32).unwrap_or(16);
(prompt, num_predict)
} else {
(text.clone(), 16)
};

*last_prompt = Some((prompt.clone(), num_predict));

// If prompt updated, send the images to Ollama
if !images.is_empty() {
let images_batch = std::mem::take(images);
proxy_ollama_response(session, images_batch, text, model, ollama_host).await;
proxy_ollama_response(session, images_batch, prompt, model, ollama_host, num_predict).await;
}
}

Expand All @@ -92,7 +102,7 @@ pub async fn ws_index(req: HttpRequest, stream: web::Payload, vlm_config: web::D

rt::spawn(async move {
let mut images: Vec<Vec<u8>> = Vec::new();
let mut last_prompt: Option<String> = None;
let mut last_prompt: Option<(String, i32)> = None;

let mut inference_interval = time::interval_at(Instant::now() + Duration::from_secs(30), Duration::from_secs(10));
inference_interval.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
Expand Down Expand Up @@ -141,9 +151,11 @@ pub async fn ws_index(req: HttpRequest, stream: web::Payload, vlm_config: web::D
tracing::error!("Error sending ping: {:?}", e);
break;
}
if let Some(prompt) = last_prompt.clone() {
tracing::info!("Inference interval fired: running handle_text with last prompt");
handle_text(&mut session, &mut images, prompt.clone(), &mut last_prompt, model.clone(), ollama_host.clone()).await;
if let Some((prompt, num_predict)) = last_prompt.clone() {
if !images.is_empty() {
let images_batch = std::mem::take(&mut images);
proxy_ollama_response(&mut session, images_batch, prompt, model.clone(), ollama_host.clone(), num_predict).await;
}
}
}
}
Expand Down
Binary file added worker/__pycache__/vlm.cpython-311.pyc
Binary file not shown.
21 changes: 20 additions & 1 deletion worker/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,25 @@ def parse_image_timestamp(image_path):
return match.group('ts')
return ""

def run_inference(vlm_prompt, prompt, image_paths):
def run_inference(vlm_prompt, prompt, image_paths, temperature: float = 0.2, num_predict: int = 16):
start_image = None
end_image = None

# Get model names from environment variables
vlm_model = os.environ.get("VLM_MODEL", "llava:7b")
<<<<<<< Updated upstream
llm_model = os.environ.get("LLM_MODEL", "llama3:latest")

logger.info("Using VLM model: " + vlm_model)
logger.info("Using LLM model: " + llm_model)

# Ensure models are available
=======

logger.info("Using VLM model: " + vlm_model + " with temperature=" + str(temperature) + " and num_predict=" + str(num_predict))

# Ensure VLM model is available
>>>>>>> Stashed changes
ensure_model_available(vlm_model)
ensure_model_available(llm_model)

Expand All @@ -98,6 +105,13 @@ def run_inference(vlm_prompt, prompt, image_paths):
model=vlm_model,
prompt=vlm_prompt,
images=[image_path],
<<<<<<< Updated upstream
=======
options={
"temperature": float(temperature),
"num_predict": num_predict
},
>>>>>>> Stashed changes
)
results += '"' + parse_image_id(image_path) + '",' + '"' + parse_image_timestamp(image_path) + '",' + '"' + res.response + '"\n'
logger.info("Inference output: " + str(res))
Expand Down Expand Up @@ -183,7 +197,12 @@ def run(conn, job, input_dir, output_dir):
return

try:
<<<<<<< Updated upstream
results = run_inference(inputs['vlm_prompt'], inputs['prompt'], image_paths)
=======
temperature = float(inputs.get('temperature', os.environ.get("VLM_TEMPERATURE", 0.2)))
num_predict = int(inputs.get('num_predict', os.environ.get("VLM_NUM_PREDICT", 16)))
results = run_inference(inputs['vlm_prompt'], inputs['prompt'], image_paths, temperature=temperature, num_predict=num_predict)
except Exception as e:
logger.error("Error processing job", extra={"job_id": job['id'], "error": str(e)})
err = {
Expand Down