Skip to content

Commit fdba8a8

Browse files
authored
Merge pull request #22 from lightonai/accelerate-search-and-reduce-memory-usage
Reduce memory usage of Fast-Plaid
2 parents 4348cdd + 455f87d commit fdba8a8

File tree

23 files changed

+1331
-503
lines changed

23 files changed

+1331
-503
lines changed

.github/workflows/publish-280.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ on:
33

44
jobs:
55
build_wheels:
6-
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} - 271
6+
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} - 280
77
runs-on: ${{ matrix.os }}
88
strategy:
99
fail-fast: false

.github/workflows/publish-290.yaml

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
on:
2+
workflow_dispatch:
3+
4+
jobs:
5+
build_wheels:
6+
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} - 290
7+
runs-on: ${{ matrix.os }}
8+
strategy:
9+
fail-fast: false
10+
matrix:
11+
os: [ubuntu-latest, macos-latest, windows-latest]
12+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
13+
14+
steps:
15+
- name: Checkout code
16+
uses: actions/checkout@v4
17+
18+
- name: Prepare pyproject.toml
19+
shell: bash
20+
run: |
21+
echo "Deleting existing pyproject.toml (if any)..."
22+
rm -f pyproject.toml
23+
echo "Using ci-290.toml as pyproject.toml..."
24+
cp ci-290.toml pyproject.toml
25+
echo "File preparation complete."
26+
27+
- name: Set up Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v5
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
32+
- name: Install Rust
33+
uses: dtolnay/rust-toolchain@stable
34+
35+
- name: Prepare Python Version for CIBW_BUILD
36+
id: prepare_python_version
37+
shell: bash
38+
run: |
39+
PYTHON_VERSION_NO_DOT=$(echo "${{ matrix.python-version }}" | tr -d '.')
40+
echo "PYTHON_VERSION_NO_DOT=$PYTHON_VERSION_NO_DOT" >> $GITHUB_OUTPUT
41+
42+
- name: Build wheels
43+
uses: pypa/cibuildwheel@v2.23.3
44+
env:
45+
CIBW_BUILD: "cp${{ steps.prepare_python_version.outputs.PYTHON_VERSION_NO_DOT }}-*"
46+
CIBW_SKIP: "*-manylinux_i686 *-musllinux_* *-win32"
47+
CIBW_BUILD_VERBOSITY: 1
48+
LIBTORCH_BYPASS_VERSION_CHECK: 1
49+
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28
50+
CIBW_MANYLINUX_ARM64_IMAGE: manylinux_2_28
51+
CIBW_ENVIRONMENT: |
52+
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu/torch_stable.html"
53+
CIBW_PIP_ARGS: --no-cache-dir
54+
CIBW_BEFORE_BUILD: "pip install torch==2.9.0 numpy maturin delvewheel"
55+
CIBW_REPAIR_WHEEL_COMMAND_LINUX: >
56+
LD_LIBRARY_PATH=$(python -c 'import torch, os; print(os.path.join(os.path.dirname(torch.__file__), "lib"))'):$LD_LIBRARY_PATH auditwheel repair -w {dest_dir} {wheel} --exclude libtorch.so --exclude libtorch_cpu.so --exclude libtorch_python.so
57+
CIBW_REPAIR_WHEEL_COMMAND_MACOS: >
58+
DYLD_LIBRARY_PATH=$(python -c 'import torch, os; print(os.path.join(os.path.dirname(torch.__file__), "lib"))') delocate-wheel -w {dest_dir} -v {wheel} --exclude libtorch.dylib --exclude libtorch_cpu.dylib --exclude libtorch_python.dylib
59+
CIBW_REPAIR_WHEEL_COMMAND_WINDOWS: >-
60+
FOR /F "usebackq tokens=*" %i IN (`python -c "import torch, os; print(os.path.join(os.path.dirname(torch.__file__), 'lib'))"`) DO (set "PATH=%i;%PATH%" && delvewheel repair -w {dest_dir} {wheel} --no-dll torch.dll --no-dll torch_cpu.dll --no-dll torch_python.dll)
61+
62+
- name: Upload wheels to artifact
63+
uses: actions/upload-artifact@v4
64+
with:
65+
name: wheels-${{ matrix.os }}-py${{ matrix.python-version }}
66+
path: ./wheelhouse/*.whl
67+
68+
publish:
69+
name: Publish 290 to PyPI
70+
needs: build_wheels
71+
runs-on: ubuntu-latest
72+
permissions:
73+
id-token: write
74+
75+
steps:
76+
- name: Download all wheels
77+
uses: actions/download-artifact@v4
78+
with:
79+
pattern: wheels-*-py*
80+
path: dist
81+
merge-multiple: true
82+
83+
- name: Publish to PyPI
84+
uses: pypa/gh-action-pypi-publish@v1.12.4
85+
with:
86+
user: __token__
87+
password: ${{ secrets.PYPI_API_TOKEN }}
88+
skip-existing: true

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,5 @@ quora/
156156
nq/
157157
dbpedia-entity/
158158
hotpotqa/
159-
msmarco/
159+
msmarco/
160+
index/

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "fast_plaid_rust"
3-
version = "1.2.4"
3+
version = "1.2.5"
44
edition = "2021"
55
build = "build.rs"
66

@@ -18,7 +18,7 @@ serde_json = "1.0.140"
1818
libc = "0.2.172"
1919
parking_lot = "0.12.3"
2020
once_cell = "1.21.3"
21-
indicatif = "0.17.11"
21+
indicatif = "0.18.2"
2222
pyo3 = { version = "0.24.2", features = ["extension-module"] }
2323
pyo3-tch = "0.20.0"
2424
rand = "0.9.1"

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
lint:
22
cargo clean
3-
uv pip install torch==2.8.0
3+
uv pip install torch==2.9.0
44
uv run --extra dev pre-commit run --files python/**/**/**.py
55

66
install:
77
cargo clean
8-
uv pip install torch==2.8.0
8+
uv pip install torch==2.9.0
99
uv pip install -e ".[dev]"
1010

1111
test:
1212
cargo clean
1313
uv run tests/test.py
1414

1515
evaluate:
16-
uv run test.py
16+
uv run benchmark/benchmark.py

README.md

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ FastPlaid is available in multiple versions to support different PyTorch version
4444

4545
| FastPlaid Version | PyTorch Version | Installation Command |
4646
| ----------------- | --------------- | ----------------------------------- |
47-
| 1.2.4.280 | 2.8.0 | `pip install fast-plaid==1.2.4.280` |
48-
| 1.2.4.271 | 2.7.1 | `pip install fast-plaid==1.2.4.271` |
49-
| 1.2.4.270 | 2.7.0 | `pip install fast-plaid==1.2.4.270` |
50-
| 1.2.4.260 | 2.6.0 | `pip install fast-plaid==1.2.4.260` |
47+
| 1.2.5.290 | 2.9.0 | `pip install fast-plaid==1.2.5.290` |
48+
| 1.2.5.280 | 2.8.0 | `pip install fast-plaid==1.2.5.280` |
49+
| 1.2.5.271 | 2.7.1 | `pip install fast-plaid==1.2.5.271` |
50+
| 1.2.5.270 | 2.7.0 | `pip install fast-plaid==1.2.5.270` |
51+
| 1.2.5.260 | 2.6.0 | `pip install fast-plaid==1.2.5.260` |
5152

5253
### Adding FastPlaid as a Dependency
5354

@@ -56,23 +57,23 @@ You can add FastPlaid to your project dependencies with version ranges to ensure
5657
**For requirements.txt:**
5758

5859
```
59-
fast-plaid>=1.2.4.260,<=1.2.4.280
60+
fast-plaid>=1.2.5.260,<=1.2.5.290
6061
```
6162

6263
**For pyproject.toml:**
6364

6465
```toml
6566
[project]
6667
dependencies = [
67-
"fast-plaid>=1.2.4.260,<=1.2.4.280"
68+
"fast-plaid>=1.2.5.260,<=1.2.5.290"
6869
]
6970
```
7071

7172
**For setup.py:**
7273

7374
```python
7475
install_requires=[
75-
"fast-plaid>=1.2.4.260,<=1.2.4.280"
76+
"fast-plaid>=1.2.5.260,<=1.2.5.290"
7677
]
7778
```
7879

@@ -316,6 +317,7 @@ class FastPlaid:
316317
self,
317318
index: str,
318319
device: str | list[str] | None = None,
320+
preload_index: bool = True,
319321
) -> None:
320322
```
321323

@@ -331,6 +333,11 @@ device: str | list[str] | None = None
331333
- Can be a list of device strings (e.g., ["cuda:0", "cuda:1"]).
332334
- If multiple GPUs are specified and available, multiprocessing is automatically set up for parallel execution.
333335
Remember to include your code within an `if __name__ == "__main__":` block for proper multiprocessing behavior.
336+
337+
preload_index: bool = True (optional)
338+
If `True`, the index will be loaded into memory upon initialization. This can
339+
speed up the first search operation by "warming up" the index. If `False`,
340+
the index will be loaded when doing the search and unloaded afterward.
334341
```
335342

336343
### Creating an Index
@@ -345,6 +352,7 @@ The **`create` method** builds the multi-vector index from your document embeddi
345352
max_points_per_centroid: int = 256,
346353
nbits: int = 4,
347354
n_samples_kmeans: int | None = None,
355+
batch_size: int = 25_000,
348356
seed: int = 42,
349357
use_triton_kmeans: bool | None = None,
350358
metadata: list[dict[str, Any]] | None = None,
@@ -376,6 +384,9 @@ n_samples_kmeans: int | None = None (optional)
376384
clustering quality. If you have a large dataset, you might want to set this to a
377385
smaller value to speed up the indexing process and save some memory.
378386

387+
batch_size: int = 25_000 (optional)
388+
Batch size for processing embeddings during index creation.
389+
379390
seed: int = 42 (optional)
380391
Seed for the random number generator used in index creation.
381392
Setting this ensures reproducible results across multiple runs.
@@ -402,6 +413,7 @@ The **`update` method** provides an efficient way to add new documents to an exi
402413
self,
403414
documents_embeddings: list[torch.Tensor] | torch.Tensor,
404415
metadata: list[dict[str, Any]] | None = None,
416+
batch_size: int = 25_000,
405417
) -> "FastPlaid":
406418
```
407419

@@ -416,6 +428,9 @@ metadata: list[dict[str, Any]] | None = None
416428
Each dictionary can contain arbitrary key-value pairs that you want to associate with the document.
417429
If provided, the length of this list must match the number of new documents being added.
418430
The metadata will be stored in a SQLite database within the index directory for filtering during searches.
431+
432+
batch_size: int = 25_000 (optional)
433+
Batch size for processing embeddings during the update.
419434
```
420435

421436
### Searching the Index
@@ -427,7 +442,7 @@ The **`search` method** lets you query the created index with your query embeddi
427442
self,
428443
queries_embeddings: torch.Tensor | list[torch.Tensor],
429444
top_k: int = 10,
430-
batch_size: int = 1 << 18,
445+
batch_size: int = 25_000,
431446
n_full_scores: int = 4096,
432447
n_ivf_probe: int = 8,
433448
show_progress: bool = True,
@@ -444,7 +459,7 @@ queries_embeddings: torch.Tensor | list[torch.Tensor]
444459
top_k: int = 10 (optional)
445460
The number of top-scoring documents to retrieve for each query.
446461

447-
batch_size: int = 1 << 18 (optional)
462+
batch_size: int = 25_000 (optional)
448463
The internal batch size used for processing queries.
449464
A larger batch size might improve throughput on powerful GPUs but can consume more memory.
450465

benchmark/benchmark.py

Lines changed: 5 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from fast_plaid import evaluation, search
1010
from pylate import models
1111

12+
print("Torch version:", torch.__version__)
13+
1214
parser = argparse.ArgumentParser(
1315
description="Run Fast-PLAiD evaluation on a BEIR dataset."
1416
)
@@ -88,11 +90,11 @@
8890

8991
large_queries_embeddings = torch.cat(
9092
([queries_embeddings] * ((1000 // queries_embeddings.shape[0]) + 1))[:1000]
91-
)
93+
).to("cpu")
9294

9395
print(f"🔍 50_000 queries on {dataset_name}...")
9496
start_search = time.time()
95-
_ = index.search(queries_embeddings=large_queries_embeddings)
97+
_ = index.search(queries_embeddings=large_queries_embeddings, top_k=10, n_full_scores=4096, n_ivf_probe=8)
9698
end_search = time.time()
9799
heavy_search_time = end_search - start_search
98100
queries_per_second = large_queries_embeddings.shape[0] / heavy_search_time
@@ -139,71 +141,4 @@
139141
with open(output_filepath, "w") as f:
140142
json.dump(output_data, f, indent=4)
141143

142-
print(f"🎉 Finished evaluation for dataset: {dataset_name}\n")
143-
144-
# Pylate
145-
146-
from pylate import evaluation, indexes, retrieve
147-
148-
index = indexes.PLAID(
149-
override=True,
150-
index_name=f"{dataset_name}_pylate",
151-
embedding_size=96,
152-
nbits=4,
153-
)
154-
155-
retriever = retrieve.ColBERT(index=index)
156-
157-
start = time.time()
158-
index.add_documents(
159-
documents_ids=[document["id"] for document in documents],
160-
documents_embeddings=documents_embeddings,
161-
)
162-
end = time.time()
163-
indexing_time = end - start
164-
print(f"🏗️ Pylate index on {dataset_name}: {end - start:.2f} seconds")
165-
166-
start = time.time()
167-
scores = retriever.retrieve(queries_embeddings=queries_embeddings, k=20)
168-
end = time.time()
169-
search_time = end - start
170-
print(f"🔍 Pylate search on {dataset_name}: {search_time:.2f} seconds")
171-
172-
173-
start = time.time()
174-
_ = retriever.retrieve(queries_embeddings=large_queries_embeddings, k=20)
175-
end = time.time()
176-
heavy_search_time = end - start
177-
queries_per_second = large_queries_embeddings.shape[0] / heavy_search_time
178-
179-
for (query_id, query), query_scores in zip(queries.items(), scores):
180-
for score in query_scores:
181-
if score["id"] == query_id:
182-
# Remove the query_id from the score
183-
query_scores.remove(score)
184-
185-
evaluation_scores = evaluation.evaluate(
186-
scores=scores,
187-
qrels=qrels,
188-
queries=list(queries.values()),
189-
metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"],
190-
)
191-
192-
print(f"\n--- 📈 Final Scores for {dataset_name} (Pylate) ---")
193-
print(evaluation_scores)
194-
195-
output_data = {
196-
"dataset": dataset_name,
197-
"indexing": round(indexing_time, 3),
198-
"search": round(search_time, 3),
199-
"qps": round(queries_per_second, 2),
200-
"size": len(documents),
201-
"queries": num_queries,
202-
"scores": evaluation_scores,
203-
}
204-
205-
output_filepath = os.path.join(output_dir, f"{dataset_name}_pylate.json")
206-
with open(output_filepath, "w") as f:
207-
json.dump(output_data, f, indent=4)
208-
209-
print(f"💾 Exporting Pylate results to {output_filepath}")
144+
print(f"🎉 Finished evaluation for dataset: {dataset_name}\n")

0 commit comments

Comments
 (0)