From 4992b1380837ecffa73fe68af10e4cfbe7a77fd5 Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Tue, 8 Jul 2025 16:01:55 +1200 Subject: [PATCH] Parameterize embedding path and add CLI --- README.md | 11 +++++-- owners.md | 2 +- questions/bert_embed.py | 33 +++++++++++-------- scripts/embed_cli.py | 17 ++++++++++ tests/unit/test_modernbert_embed.py | 49 +++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 16 deletions(-) create mode 100644 scripts/embed_cli.py create mode 100644 tests/unit/test_modernbert_embed.py diff --git a/README.md b/README.md index a2b6047..77b681c 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,14 @@ The embedding model is a smaller model. ```shell cd models -git clone https://huggingface.co/distilbert-base-uncased +git clone https://huggingface.co/answerdotai/ModernBERT-base +``` +Set `BERT_CHECKPOINT` to override the default model path if needed. +`BERT_DEVICE` can be set to "cpu" to force CPU embeddings. + +Generate embeddings from the command line with: +```shell +python scripts/embed_cli.py "Hello world" ``` Parakeet ASR models will be loaded on demand and placed in the huggingface cache. @@ -273,7 +280,7 @@ clone from huggingface ``` cd models -git clone https://huggingface.co/distilbert-base-uncased +git clone https://huggingface.co/answerdotai/ModernBERT-base ``` ### maintenence diff --git a/owners.md b/owners.md index cab452a..a3ea2b8 100644 --- a/owners.md +++ b/owners.md @@ -14,7 +14,7 @@ cloudflared tunnel --url localhost:9080 --name textaudio # Speed link models to another drive can be a good idea if you have a faster SSD as swapping models is slow -sudo ln -s $HOME/code/20-questions/models/distilbert-base-uncased /models/distilbert-base-uncased +sudo ln -s $HOME/code/20-questions/models/ModernBERT-base /models/ModernBERT-base # setup tunnel diff --git a/questions/bert_embed.py b/questions/bert_embed.py index 2f48366..643eee3 100644 --- a/questions/bert_embed.py +++ b/questions/bert_embed.py @@ -1,4 +1,5 @@ from transformers import AutoModel, AutoTokenizer +import os import torch.nn as nn import torch @@ -35,26 +36,32 @@ def forward(self, x, attention_mask=None): return mean_pool -checkpoint = "models/distilbert-base-uncased" -distilbert = None -def get_distilbert(): - global distilbert +checkpoint = os.getenv("BERT_CHECKPOINT", "models/ModernBERT-base") +device_override = os.getenv("BERT_DEVICE") +if device_override: + DEVICE = device_override + +modernbert = None + + +def get_modernbert(): + global modernbert with log_time("bert load"): - if not distilbert: - distilbert = FeatureExtractModel(checkpoint, freeze=True) + if not modernbert: + modernbert = FeatureExtractModel(checkpoint, freeze=True) - distilbert.eval() + modernbert.eval() if DEVICE == "cuda": with log_time("bert to bf16"): - distilbert.bfloat16() + modernbert.bfloat16() with log_time("bert to gpu"): - distilbert.to(DEVICE) + modernbert.to(DEVICE) elif DEVICE == "cpu": logger.error("no GPU available, performance may be very slow") logger.error("consider using a GPU or many fast CPUs if you need to do this") - distilbert.to(DEVICE) - return distilbert + modernbert.to(DEVICE) + return modernbert tokenizer = AutoTokenizer.from_pretrained(checkpoint) @@ -75,7 +82,7 @@ def get_bert_embeddings(sentences, model_cache): """ returns 768 size tensor """ - distilbert = model_cache.add_or_get("distilbert", get_distilbert) + modernbert = model_cache.add_or_get("modernbert", get_modernbert) final_embeddings = list() all_embeddings = [] @@ -90,7 +97,7 @@ def get_bert_embeddings(sentences, model_cache): return_attention_mask=True, padding=True, ) - embeddings = distilbert(tokens) + embeddings = modernbert(tokens) final_embeddings.extend(embeddings) all_embeddings = torch.stack(final_embeddings) return all_embeddings.cpu().float().numpy().tolist() diff --git a/scripts/embed_cli.py b/scripts/embed_cli.py new file mode 100644 index 0000000..f7b2c42 --- /dev/null +++ b/scripts/embed_cli.py @@ -0,0 +1,17 @@ +import sys +from questions.bert_embed import get_bert_embeddings_fast +from questions.inference_server.model_cache import ModelCache + + +def main() -> None: + if len(sys.argv) < 2: + print("Usage: embed_cli.py [sentence2 ...]") + return + cache = ModelCache() + embeddings = get_bert_embeddings_fast(sys.argv[1:], cache) + for emb in embeddings: + print(emb) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_modernbert_embed.py b/tests/unit/test_modernbert_embed.py new file mode 100644 index 0000000..d6f4f84 --- /dev/null +++ b/tests/unit/test_modernbert_embed.py @@ -0,0 +1,49 @@ +import importlib +import sys +import types +import pytest + +pytest.skip("torch not available in this environment", allow_module_level=True) + +class DummyCache: + def __init__(self): + self.model = None + def add_or_get(self, name, fn): + if self.model is None: + self.model = fn() + return self.model + +def _fake_model(*args, **kwargs): + class DummyModel(torch.nn.Module): + config = types.SimpleNamespace(hidden_size=8) + def forward(self, input_ids, attention_mask, return_dict=True): + batch = input_ids.size(0) + return types.SimpleNamespace(last_hidden_state=torch.zeros((batch, 1, 8))) + return DummyModel() + +class FakeTokens(dict): + def to(self, device): + return self + + +def _fake_tokenizer(text, truncation=None, return_tensors=None, return_attention_mask=None, padding=None): + return FakeTokens({"input_ids": torch.tensor([[1]]), "attention_mask": torch.tensor([[1]])}) + +def test_get_modernbert_uses_checkpoint(monkeypatch): + calls = {} + monkeypatch.setattr("transformers.AutoModel.from_pretrained", lambda cp: (_fake_model(), calls.setdefault("cp", cp))[0]) + monkeypatch.setattr("transformers.AutoTokenizer.from_pretrained", lambda cp: _fake_tokenizer) + bert_embed = importlib.reload(sys.modules.get("questions.bert_embed")) if "questions.bert_embed" in sys.modules else importlib.import_module("questions.bert_embed") + bert_embed.modernbert = None + model = bert_embed.get_modernbert() + assert calls["cp"] == bert_embed.checkpoint + assert model is bert_embed.modernbert + + +def test_get_bert_embeddings_fast(monkeypatch): + monkeypatch.setattr("transformers.AutoModel.from_pretrained", lambda cp: _fake_model()) + monkeypatch.setattr("transformers.AutoTokenizer.from_pretrained", lambda cp: _fake_tokenizer) + bert_embed = importlib.reload(sys.modules.get("questions.bert_embed")) if "questions.bert_embed" in sys.modules else importlib.import_module("questions.bert_embed") + bert_embed.modernbert = None + result = bert_embed.get_bert_embeddings_fast(["hello"], DummyCache()) + assert isinstance(result, list) and isinstance(result[0], list)