From 64fbc3ef3eb428642adaa0875d93360915bfa3e3 Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Sun, 18 May 2025 17:34:15 +1200 Subject: [PATCH] Switch to NVIDIA Parakeet ASR --- .github/workflows/python-tests.yml | 21 ++++++++ README.md | 6 +-- .../inference_server/inference_server.py | 50 +++++++++---------- questions/inference_server/requirements.in | 2 +- questions/inference_server/requirements.txt | 3 +- tests/unit/test_audio_model.py | 20 ++++++++ 6 files changed, 71 insertions(+), 31 deletions(-) create mode 100644 .github/workflows/python-tests.yml mode change 100755 => 100644 README.md mode change 100755 => 100644 questions/inference_server/requirements.in create mode 100644 tests/unit/test_audio_model.py diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml new file mode 100644 index 0000000..17b71a7 --- /dev/null +++ b/.github/workflows/python-tests.yml @@ -0,0 +1,21 @@ +name: Unit Tests +on: + push: + pull_request: + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install uv + run: python -m pip install astral-uv + - name: Install dependencies + run: | + uv pip install --no-compile -r questions/inference_server/requirements.txt + uv pip install --no-compile -r dev-requirements.txt + - name: Run unit tests + run: pytest tests/unit diff --git a/README.md b/README.md old mode 100755 new mode 100644 index e100adf..46c4140 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Text Generator is a system for; * Serving AI APIs via swapping in AI Networks. * Using data enrichment (OCR, crawling, image analysis to make prompt engineering easier) * Generating speech and text. -* Understanding text and speech (speech to text with whisper). +* Understanding text and speech (speech to text with NVIDIA Parakeet). Text generator can be used via API or self hosted. @@ -128,7 +128,7 @@ cd models git clone https://huggingface.co/distilbert-base-uncased ``` -whisper and STT models will be loaded on demand and placed in the huggingface cache. +Parakeet ASR models will be loaded on demand and placed in the huggingface cache. #### Run @@ -194,7 +194,7 @@ PYTHONPATH=$HOME/code/20-questions:$HOME/code/20-questions/OFA:$HOME/code/20-que Then go to localhost:9080/docs to use the API #### run audio server only -Just the whisper speech to text part. +Just the Parakeet speech to text part. This isn't required as the inference server automatically balances these requests ```shell PYTHONPATH=$(pwd):$(pwd)/OFA GOOGLE_APPLICATION_CREDENTIALS=secrets/google-credentials.json gunicorn -k uvicorn.workers.UvicornWorker -b :9080 audio_server.audio_server:app --timeout 180000 --workers 1 diff --git a/questions/inference_server/inference_server.py b/questions/inference_server/inference_server.py index 0f49f2d..88028b6 100644 --- a/questions/inference_server/inference_server.py +++ b/questions/inference_server/inference_server.py @@ -69,8 +69,7 @@ version="1", ) -import whisper -import numpy as np +import nemo.collections.asr as nemo_asr MODEL_CACHE = ModelCache() @@ -259,20 +258,20 @@ def validate_generate_params(generate_params): audio_model = None + + def load_audio_model(): + """Load and return the Parakeet ASR model for speech to text.""" global audio_model - # about 10s - with log_time("load whisper model"): - # audio_model = whisper.load_model("large") # todo specify download_root to a fast ssd + with log_time("load parakeet model"): if not audio_model: - audio_model = whisper.load_model("medium", download_root="models") # todo specify download_root to a fast ssd - audio_model.eval() - audio_model = audio_model.to("cuda") - logger.info( - f"Model is {'multilingual' if audio_model.is_multilingual else 'English-only'} " - f"and has {sum(np.prod(p.shape) for p in audio_model.parameters()):,} parameters." - f"and is on {audio_model.device}" - ) + audio_model = nemo_asr.models.ASRModel.from_pretrained( + model_name="nvidia/parakeet-tdt-0.6b-v2" + ) + audio_model.freeze() + device = "cuda" if torch.cuda.is_available() else "cpu" + audio_model = audio_model.to(device) + logger.info(f"Model loaded on {device}") return audio_model @@ -321,20 +320,21 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile): audio_bytes = response.content with torch.inference_mode(): - opts = transcribe_options # dict(beam_size=5, best_of=5) - if audio_params.translate_to_english: - opts = translate_options - # write to /dev/shm ... assume mp3 - tmp_file = NamedTemporaryFile(dir="/dev/shm", delete=True, suffix=".mp3") + tmp_file = NamedTemporaryFile(dir="/dev/shm", delete=True, suffix=".wav") tmp_file.write(audio_bytes) - result = audio_model.transcribe(tmp_file.name, **opts) - - # clean data + nemo_result = audio_model.transcribe([tmp_file.name], timestamps=True)[0] tmp_file.close() - for segment in result['segments']: - del segment['tokens'] - result['text'] = result['text'].strip() - return result + + segments = [ + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("segment", "").strip(), + } + for seg in nemo_result.timestamp["segment"] + ] + + return {"text": nemo_result.text.strip(), "segments": segments} def srt_format_timestamp(seconds: float): diff --git a/questions/inference_server/requirements.in b/questions/inference_server/requirements.in old mode 100755 new mode 100644 index e57f754..ca89d88 --- a/questions/inference_server/requirements.in +++ b/questions/inference_server/requirements.in @@ -75,7 +75,7 @@ pypdf2>=3.0.0 pytesseract>=0.3.10 setuptools>=69.0.0 -whisper>=1.1.0 +nemo_toolkit[asr]>=2.2.0 ffmpeg-python>=0.2.0 tensorboard>=2.15.0 # download video/audio extract diff --git a/questions/inference_server/requirements.txt b/questions/inference_server/requirements.txt index d2fed44..0b5950d 100644 --- a/questions/inference_server/requirements.txt +++ b/questions/inference_server/requirements.txt @@ -548,7 +548,6 @@ six==1.17.0 # -r requirements.in # python-dateutil # tensorboard - # whisper sniffio==1.3.1 # via # anthropic @@ -649,7 +648,7 @@ websockets==12.0 # via gradio-client werkzeug==3.1.2 # via tensorboard -whisper==1.1.10 +nemo-toolkit[asr]==2.2.0 # via -r requirements.in wrapt==1.17.0 # via -r requirements.in diff --git a/tests/unit/test_audio_model.py b/tests/unit/test_audio_model.py new file mode 100644 index 0000000..84fc6e5 --- /dev/null +++ b/tests/unit/test_audio_model.py @@ -0,0 +1,20 @@ +import builtins +from unittest import mock + +import questions.inference_server.inference_server as server + + +def test_load_audio_model(monkeypatch): + fake_model = mock.MagicMock() + monkeypatch.setattr( + server.nemo_asr.models.ASRModel, + "from_pretrained", + mock.MagicMock(return_value=fake_model), + ) + server.audio_model = None + model = server.load_audio_model() + assert model is fake_model + server.nemo_asr.models.ASRModel.from_pretrained.assert_called_once_with( + model_name="nvidia/parakeet-tdt-0.6b-v2" + ) + fake_model.to.assert_called()