From 874bf453482ee754d90afe9c0492a0c1e9b94586 Mon Sep 17 00:00:00 2001 From: Hubert Nowak Date: Sun, 5 Oct 2025 09:38:43 +0200 Subject: [PATCH] Handle multiple files in parallel --- backend/app/background_tasks.py | 44 +++++++++++++++++---------------- backend/app/transcribe.py | 4 +-- backend/app/upload_api.py | 24 +++++++++--------- backend/requirements.txt | 3 ++- 4 files changed, 40 insertions(+), 35 deletions(-) diff --git a/backend/app/background_tasks.py b/backend/app/background_tasks.py index afc7f44..3deae52 100644 --- a/backend/app/background_tasks.py +++ b/backend/app/background_tasks.py @@ -9,9 +9,11 @@ import requests import string import json +import httpx MODEL_SIZE = os.getenv("WHISPER_MODEL", "base") + def load_whisper(): preferred = os.getenv("WHISPER_DEVICE") # cuda | metal | cpu (optional) @@ -39,10 +41,11 @@ def load_whisper(): raise RuntimeError(f"Whisper init failed. Last error: {last_err}") + whisper_model = load_whisper() -def process_file(original_path: str, patch_duration_sec: int, overlap_sec: int): +async def process_file(original_path: str, patch_duration_sec: int, overlap_sec: int): # Import db session inside background task with tempfile.TemporaryDirectory() as tmpdir: # Work on a copy @@ -57,29 +60,31 @@ def process_file(original_path: str, patch_duration_sec: int, overlap_sec: int): audio_path = working_path # Split audio - patches = split_audio_to_patches(str(audio_path), patch_duration_sec, overlap_sec) + patches = await split_audio_to_patches(str(audio_path), patch_duration_sec, overlap_sec) # Transcribe - all_results = transcribe_patches(patches, whisper_model) + all_results = await transcribe_patches(patches, whisper_model) return all_results -def send_to_llm(transcribed_text: str, default_definitions: list = None, positive_examples: list = None, negative_examples: list = None): +async def send_to_llm(transcribed_text: str, default_definitions: list = None, positive_examples: list = None, negative_examples: list = None): """Send transcribed text to LLM for analysis""" - response = requests.post( - "http://localhost:8001/detect", - json={ - "transcription": transcribed_text, - "default_definitions": default_definitions or [], - "positive_examples": positive_examples or [], - "negative_examples": negative_examples or [] - } - ) + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8001/detect", + json={ + "transcription": transcribed_text, + "default_definitions": default_definitions or [], + "positive_examples": positive_examples or [], + "negative_examples": negative_examples or [] + }, + timeout=1800 + ) - result = response.json() + result = response.json() - return result + return result def normalize_word(word): @@ -118,17 +123,15 @@ def find_matching_spans(transcribed_patches: dict, llm_spans: list): return processed_spans -def main_background_function(job_id: str, original_path: str, patch_duration_sec: int, overlap_sec: int, db: Session, default_definitions: list = None, positive_examples: list = None, negative_examples: list = None): - +async def main_background_function(job_id: str, original_path: str, patch_duration_sec: int, overlap_sec: int, db: Session, default_definitions: list = None, positive_examples: list = None, negative_examples: list = None): job = db.get(Job, job_id) job.status = "transcribing" db.commit() db.expire_all() - transcribed_patches = process_file(original_path, patch_duration_sec, overlap_sec) + transcribed_patches = await process_file(original_path, patch_duration_sec, overlap_sec) print(f"Got {len(transcribed_patches)} batches") - cleaned_list = [] @@ -149,11 +152,10 @@ def main_background_function(job_id: str, original_path: str, patch_duration_sec for i, batch in enumerate(cleaned_list): print(f"Evaluating {i + 1}/{len(transcribed_patches)} ") - result_from_llm = send_to_llm(batch, default_definitions, positive_examples, negative_examples) + result_from_llm = await send_to_llm(batch, default_definitions, positive_examples, negative_examples) llm_spans = result_from_llm["spans"] processed_spans = find_matching_spans(transcribed_patches[i], llm_spans) all_processed_spans.extend(processed_spans) - # Read the final transcribed text from file with open(Path(original_path).with_suffix('.txt'), 'r', encoding='utf-8') as f: diff --git a/backend/app/transcribe.py b/backend/app/transcribe.py index 4d9fc71..f6e5bf9 100644 --- a/backend/app/transcribe.py +++ b/backend/app/transcribe.py @@ -20,7 +20,7 @@ def convert_video_to_audio(video_path: str, output_audio_path: str) -> str: raise -def split_audio_to_patches(audio_path: str, patch_duration_sec: int = 120, overlap_sec: int = 30): +async def split_audio_to_patches(audio_path: str, patch_duration_sec: int = 120, overlap_sec: int = 30): print(f"[INFO] Loading audio for patching: {audio_path}") y, sr = librosa.load(audio_path, sr=None) total_duration = librosa.get_duration(y=y, sr=sr) @@ -42,7 +42,7 @@ def split_audio_to_patches(audio_path: str, patch_duration_sec: int = 120, overl return patches -def transcribe_patches(patches, model): +async def transcribe_patches(patches, model): all_results = [] for i, patch_path in enumerate(patches): print(f"[INFO] Transcribing patch {i}: {patch_path}") diff --git a/backend/app/upload_api.py b/backend/app/upload_api.py index 3f3a211..d515c1b 100644 --- a/backend/app/upload_api.py +++ b/backend/app/upload_api.py @@ -9,6 +9,7 @@ from typing import List, Optional from app.models import Job, Batch from app.background_tasks import main_background_function +import asyncio router = routing.APIRouter() @@ -51,7 +52,7 @@ def extract_audio_from_zip(zip_file: UploadFile, extract_dir: Path) -> List[Path @router.post("") async def upload_batch( - background_tasks: BackgroundTasks, + # background_tasks: BackgroundTasks, name: str = Form(...), description: Optional[str] = Form(None), default_definitions: str = Form("[]"), @@ -141,16 +142,17 @@ async def upload_batch( db.refresh(job) # Launch background task for each job - background_tasks.add_task( - main_background_function, - str(job.id), - str(file_path), - patch_duration_sec, - overlap_sec, - db, - default_defs_list, - positive_examples_list, - negative_examples_list + asyncio.create_task( + main_background_function( + str(job.id), + str(file_path), + patch_duration_sec, + overlap_sec, + db, + default_defs_list, + positive_examples_list, + negative_examples_list + ) ) return {"batch_id": batch.id} diff --git a/backend/requirements.txt b/backend/requirements.txt index baaf241..2d88db8 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -12,4 +12,5 @@ python-multipart ffmpeg-python==0.2.0 pydub==0.25.1 librosa -soundfile \ No newline at end of file +soundfile +httpx \ No newline at end of file