Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,14 @@ The embedding model is a smaller model.

```shell
cd models
git clone https://huggingface.co/distilbert-base-uncased
git clone https://huggingface.co/answerdotai/ModernBERT-base
```
Set `BERT_CHECKPOINT` to override the default model path if needed.
`BERT_DEVICE` can be set to "cpu" to force CPU embeddings.

Generate embeddings from the command line with:
```shell
python scripts/embed_cli.py "Hello world"
```

Parakeet ASR models will be loaded on demand and placed in the huggingface cache.
Expand Down Expand Up @@ -273,7 +280,7 @@ clone from huggingface

```
cd models
git clone https://huggingface.co/distilbert-base-uncased
git clone https://huggingface.co/answerdotai/ModernBERT-base
```

### maintenence
Expand Down
2 changes: 1 addition & 1 deletion owners.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ cloudflared tunnel --url localhost:9080 --name textaudio
# Speed

link models to another drive can be a good idea if you have a faster SSD as swapping models is slow
sudo ln -s $HOME/code/20-questions/models/distilbert-base-uncased /models/distilbert-base-uncased
sudo ln -s $HOME/code/20-questions/models/ModernBERT-base /models/ModernBERT-base


# setup tunnel
Expand Down
33 changes: 20 additions & 13 deletions questions/bert_embed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from transformers import AutoModel, AutoTokenizer
import os
import torch.nn as nn
import torch

Expand Down Expand Up @@ -35,26 +36,32 @@ def forward(self, x, attention_mask=None):
return mean_pool


checkpoint = "models/distilbert-base-uncased"
distilbert = None
def get_distilbert():
global distilbert
checkpoint = os.getenv("BERT_CHECKPOINT", "models/ModernBERT-base")
device_override = os.getenv("BERT_DEVICE")
if device_override:
DEVICE = device_override

modernbert = None


def get_modernbert():
global modernbert
with log_time("bert load"):
if not distilbert:
distilbert = FeatureExtractModel(checkpoint, freeze=True)
if not modernbert:
modernbert = FeatureExtractModel(checkpoint, freeze=True)

distilbert.eval()
modernbert.eval()
if DEVICE == "cuda":
with log_time("bert to bf16"):
distilbert.bfloat16()
modernbert.bfloat16()

with log_time("bert to gpu"):
distilbert.to(DEVICE)
modernbert.to(DEVICE)
elif DEVICE == "cpu":
logger.error("no GPU available, performance may be very slow")
logger.error("consider using a GPU or many fast CPUs if you need to do this")
distilbert.to(DEVICE)
return distilbert
modernbert.to(DEVICE)
return modernbert

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Expand All @@ -75,7 +82,7 @@ def get_bert_embeddings(sentences, model_cache):
"""
returns 768 size tensor
"""
distilbert = model_cache.add_or_get("distilbert", get_distilbert)
modernbert = model_cache.add_or_get("modernbert", get_modernbert)
final_embeddings = list()
all_embeddings = []

Expand All @@ -90,7 +97,7 @@ def get_bert_embeddings(sentences, model_cache):
return_attention_mask=True,
padding=True,
)
embeddings = distilbert(tokens)
embeddings = modernbert(tokens)
final_embeddings.extend(embeddings)
all_embeddings = torch.stack(final_embeddings)
return all_embeddings.cpu().float().numpy().tolist()
Expand Down
17 changes: 17 additions & 0 deletions scripts/embed_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import sys
from questions.bert_embed import get_bert_embeddings_fast
from questions.inference_server.model_cache import ModelCache


def main() -> None:
if len(sys.argv) < 2:
print("Usage: embed_cli.py <sentence1> [sentence2 ...]")
return
cache = ModelCache()
embeddings = get_bert_embeddings_fast(sys.argv[1:], cache)
for emb in embeddings:
print(emb)


if __name__ == "__main__":
main()
49 changes: 49 additions & 0 deletions tests/unit/test_modernbert_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import importlib
import sys
import types
import pytest

pytest.skip("torch not available in this environment", allow_module_level=True)

class DummyCache:
def __init__(self):
self.model = None
def add_or_get(self, name, fn):
if self.model is None:
self.model = fn()
return self.model

def _fake_model(*args, **kwargs):
class DummyModel(torch.nn.Module):
config = types.SimpleNamespace(hidden_size=8)
def forward(self, input_ids, attention_mask, return_dict=True):
batch = input_ids.size(0)
return types.SimpleNamespace(last_hidden_state=torch.zeros((batch, 1, 8)))
return DummyModel()

class FakeTokens(dict):
def to(self, device):
return self


def _fake_tokenizer(text, truncation=None, return_tensors=None, return_attention_mask=None, padding=None):
return FakeTokens({"input_ids": torch.tensor([[1]]), "attention_mask": torch.tensor([[1]])})

def test_get_modernbert_uses_checkpoint(monkeypatch):
calls = {}
monkeypatch.setattr("transformers.AutoModel.from_pretrained", lambda cp: (_fake_model(), calls.setdefault("cp", cp))[0])
monkeypatch.setattr("transformers.AutoTokenizer.from_pretrained", lambda cp: _fake_tokenizer)
bert_embed = importlib.reload(sys.modules.get("questions.bert_embed")) if "questions.bert_embed" in sys.modules else importlib.import_module("questions.bert_embed")
bert_embed.modernbert = None
model = bert_embed.get_modernbert()
assert calls["cp"] == bert_embed.checkpoint
assert model is bert_embed.modernbert


def test_get_bert_embeddings_fast(monkeypatch):
monkeypatch.setattr("transformers.AutoModel.from_pretrained", lambda cp: _fake_model())
monkeypatch.setattr("transformers.AutoTokenizer.from_pretrained", lambda cp: _fake_tokenizer)
bert_embed = importlib.reload(sys.modules.get("questions.bert_embed")) if "questions.bert_embed" in sys.modules else importlib.import_module("questions.bert_embed")
bert_embed.modernbert = None
result = bert_embed.get_bert_embeddings_fast(["hello"], DummyCache())
assert isinstance(result, list) and isinstance(result[0], list)
Loading