Skip to content

Commit efb2b00

Browse files
authored
refactor: use default device for embeddings
1 parent 08c5df0 commit efb2b00

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

mostlyai/qa/_sampling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import numpy as np
3333
import pandas as pd
3434
import pyarrow as pa
35-
import torch
3635

3736
from mostlyai.qa._common import (
3837
CTX_COLUMN_PREFIX,
@@ -243,7 +242,7 @@ def calculate_embeddings(
243242
) -> np.ndarray:
244243
t0 = time.time()
245244
# load embedder
246-
embedder = load_embedder(device="cuda" if torch.cuda.is_available() else "cpu")
245+
embedder = load_embedder()
247246
# split into buckets for calculating embeddings to avoid memory issues and report continuous progress
248247
steps = progress_to - progress_from if progress_to is not None and progress_from is not None else 1
249248
buckets = np.array_split(strings, steps)

mostlyai/qa/assets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def load_tokenizer():
3737
return GPT2Tokenizer.from_pretrained(_MODULE_DIR / "tokenizers" / "transformers" / "gpt2")
3838

3939

40-
def load_embedder(device: str):
40+
def load_embedder():
4141
from sentence_transformers import SentenceTransformer
4242

4343
path = _MODULE_DIR / "embedders" / "sentence-transformers" / "all-MiniLM-L6-v2"
44-
return SentenceTransformer(str(path), local_files_only=True, device=device)
44+
return SentenceTransformer(str(path), local_files_only=True)

0 commit comments

Comments
 (0)