Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "lean-explore"
version = "1.0.0"
version = "1.1.0"
authors = [
{ name = "Justin Asher", email = "justinchadwickasher@gmail.com" },
]
Expand Down
235 changes: 126 additions & 109 deletions src/lean_explore/cli/data_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,26 @@
"""Manages local Lean Explore data toolchains.

Provides CLI commands to download, install, and clean data files (database,
FAISS index, etc.) from remote storage using Pooch for checksums and caching.
FAISS index, BM25 indexes, etc.) from remote storage.
"""

import logging
import shutil
from typing import TypedDict
from pathlib import Path

import pooch
import requests
import typer
from rich.console import Console
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TransferSpeedColumn,
)

from lean_explore.config import Config


class ManifestFileEntry(TypedDict):
"""A file entry in the manifest's toolchain version."""

remote_name: str
local_name: str
sha256: str


class ToolchainVersionInfo(TypedDict):
"""Version information for a specific toolchain in the manifest."""

assets_base_path_r2: str
files: list[ManifestFileEntry]


class Manifest(TypedDict):
"""Remote data manifest structure."""

default_toolchain: str
toolchains: dict[str, ToolchainVersionInfo]

logger = logging.getLogger(__name__)

app = typer.Typer(
Expand All @@ -48,64 +32,79 @@ class Manifest(TypedDict):
no_args_is_help=True,
)

# Files required for the search engine (relative to version directory)
REQUIRED_FILES: list[str] = [
"lean_explore.db",
"informalization_faiss.index",
"informalization_faiss_ids_map.json",
"bm25_ids_map.json",
]

# BM25 index directories and their contents
BM25_DIRECTORIES: dict[str, list[str]] = {
"bm25_name_raw": [
"data.csc.index.npy",
"indices.csc.index.npy",
"indptr.csc.index.npy",
"nonoccurrence_array.index.npy",
"params.index.json",
"vocab.index.json",
],
"bm25_name_spaced": [
"data.csc.index.npy",
"indices.csc.index.npy",
"indptr.csc.index.npy",
"nonoccurrence_array.index.npy",
"params.index.json",
"vocab.index.json",
],
}


def _get_console() -> Console:
"""Create a Rich console instance for output."""
return Console()


def _fetch_manifest() -> Manifest | None:
"""Fetches the remote data manifest.
def _fetch_latest_version() -> str:
"""Fetch the latest version identifier from remote storage.

Returns:
The manifest dictionary, or None if fetch fails.
The version string (e.g., "20260127_103630").

Raises:
ValueError: If the latest version cannot be fetched.
"""
console = _get_console()
latest_url = f"{Config.R2_ASSETS_BASE_URL}/assets/latest.txt"
try:
response = requests.get(Config.MANIFEST_URL, timeout=10)
response = requests.get(latest_url, timeout=10)
response.raise_for_status()
return response.json()
return response.text.strip()
except requests.exceptions.RequestException as error:
logger.error("Failed to fetch manifest: %s", error)
console.print(f"[bold red]Error fetching manifest: {error}[/bold red]")
return None
logger.error("Failed to fetch latest version: %s", error)
raise ValueError(f"Failed to fetch latest version: {error}") from error


def _resolve_version(manifest: Manifest, version: str | None) -> str:
"""Resolves the version string to an actual toolchain version.
def _download_file(url: str, destination: Path, progress: Progress) -> None:
"""Download a file with progress tracking.

Args:
manifest: The manifest dictionary containing toolchain information.
version: The requested version, or None/"stable" for default.

Returns:
The resolved version string.

Raises:
ValueError: If the version cannot be resolved.
url: The URL to download from.
destination: The local path to save the file.
progress: Rich progress instance for tracking.
"""
if not version or version.lower() == "stable":
resolved = manifest.get("default_toolchain")
if not resolved:
raise ValueError("No default_toolchain specified in manifest")
return resolved
return version
destination.parent.mkdir(parents=True, exist_ok=True)

response = requests.get(url, stream=True, timeout=300)
response.raise_for_status()

