Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@ __pycache__/
*.egg-info/
dist/
build/
*.pkl
*.npy

# Large binary files
embeddings.npy
faiss_index/
data/
qdrant_storage/
logs/
checkpoints/

# Jupyter
.ipynb_checkpoints/
Expand Down
5,154 changes: 5,154 additions & 0 deletions results/test_queries_results.json

Large diffs are not rendered by default.

145 changes: 145 additions & 0 deletions submission/bm25_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# =============================================================================
# bm25_index.py — BM25-Okapi sparse retriever (rank_bm25)
# =============================================================================
"""
Builds a BM25-Okapi index over the corpus and caches it as a pickle so
tokenisation is not repeated on subsequent runs.

Public API
----------
from bm25_index import BM25Index

idx = BM25Index()
idx.build(documents, doc_ids, ck)
results = idx.search(query, top_k=50)
# results → list of {"doc_id": str, "score": float, "corpus_idx": int}
"""

from __future__ import annotations

import pickle
from pathlib import Path

import numpy as np
from rank_bm25 import BM25Okapi
from tqdm import tqdm

from checkpoint import Checkpoint
from config import BM25_CACHE_PATH, SPARSE_TOP_K
from logger import get_logger

log = get_logger(__name__)

_CACHE = Path(BM25_CACHE_PATH)


def _tokenise(text: str) -> list[str]:
"""Whitespace tokenisation (fast; good enough for BM25 on English text)."""
return text.lower().split()


class BM25Index:
"""
Wrapper around BM25Okapi that adds checkpoint-aware build/load and a
consistent search interface.
"""

def __init__(self) -> None:
self._bm25: BM25Okapi | None = None
self._doc_ids: list[str] = []

def build(
self,
documents: list[str],
doc_ids: list[str],
ck: Checkpoint,
) -> None:
"""
Tokenise all documents and build the BM25 index.

If the checkpoint phase 'bm25_indexed' is already done and the pickle
cache exists, the index is loaded from disk without re-tokenising.

Parameters
----------
documents : list[str] — title-prepended document texts
doc_ids : list[str] — original document IDs (same order as documents)
ck : Checkpoint — shared pipeline checkpoint
"""
self._doc_ids = doc_ids

if ck.done("bm25_indexed") and _CACHE.exists():
log.info("Phase 'bm25_indexed' done — loading BM25 from %s", _CACHE)
self._load()
return

log.info("=== Phase: bm25_indexed ===")
log.info("Tokenising %d documents for BM25 …", len(documents))

tokenised = [_tokenise(doc) for doc in tqdm(documents, desc="BM25 tokenise")]

avg_len = np.mean([len(t) for t in tokenised])
log.debug("Avg token count per document: %.1f", avg_len)

log.info("Building BM25Okapi index …")
self._bm25 = BM25Okapi(tokenised)

self._save()
ck.mark_done("bm25_indexed", bm25_corpus_size=len(documents))
log.info("BM25 index built and cached → %s", _CACHE)

def _save(self) -> None:
_CACHE.parent.mkdir(parents=True, exist_ok=True)
payload = {"bm25": self._bm25, "doc_ids": self._doc_ids}
with open(_CACHE, "wb") as fh:
pickle.dump(payload, fh, protocol=pickle.HIGHEST_PROTOCOL)
log.debug("BM25 index saved to %s", _CACHE)

def _load(self) -> None:
with open(_CACHE, "rb") as fh:
payload = pickle.load(fh)
self._bm25 = payload["bm25"]
self._doc_ids = payload["doc_ids"]
log.info(
"BM25 index loaded corpus_size=%d",
len(self._doc_ids),
)

def search(self, query: str, top_k: int = SPARSE_TOP_K) -> list[dict]:
"""
Score all documents against *query* and return the top-k.

Parameters
----------
query : str — raw query string (tokenised internally)
top_k : int — number of candidates to return

Returns
-------
list of dicts:
{"doc_id": str, "score": float, "corpus_idx": int}
Sorted by score descending.
"""
if self._bm25 is None:
raise RuntimeError("BM25Index.build() must be called before search()")

tokens = _tokenise(query)
log.debug("BM25 search: query_tokens=%s top_k=%d", tokens, top_k)

scores = self._bm25.get_scores(tokens) # shape (N,)
top_indices = np.argsort(scores)[::-1][:top_k].tolist()

results = [
{
"doc_id": self._doc_ids[idx],
"score": float(scores[idx]),
"corpus_idx": idx,
}
for idx in top_indices
]
log.debug(
"BM25 returned %d hits (top score=%.4f)",
len(results),
results[0]["score"] if results else 0.0,
)
return results
133 changes: 133 additions & 0 deletions submission/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# =============================================================================
# checkpoint.py — atomic JSON checkpoint manager
# =============================================================================
"""
State file (checkpoints/pipeline_state.json):

{
"phases": {
"datasets_downloaded": true,
"embeddings_computed": false,
...
},
"meta": {
"corpus_size": 12345,
"embed_model": "Qwen/Qwen3-Embedding-0.6B",
"datasets_downloaded_completed_at": "2025-06-10T12:34:56+00:00",
...
}
}

Writes are atomic: we write to a `.tmp` file then rename, so a crash
mid-write never corrupts the checkpoint.

Usage
-----
from checkpoint import Checkpoint

ck = Checkpoint()

if not ck.done("qdrant_indexed"):
do_indexing()
ck.mark_done("qdrant_indexed", corpus_size=12345)

ck.summary() # pretty-print all phase statuses

# Force a phase (and everything after) to re-run:
ck.reset_from("embeddings_computed")
"""

from __future__ import annotations

import json
import os
from datetime import datetime, timezone
from pathlib import Path

