diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml new file mode 100644 index 0000000..17b71a7 --- /dev/null +++ b/.github/workflows/python-tests.yml @@ -0,0 +1,21 @@ +name: Unit Tests +on: + push: + pull_request: + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install uv + run: python -m pip install astral-uv + - name: Install dependencies + run: | + uv pip install --no-compile -r questions/inference_server/requirements.txt + uv pip install --no-compile -r dev-requirements.txt + - name: Run unit tests + run: pytest tests/unit diff --git a/README.md b/README.md old mode 100755 new mode 100644 index 5f7edde..a2b6047 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Text Generator is a system for; * Serving AI APIs via swapping in AI Networks. * Using data enrichment (OCR, crawling, image analysis to make prompt engineering easier) * Generating speech and text. -* Understanding text and speech (speech to text with whisper). +* Understanding text and speech (speech to text with NVIDIA Parakeet). Text generator can be used via API or self hosted. @@ -150,7 +150,7 @@ cd models git clone https://huggingface.co/distilbert-base-uncased ``` -whisper and STT models will be loaded on demand and placed in the huggingface cache. +Parakeet ASR models will be loaded on demand and placed in the huggingface cache. #### Run @@ -216,7 +216,7 @@ PYTHONPATH=$HOME/code/20-questions:$HOME/code/20-questions/OFA:$HOME/code/20-que Then go to localhost:9080/docs to use the API #### run audio server only -Just the whisper speech to text part. +Just the Parakeet speech to text part. This isn't required as the inference server automatically balances these requests ```shell PYTHONPATH=$(pwd):$(pwd)/OFA GOOGLE_APPLICATION_CREDENTIALS=secrets/google-credentials.json gunicorn -k uvicorn.workers.UvicornWorker -b :9080 audio_server.audio_server:app --timeout 180000 --workers 1 diff --git a/questions/inference_server/inference_server.py b/questions/inference_server/inference_server.py index 88278a8..c6a8736 100644 --- a/questions/inference_server/inference_server.py +++ b/questions/inference_server/inference_server.py @@ -88,8 +88,7 @@ version="1", ) -import whisper -import numpy as np +import nemo.collections.asr as nemo_asr MODEL_CACHE = ModelCache() @@ -345,21 +344,17 @@ def validate_generate_params(generate_params): def load_audio_model(): + """Load and return the Parakeet ASR model for speech to text.""" global audio_model - # about 10s - with log_time("load whisper model"): - # audio_model = whisper.load_model("large") # todo specify download_root to a fast ssd + with log_time("load parakeet model"): if not audio_model: - audio_model = whisper.load_model( - "medium", download_root="models" - ) # todo specify download_root to a fast ssd - audio_model.eval() - audio_model = audio_model.to("cuda") - logger.info( - f"Model is {'multilingual' if audio_model.is_multilingual else 'English-only'} " - f"and has {sum(np.prod(p.shape) for p in audio_model.parameters()):,} parameters." - f"and is on {audio_model.device}" - ) + audio_model = nemo_asr.models.ASRModel.from_pretrained( + model_name="nvidia/parakeet-tdt-0.6b-v2", download_root="models" + ) + audio_model.freeze() + device = "cuda" if torch.cuda.is_available() else "cpu" + audio_model = audio_model.to(device) + logger.info(f"Model loaded on {device}") return audio_model @@ -412,20 +407,21 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile): audio_bytes = response.content with torch.inference_mode(): - opts = transcribe_options # dict(beam_size=5, best_of=5) - if audio_params.translate_to_english: - opts = translate_options - # write to /dev/shm ... assume mp3 - tmp_file = NamedTemporaryFile(dir="/dev/shm", delete=True, suffix=".mp3") + tmp_file = NamedTemporaryFile(dir="/dev/shm", delete=True, suffix=".wav") tmp_file.write(audio_bytes) - result = audio_model.transcribe(tmp_file.name, **opts) - - # clean data + nemo_result = audio_model.transcribe([tmp_file.name], timestamps=True)[0] tmp_file.close() - for segment in result["segments"]: - del segment["tokens"] - result["text"] = result["text"].strip() - return result + + segments = [ + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("segment", "").strip(), + } + for seg in nemo_result.timestamp["segment"] + ] + + return {"text": nemo_result.text.strip(), "segments": segments} diff --git a/questions/inference_server/requirements.in b/questions/inference_server/requirements.in old mode 100755 new mode 100644 index 80abb76..6ce6518 --- a/questions/inference_server/requirements.in +++ b/questions/inference_server/requirements.in @@ -73,7 +73,7 @@ pypdf2>=3.0.0 pytesseract>=0.3.10 setuptools>=69.0.0 -whisper>=1.1.0 +nemo_toolkit[asr]>=2.2.0 ffmpeg-python>=0.2.0 tensorboard>=2.15.0 # download video/audio extract diff --git a/questions/inference_server/requirements.txt b/questions/inference_server/requirements.txt index 07c584f..22a3815 100644 --- a/questions/inference_server/requirements.txt +++ b/questions/inference_server/requirements.txt @@ -547,7 +547,6 @@ six==1.17.0 # -r requirements.in # python-dateutil # tensorboard - # whisper sniffio==1.3.1 # via # anthropic @@ -648,7 +647,7 @@ websockets==12.0 # via gradio-client werkzeug==3.1.2 # via tensorboard -whisper==1.1.10 +nemo-toolkit[asr]==2.2.0 # via -r requirements.in wrapt==1.17.0 # via -r requirements.in diff --git a/tests/unit/test_audio_model.py b/tests/unit/test_audio_model.py new file mode 100644 index 0000000..84fc6e5 --- /dev/null +++ b/tests/unit/test_audio_model.py @@ -0,0 +1,20 @@ +import builtins +from unittest import mock + +import questions.inference_server.inference_server as server + + +def test_load_audio_model(monkeypatch): + fake_model = mock.MagicMock() + monkeypatch.setattr( + server.nemo_asr.models.ASRModel, + "from_pretrained", + mock.MagicMock(return_value=fake_model), + ) + server.audio_model = None + model = server.load_audio_model() + assert model is fake_model + server.nemo_asr.models.ASRModel.from_pretrained.assert_called_once_with( + model_name="nvidia/parakeet-tdt-0.6b-v2" + ) + fake_model.to.assert_called()