def _build_file_registry(version_info: ToolchainVersionInfo) -> dict[str, str]:
"""Builds a Pooch registry from version info.
total_size = int(response.headers.get("content-length", 0))
task_id = progress.add_task(destination.name, total=total_size)

Args:
version_info: The version information from the manifest.

Returns:
A dictionary mapping remote filenames to SHA256 checksums.
"""
return {
file_entry["remote_name"]: f"sha256:{file_entry['sha256']}"
for file_entry in version_info.get("files", [])
if file_entry.get("remote_name") and file_entry.get("sha256")
}
with open(destination, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
progress.update(task_id, advance=len(chunk))


def _write_active_version(version: str) -> None:
Expand Down Expand Up @@ -139,53 +138,64 @@ def _cleanup_old_versions(current_version: str) -> None:


def _install_toolchain(version: str | None = None) -> None:
"""Installs the data toolchain for the specified version.
"""Install the data toolchain for the specified version.

Downloads and verifies all required data files (database, FAISS index, etc.)
using Pooch. Files are automatically decompressed and cached locally.
After successful installation, sets this version as the active version.
Downloads all required data files (database, FAISS index, BM25 indexes)
from remote storage. After successful installation, sets this version
as the active version and cleans up old versions.

Args:
version: The version to install. If None, uses the default version.
version: The version to install. If None, fetches the latest version.

Raises:
ValueError: If manifest fetch fails or version is not found.
ValueError: If version fetch fails or download errors occur.
"""
console = _get_console()

manifest = _fetch_manifest()
if not manifest:
raise ValueError("Failed to fetch manifest")

resolved_version = _resolve_version(manifest, version)
version_info = manifest.get("toolchains", {}).get(resolved_version)
if not version_info:
available = list(manifest.get("toolchains", {}).keys())
raise ValueError(
f"Version '{resolved_version}' not found. Available: {available}"
)

registry = _build_file_registry(version_info)
base_path = version_info.get("assets_base_path_r2", "")
base_url = f"{Config.R2_ASSETS_BASE_URL}/{base_path}/"

file_downloader = pooch.create(
path=Config.CACHE_DIRECTORY / resolved_version,
base_url=base_url,
registry=registry,
)

# Download and decompress each file
for file_entry in version_info.get("files", []):
remote_name = file_entry.get("remote_name")
local_name = file_entry.get("local_name")
if remote_name and local_name:
logger.info("Downloading %s -> %s", remote_name, local_name)
file_downloader.fetch(
remote_name, processor=pooch.Decompress(name=local_name)
)

# Set this version as the active version and clean up old versions
if version:
resolved_version = version
else:
console.print("Fetching latest version...")
resolved_version = _fetch_latest_version()

console.print(f"Installing version: [bold]{resolved_version}[/bold]")

base_url = f"{Config.R2_ASSETS_BASE_URL}/assets/{resolved_version}"
cache_path = Config.CACHE_DIRECTORY / resolved_version

# Build list of all files to download
files_to_download: list[tuple[str, Path]] = []

for filename in REQUIRED_FILES:
url = f"{base_url}/{filename}"
destination = cache_path / filename
files_to_download.append((url, destination))

for directory_name, directory_files in BM25_DIRECTORIES.items():
for filename in directory_files:
url = f"{base_url}/{directory_name}/{filename}"
destination = cache_path / directory_name / filename
files_to_download.append((url, destination))

# Download all files with progress
with Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(),
DownloadColumn(),
TransferSpeedColumn(),
console=console,
) as progress:
for url, destination in files_to_download:
if destination.exists():
logger.info("Skipping existing file: %s", destination.name)
continue
try:
_download_file(url, destination, progress)
except requests.exceptions.RequestException as error:
logger.error("Failed to download %s: %s", url, error)
raise ValueError(f"Failed to download {url}: {error}") from error

# Set this version as active and clean up old versions
_write_active_version(resolved_version)
_cleanup_old_versions(resolved_version)

Expand All @@ -208,29 +218,36 @@ def fetch(
None,
"--version",
"-v",
help="Version to install (e.g., '0.1.0'). Defaults to stable/latest.",
help="Version to install (e.g., '20260127_103630'). Defaults to latest.",
),
) -> None:
"""Fetches and installs the data toolchain from the remote repository.
"""Fetch and install the data toolchain from remote storage.

Downloads the database, FAISS index, and other required data files.
Files are verified with SHA256 checksums and automatically decompressed.
Downloads the database, FAISS index, and BM25 indexes required for
local search. Automatically cleans up old cached versions.
"""
_install_toolchain(version)


@app.command("clean")
def clean_data_toolchains() -> None:
"""Removes all downloaded local data toolchains."""
"""Remove all downloaded local data toolchains."""
console = _get_console()

if not Config.CACHE_DIRECTORY.exists():
cache_exists = Config.CACHE_DIRECTORY.exists()
version_file = Config.CACHE_DIRECTORY.parent / "active_version"
version_exists = version_file.exists()

if not cache_exists and not version_exists:
console.print("[yellow]No local data found to clean.[/yellow]")
return

if typer.confirm("Delete all cached data?", default=False, abort=True):
try:
shutil.rmtree(Config.CACHE_DIRECTORY)
if cache_exists:
shutil.rmtree(Config.CACHE_DIRECTORY)
if version_exists:
version_file.unlink()
console.print("[green]Data cache cleared.[/green]")
except OSError as error:
logger.error("Failed to clean cache directory: %s", error)
Expand Down
Loading