from config import CHECKPOINT_FILE, PHASES
from logger import get_logger

log = get_logger(__name__)


class Checkpoint:
def __init__(self, path: Path = CHECKPOINT_FILE) -> None:
self.path = Path(path)
self.path.parent.mkdir(parents=True, exist_ok=True)
self._state = self._load()
log.debug("Checkpoint loaded from %s", self.path)

# ── I/O ──────────────────────────────────────────────────────────────────
def _load(self) -> dict:
if self.path.exists():
with open(self.path, encoding="utf-8") as fh:
state = json.load(fh)
# Back-fill any phases added to config after the file was created
for phase in PHASES:
state["phases"].setdefault(phase, False)
log.debug("Loaded existing checkpoint: %s", state["phases"])
return state

log.info("No checkpoint found — starting fresh at %s", self.path)
return {"phases": {p: False for p in PHASES}, "meta": {}}

def _save(self) -> None:
"""Atomic write: tmp file → rename."""
tmp = self.path.with_suffix(".tmp")
with open(tmp, "w", encoding="utf-8") as fh:
json.dump(self._state, fh, indent=2)
os.replace(tmp, self.path)
log.debug("Checkpoint saved to %s", self.path)

# ── Phase control ─────────────────────────────────────────────────────────
def done(self, phase: str) -> bool:
"""Return True if *phase* has already been completed."""
result = self._state["phases"].get(phase, False)
log.debug("Phase '%s' done=%s", phase, result)
return result

def mark_done(self, phase: str, **meta) -> None:
"""Mark *phase* as completed and persist. Extra kwargs go into meta."""
self._state["phases"][phase] = True
ts = datetime.now(timezone.utc).isoformat()
self._state["meta"][f"{phase}_completed_at"] = ts
for k, v in meta.items():
self._state["meta"][k] = v
self._save()
log.info("✓ Phase '%s' marked done (ts=%s)", phase, ts)

def reset(self, phase: str) -> None:
"""Force a single phase to re-run on next execution."""
self._state["phases"][phase] = False
self._save()
log.info("↺ Phase '%s' reset — will re-run", phase)

def reset_from(self, phase: str) -> None:
"""Reset *phase* and every phase after it in PHASES order."""
if phase not in PHASES:
raise ValueError(f"Unknown phase '{phase}'. Valid: {PHASES}")
idx = PHASES.index(phase)
for p in PHASES[idx:]:
self._state["phases"][p] = False
self._save()
log.info("↺ Phases reset from '%s' onwards: %s", phase, PHASES[idx:])

# ── Metadata ──────────────────────────────────────────────────────────────
def set_meta(self, key: str, value) -> None:
self._state["meta"][key] = value
self._save()
log.debug("Meta set: %s = %r", key, value)

def get_meta(self, key: str, default=None):
return self._state["meta"].get(key, default)

# ── Diagnostics ───────────────────────────────────────────────────────────
def summary(self) -> None:
print("\n── Checkpoint Summary ──────────────────────────────")
for phase in PHASES:
ts_key = f"{phase}_completed_at"
ts = self._state["meta"].get(ts_key, "")
status = "✓ done " if self._state["phases"].get(phase) else "✗ pending"
suffix = f" @ {ts}" if ts else ""
print(f" {status} {phase}{suffix}")
print("────────────────────────────────────────────────────\n")
55 changes: 55 additions & 0 deletions submission/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# =============================================================================
# config.py — central configuration for DevRev hybrid search
# =============================================================================
from __future__ import annotations
import os
from pathlib import Path

# Set via env-var: export HF_TOKEN="hf_..."
# Or pass directly to the pipeline via --hf-token CLI arg.
HF_TOKEN: str | None = os.getenv("HF_TOKEN")

# Embedder: Qwen3-Embedding-0.6B
# • sentence-transformers compatible
# • 1024-d, 32k context, instruction-aware
# • Use prompt_name="query" for queries; plain encode for documents
EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-0.6B"
EMBED_DIM = 1024
EMBED_BATCH_SIZE = 32 # lower to 8-16 if GPU < 8 GB VRAM
EMBED_MAX_LENGTH = 512 # practical cap; model supports 32k but slow on CPU

# Reranker: Qwen3-Reranker-0.6B (matching family, instruction-aware)
RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-0.6B"
RERANKER_MAX_LEN = 512

QDRANT_PATH = "qdrant_storage"
QDRANT_COLLECTION = "devrev_kb"

BM25_CACHE_PATH = "checkpoints/bm25_index.pkl"

DENSE_TOP_K = 50 # candidates pulled from Qdrant per query
SPARSE_TOP_K = 50 # candidates pulled from BM25 per query
RRF_K = 60 # RRF constant (standard; higher → less top-rank emphasis)
RERANK_POOL = 100 # top-N from RRF sent to cross-encoder
FINAL_TOP_K = 10 # results returned per query

HF_DATASET_ID = "devrev/search"
DATA_DIR = Path("data")
RESULTS_DIR = Path("results")

EVAL_K = 10 # Recall@K and Precision@K

CHECKPOINT_FILE = Path("checkpoints/pipeline_state.json")

# Ordered phases — each is marked done atomically after it completes.
# Re-running the pipeline skips any phase already marked done.
PHASES = [
"datasets_downloaded", # HF → local parquet
"embeddings_computed", # corpus → .npy cache
"qdrant_indexed", # .npy → Qdrant collection
"bm25_indexed", # corpus → BM25 pickle
"test_results_saved", # test queries → results JSON
]

LOG_DIR = Path("logs")
LOG_LEVEL = "DEBUG" # DEBUG | INFO | WARNING
Loading
Loading