Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Unit Tests
on:
push:
pull_request:

jobs:
tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install uv
run: python -m pip install astral-uv
- name: Install dependencies
run: |
uv pip install --no-compile -r questions/inference_server/requirements.txt
uv pip install --no-compile -r dev-requirements.txt
- name: Run unit tests
run: pytest tests/unit
6 changes: 3 additions & 3 deletions README.md
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Text Generator is a system for;
* Serving AI APIs via swapping in AI Networks.
* Using data enrichment (OCR, crawling, image analysis to make prompt engineering easier)
* Generating speech and text.
* Understanding text and speech (speech to text with whisper).
* Understanding text and speech (speech to text with NVIDIA Parakeet).

Text generator can be used via API or self hosted.

Expand Down Expand Up @@ -150,7 +150,7 @@ cd models
git clone https://huggingface.co/distilbert-base-uncased
```

whisper and STT models will be loaded on demand and placed in the huggingface cache.
Parakeet ASR models will be loaded on demand and placed in the huggingface cache.


#### Run
Expand Down Expand Up @@ -216,7 +216,7 @@ PYTHONPATH=$HOME/code/20-questions:$HOME/code/20-questions/OFA:$HOME/code/20-que
Then go to localhost:9080/docs to use the API
#### run audio server only

Just the whisper speech to text part.
Just the Parakeet speech to text part.
This isn't required as the inference server automatically balances these requests
```shell
PYTHONPATH=$(pwd):$(pwd)/OFA GOOGLE_APPLICATION_CREDENTIALS=secrets/google-credentials.json gunicorn -k uvicorn.workers.UvicornWorker -b :9080 audio_server.audio_server:app --timeout 180000 --workers 1
Expand Down
50 changes: 23 additions & 27 deletions questions/inference_server/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@
version="1",
)

import whisper
import numpy as np
import nemo.collections.asr as nemo_asr

MODEL_CACHE = ModelCache()

Expand Down Expand Up @@ -345,21 +344,17 @@ def validate_generate_params(generate_params):


def load_audio_model():
"""Load and return the Parakeet ASR model for speech to text."""
global audio_model
# about 10s
with log_time("load whisper model"):
# audio_model = whisper.load_model("large") # todo specify download_root to a fast ssd
with log_time("load parakeet model"):
if not audio_model:
audio_model = whisper.load_model(
"medium", download_root="models"
) # todo specify download_root to a fast ssd
audio_model.eval()
audio_model = audio_model.to("cuda")
logger.info(
f"Model is {'multilingual' if audio_model.is_multilingual else 'English-only'} "
f"and has {sum(np.prod(p.shape) for p in audio_model.parameters()):,} parameters."
f"and is on {audio_model.device}"
)
audio_model = nemo_asr.models.ASRModel.from_pretrained(
model_name="nvidia/parakeet-tdt-0.6b-v2", download_root="models"
)
audio_model.freeze()
device = "cuda" if torch.cuda.is_available() else "cpu"
audio_model = audio_model.to(device)
logger.info(f"Model loaded on {device}")
return audio_model


Expand Down Expand Up @@ -412,20 +407,21 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile):
audio_bytes = response.content

with torch.inference_mode():
opts = transcribe_options # dict(beam_size=5, best_of=5)
if audio_params.translate_to_english:
opts = translate_options
# write to /dev/shm ... assume mp3
tmp_file = NamedTemporaryFile(dir="/dev/shm", delete=True, suffix=".mp3")
tmp_file = NamedTemporaryFile(dir="/dev/shm", delete=True, suffix=".wav")
tmp_file.write(audio_bytes)
result = audio_model.transcribe(tmp_file.name, **opts)

# clean data
nemo_result = audio_model.transcribe([tmp_file.name], timestamps=True)[0]
tmp_file.close()
for segment in result["segments"]:
del segment["tokens"]
result["text"] = result["text"].strip()
return result

segments = [
{
"start": seg["start"],
"end": seg["end"],
"text": seg.get("segment", "").strip(),
}
for seg in nemo_result.timestamp["segment"]
]

return {"text": nemo_result.text.strip(), "segments": segments}



Expand Down
2 changes: 1 addition & 1 deletion questions/inference_server/requirements.in
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pypdf2>=3.0.0
pytesseract>=0.3.10
setuptools>=69.0.0

whisper>=1.1.0
nemo_toolkit[asr]>=2.2.0
ffmpeg-python>=0.2.0
tensorboard>=2.15.0
# download video/audio extract
Expand Down
3 changes: 1 addition & 2 deletions questions/inference_server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ six==1.17.0
# -r requirements.in
# python-dateutil
# tensorboard
# whisper
sniffio==1.3.1
# via
# anthropic
Expand Down Expand Up @@ -648,7 +647,7 @@ websockets==12.0
# via gradio-client
werkzeug==3.1.2
# via tensorboard
whisper==1.1.10
nemo-toolkit[asr]==2.2.0
# via -r requirements.in
wrapt==1.17.0
# via -r requirements.in
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_audio_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import builtins
from unittest import mock

import questions.inference_server.inference_server as server


def test_load_audio_model(monkeypatch):
fake_model = mock.MagicMock()
monkeypatch.setattr(
server.nemo_asr.models.ASRModel,
"from_pretrained",
mock.MagicMock(return_value=fake_model),
)
server.audio_model = None
model = server.load_audio_model()
assert model is fake_model
server.nemo_asr.models.ASRModel.from_pretrained.assert_called_once_with(
model_name="nvidia/parakeet-tdt-0.6b-v2"
)
fake_model.to.assert_called()
Loading