diff --git a/pyproject.toml b/pyproject.toml index aaa714c..888645e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lean-explore" -version = "1.0.0" +version = "1.1.0" authors = [ { name = "Justin Asher", email = "justinchadwickasher@gmail.com" }, ] diff --git a/src/lean_explore/cli/data_commands.py b/src/lean_explore/cli/data_commands.py index e48a2f0..0624957 100644 --- a/src/lean_explore/cli/data_commands.py +++ b/src/lean_explore/cli/data_commands.py @@ -3,42 +3,26 @@ """Manages local Lean Explore data toolchains. Provides CLI commands to download, install, and clean data files (database, -FAISS index, etc.) from remote storage using Pooch for checksums and caching. +FAISS index, BM25 indexes, etc.) from remote storage. """ import logging import shutil -from typing import TypedDict +from pathlib import Path -import pooch import requests import typer from rich.console import Console +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + TextColumn, + TransferSpeedColumn, +) from lean_explore.config import Config - -class ManifestFileEntry(TypedDict): - """A file entry in the manifest's toolchain version.""" - - remote_name: str - local_name: str - sha256: str - - -class ToolchainVersionInfo(TypedDict): - """Version information for a specific toolchain in the manifest.""" - - assets_base_path_r2: str - files: list[ManifestFileEntry] - - -class Manifest(TypedDict): - """Remote data manifest structure.""" - - default_toolchain: str - toolchains: dict[str, ToolchainVersionInfo] - logger = logging.getLogger(__name__) app = typer.Typer( @@ -48,64 +32,79 @@ class Manifest(TypedDict): no_args_is_help=True, ) +# Files required for the search engine (relative to version directory) +REQUIRED_FILES: list[str] = [ + "lean_explore.db", + "informalization_faiss.index", + "informalization_faiss_ids_map.json", + "bm25_ids_map.json", +] + +# BM25 index directories and their contents +BM25_DIRECTORIES: dict[str, list[str]] = { + "bm25_name_raw": [ + "data.csc.index.npy", + "indices.csc.index.npy", + "indptr.csc.index.npy", + "nonoccurrence_array.index.npy", + "params.index.json", + "vocab.index.json", + ], + "bm25_name_spaced": [ + "data.csc.index.npy", + "indices.csc.index.npy", + "indptr.csc.index.npy", + "nonoccurrence_array.index.npy", + "params.index.json", + "vocab.index.json", + ], +} + def _get_console() -> Console: """Create a Rich console instance for output.""" return Console() -def _fetch_manifest() -> Manifest | None: - """Fetches the remote data manifest. +def _fetch_latest_version() -> str: + """Fetch the latest version identifier from remote storage. Returns: - The manifest dictionary, or None if fetch fails. + The version string (e.g., "20260127_103630"). + + Raises: + ValueError: If the latest version cannot be fetched. """ - console = _get_console() + latest_url = f"{Config.R2_ASSETS_BASE_URL}/assets/latest.txt" try: - response = requests.get(Config.MANIFEST_URL, timeout=10) + response = requests.get(latest_url, timeout=10) response.raise_for_status() - return response.json() + return response.text.strip() except requests.exceptions.RequestException as error: - logger.error("Failed to fetch manifest: %s", error) - console.print(f"[bold red]Error fetching manifest: {error}[/bold red]") - return None + logger.error("Failed to fetch latest version: %s", error) + raise ValueError(f"Failed to fetch latest version: {error}") from error -def _resolve_version(manifest: Manifest, version: str | None) -> str: - """Resolves the version string to an actual toolchain version. +def _download_file(url: str, destination: Path, progress: Progress) -> None: + """Download a file with progress tracking. Args: - manifest: The manifest dictionary containing toolchain information. - version: The requested version, or None/"stable" for default. - - Returns: - The resolved version string. - - Raises: - ValueError: If the version cannot be resolved. + url: The URL to download from. + destination: The local path to save the file. + progress: Rich progress instance for tracking. """ - if not version or version.lower() == "stable": - resolved = manifest.get("default_toolchain") - if not resolved: - raise ValueError("No default_toolchain specified in manifest") - return resolved - return version + destination.parent.mkdir(parents=True, exist_ok=True) + response = requests.get(url, stream=True, timeout=300) + response.raise_for_status() -def _build_file_registry(version_info: ToolchainVersionInfo) -> dict[str, str]: - """Builds a Pooch registry from version info. + total_size = int(response.headers.get("content-length", 0)) + task_id = progress.add_task(destination.name, total=total_size) - Args: - version_info: The version information from the manifest. - - Returns: - A dictionary mapping remote filenames to SHA256 checksums. - """ - return { - file_entry["remote_name"]: f"sha256:{file_entry['sha256']}" - for file_entry in version_info.get("files", []) - if file_entry.get("remote_name") and file_entry.get("sha256") - } + with open(destination, "wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + progress.update(task_id, advance=len(chunk)) def _write_active_version(version: str) -> None: @@ -139,53 +138,64 @@ def _cleanup_old_versions(current_version: str) -> None: def _install_toolchain(version: str | None = None) -> None: - """Installs the data toolchain for the specified version. + """Install the data toolchain for the specified version. - Downloads and verifies all required data files (database, FAISS index, etc.) - using Pooch. Files are automatically decompressed and cached locally. - After successful installation, sets this version as the active version. + Downloads all required data files (database, FAISS index, BM25 indexes) + from remote storage. After successful installation, sets this version + as the active version and cleans up old versions. Args: - version: The version to install. If None, uses the default version. + version: The version to install. If None, fetches the latest version. Raises: - ValueError: If manifest fetch fails or version is not found. + ValueError: If version fetch fails or download errors occur. """ console = _get_console() - manifest = _fetch_manifest() - if not manifest: - raise ValueError("Failed to fetch manifest") - - resolved_version = _resolve_version(manifest, version) - version_info = manifest.get("toolchains", {}).get(resolved_version) - if not version_info: - available = list(manifest.get("toolchains", {}).keys()) - raise ValueError( - f"Version '{resolved_version}' not found. Available: {available}" - ) - - registry = _build_file_registry(version_info) - base_path = version_info.get("assets_base_path_r2", "") - base_url = f"{Config.R2_ASSETS_BASE_URL}/{base_path}/" - - file_downloader = pooch.create( - path=Config.CACHE_DIRECTORY / resolved_version, - base_url=base_url, - registry=registry, - ) - - # Download and decompress each file - for file_entry in version_info.get("files", []): - remote_name = file_entry.get("remote_name") - local_name = file_entry.get("local_name") - if remote_name and local_name: - logger.info("Downloading %s -> %s", remote_name, local_name) - file_downloader.fetch( - remote_name, processor=pooch.Decompress(name=local_name) - ) - - # Set this version as the active version and clean up old versions + if version: + resolved_version = version + else: + console.print("Fetching latest version...") + resolved_version = _fetch_latest_version() + + console.print(f"Installing version: [bold]{resolved_version}[/bold]") + + base_url = f"{Config.R2_ASSETS_BASE_URL}/assets/{resolved_version}" + cache_path = Config.CACHE_DIRECTORY / resolved_version + + # Build list of all files to download + files_to_download: list[tuple[str, Path]] = [] + + for filename in REQUIRED_FILES: + url = f"{base_url}/{filename}" + destination = cache_path / filename + files_to_download.append((url, destination)) + + for directory_name, directory_files in BM25_DIRECTORIES.items(): + for filename in directory_files: + url = f"{base_url}/{directory_name}/{filename}" + destination = cache_path / directory_name / filename + files_to_download.append((url, destination)) + + # Download all files with progress + with Progress( + TextColumn("[bold blue]{task.description}"), + BarColumn(), + DownloadColumn(), + TransferSpeedColumn(), + console=console, + ) as progress: + for url, destination in files_to_download: + if destination.exists(): + logger.info("Skipping existing file: %s", destination.name) + continue + try: + _download_file(url, destination, progress) + except requests.exceptions.RequestException as error: + logger.error("Failed to download %s: %s", url, error) + raise ValueError(f"Failed to download {url}: {error}") from error + + # Set this version as active and clean up old versions _write_active_version(resolved_version) _cleanup_old_versions(resolved_version) @@ -208,29 +218,36 @@ def fetch( None, "--version", "-v", - help="Version to install (e.g., '0.1.0'). Defaults to stable/latest.", + help="Version to install (e.g., '20260127_103630'). Defaults to latest.", ), ) -> None: - """Fetches and installs the data toolchain from the remote repository. + """Fetch and install the data toolchain from remote storage. - Downloads the database, FAISS index, and other required data files. - Files are verified with SHA256 checksums and automatically decompressed. + Downloads the database, FAISS index, and BM25 indexes required for + local search. Automatically cleans up old cached versions. """ _install_toolchain(version) @app.command("clean") def clean_data_toolchains() -> None: - """Removes all downloaded local data toolchains.""" + """Remove all downloaded local data toolchains.""" console = _get_console() - if not Config.CACHE_DIRECTORY.exists(): + cache_exists = Config.CACHE_DIRECTORY.exists() + version_file = Config.CACHE_DIRECTORY.parent / "active_version" + version_exists = version_file.exists() + + if not cache_exists and not version_exists: console.print("[yellow]No local data found to clean.[/yellow]") return if typer.confirm("Delete all cached data?", default=False, abort=True): try: - shutil.rmtree(Config.CACHE_DIRECTORY) + if cache_exists: + shutil.rmtree(Config.CACHE_DIRECTORY) + if version_exists: + version_file.unlink() console.print("[green]Data cache cleared.[/green]") except OSError as error: logger.error("Failed to clean cache directory: %s", error) diff --git a/tests/cli/data_commands_test.py b/tests/cli/data_commands_test.py index 5074925..fed368c 100644 --- a/tests/cli/data_commands_test.py +++ b/tests/cli/data_commands_test.py @@ -11,11 +11,11 @@ from typer.testing import CliRunner from lean_explore.cli.data_commands import ( - _build_file_registry, - _fetch_manifest, + _cleanup_old_versions, + _fetch_latest_version, _get_console, _install_toolchain, - _resolve_version, + _write_active_version, app, ) @@ -31,202 +31,192 @@ def test_get_console_returns_console(self): assert console is not None -class TestFetchManifest: - """Tests for the _fetch_manifest function.""" +class TestFetchLatestVersion: + """Tests for the _fetch_latest_version function.""" - def test_fetch_manifest_success(self): - """Test successful manifest fetch.""" + def test_fetch_latest_version_success(self): + """Test successful latest version fetch.""" mock_response = MagicMock() - mock_response.json.return_value = { - "default_toolchain": "0.1.0", - "toolchains": {"0.1.0": {}}, - } + mock_response.text = "20260127_103630\n" mock_response.raise_for_status = MagicMock() with patch("requests.get", return_value=mock_response): - result = _fetch_manifest() - assert result is not None - assert result["default_toolchain"] == "0.1.0" + result = _fetch_latest_version() + assert result == "20260127_103630" - def test_fetch_manifest_network_error(self): - """Test manifest fetch with network error.""" + def test_fetch_latest_version_strips_whitespace(self): + """Test that version string is stripped of whitespace.""" + mock_response = MagicMock() + mock_response.text = " 20260127_103630 \n" + mock_response.raise_for_status = MagicMock() + + with patch("requests.get", return_value=mock_response): + result = _fetch_latest_version() + assert result == "20260127_103630" + + def test_fetch_latest_version_network_error(self): + """Test latest version fetch with network error.""" import requests with patch("requests.get", side_effect=requests.exceptions.ConnectionError()): - result = _fetch_manifest() - assert result is None + with pytest.raises(ValueError, match="Failed to fetch latest version"): + _fetch_latest_version() - def test_fetch_manifest_http_error(self): - """Test manifest fetch with HTTP error.""" + def test_fetch_latest_version_http_error(self): + """Test latest version fetch with HTTP error.""" import requests mock_response = MagicMock() mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError() with patch("requests.get", return_value=mock_response): - result = _fetch_manifest() - assert result is None + with pytest.raises(ValueError, match="Failed to fetch latest version"): + _fetch_latest_version() - def test_fetch_manifest_timeout(self): - """Test manifest fetch with timeout.""" + def test_fetch_latest_version_timeout(self): + """Test latest version fetch with timeout.""" import requests with patch("requests.get", side_effect=requests.exceptions.Timeout()): - result = _fetch_manifest() - assert result is None - - -class TestResolveVersion: - """Tests for the _resolve_version function.""" - - def test_resolve_version_none_uses_default(self): - """Test that None version resolves to default.""" - manifest = {"default_toolchain": "1.0.0"} - result = _resolve_version(manifest, None) - assert result == "1.0.0" - - def test_resolve_version_stable_uses_default(self): - """Test that 'stable' resolves to default.""" - manifest = {"default_toolchain": "1.0.0"} - result = _resolve_version(manifest, "stable") - assert result == "1.0.0" - - def test_resolve_version_stable_case_insensitive(self): - """Test that 'STABLE' is case insensitive.""" - manifest = {"default_toolchain": "1.0.0"} - result = _resolve_version(manifest, "STABLE") - assert result == "1.0.0" - - def test_resolve_version_specific(self): - """Test that specific version is returned as-is.""" - manifest = {"default_toolchain": "1.0.0"} - result = _resolve_version(manifest, "2.0.0") - assert result == "2.0.0" - - def test_resolve_version_no_default(self): - """Test error when no default and version is None.""" - manifest = {} - with pytest.raises(ValueError, match="No default_toolchain"): - _resolve_version(manifest, None) - - -class TestBuildFileRegistry: - """Tests for the _build_file_registry function.""" - - def test_build_registry_with_valid_files(self): - """Test building registry with valid file entries.""" - version_info = { - "files": [ - {"remote_name": "file1.gz", "sha256": "abc123", "local_name": "file1"}, - {"remote_name": "file2.gz", "sha256": "def456", "local_name": "file2"}, - ] - } - result = _build_file_registry(version_info) - assert result == { - "file1.gz": "sha256:abc123", - "file2.gz": "sha256:def456", - } - - def test_build_registry_skips_incomplete_entries(self): - """Test that entries without remote_name or sha256 are skipped.""" - version_info = { - "files": [ - {"remote_name": "file1.gz", "sha256": "abc123"}, - {"remote_name": "file2.gz"}, # Missing sha256 - {"sha256": "def456"}, # Missing remote_name - ] - } - result = _build_file_registry(version_info) - assert result == {"file1.gz": "sha256:abc123"} - - def test_build_registry_empty_files(self): - """Test building registry with empty files list.""" - version_info = {"files": []} - result = _build_file_registry(version_info) - assert result == {} - - def test_build_registry_no_files_key(self): - """Test building registry when files key is missing.""" - version_info = {} - result = _build_file_registry(version_info) - assert result == {} + with pytest.raises(ValueError, match="Failed to fetch latest version"): + _fetch_latest_version() + + +class TestWriteActiveVersion: + """Tests for the _write_active_version function.""" + + def test_write_active_version_creates_file(self): + """Test that active version is written to file.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_dir = Path(tmpdir) / "cache" + version_file = Path(tmpdir) / "active_version" + + with patch( + "lean_explore.cli.data_commands.Config.CACHE_DIRECTORY", cache_dir + ): + _write_active_version("20260127_103630") + assert version_file.exists() + assert version_file.read_text() == "20260127_103630" + + def test_write_active_version_overwrites_existing(self): + """Test that active version file is overwritten.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_dir = Path(tmpdir) / "cache" + version_file = Path(tmpdir) / "active_version" + version_file.write_text("old_version") + + with patch( + "lean_explore.cli.data_commands.Config.CACHE_DIRECTORY", cache_dir + ): + _write_active_version("new_version") + assert version_file.read_text() == "new_version" + + +class TestCleanupOldVersions: + """Tests for the _cleanup_old_versions function.""" + + def test_cleanup_removes_old_versions(self): + """Test that old version directories are removed.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_dir = Path(tmpdir) + old_version = cache_dir / "old_version" + current_version = cache_dir / "current_version" + old_version.mkdir() + current_version.mkdir() + (old_version / "file.txt").touch() + (current_version / "file.txt").touch() + + with patch( + "lean_explore.cli.data_commands.Config.CACHE_DIRECTORY", cache_dir + ): + _cleanup_old_versions("current_version") + assert not old_version.exists() + assert current_version.exists() + + def test_cleanup_handles_nonexistent_cache(self): + """Test cleanup when cache directory doesn't exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + nonexistent = Path(tmpdir) / "nonexistent" + with patch( + "lean_explore.cli.data_commands.Config.CACHE_DIRECTORY", nonexistent + ): + # Should not raise + _cleanup_old_versions("any_version") class TestInstallToolchain: """Tests for the _install_toolchain function.""" - def test_install_toolchain_manifest_fetch_fails(self): - """Test error when manifest fetch fails.""" + def test_install_toolchain_version_fetch_fails(self): + """Test error when latest version fetch fails.""" with patch( - "lean_explore.cli.data_commands._fetch_manifest", return_value=None + "lean_explore.cli.data_commands._fetch_latest_version", + side_effect=ValueError("Failed to fetch"), ): - with pytest.raises(ValueError, match="Failed to fetch manifest"): + with pytest.raises(ValueError, match="Failed to fetch"): _install_toolchain() - def test_install_toolchain_version_not_found(self): - """Test error when version is not in manifest.""" - manifest = { - "default_toolchain": "1.0.0", - "toolchains": {"1.0.0": {}}, - } - with patch( - "lean_explore.cli.data_commands._fetch_manifest", return_value=manifest - ): - with pytest.raises(ValueError, match="not found"): - _install_toolchain("2.0.0") - - def test_install_toolchain_shows_available_versions(self): - """Test that error message shows available versions.""" - manifest = { - "default_toolchain": "1.0.0", - "toolchains": {"1.0.0": {}, "1.1.0": {}}, - } - with patch( - "lean_explore.cli.data_commands._fetch_manifest", return_value=manifest - ): - with pytest.raises(ValueError) as exc_info: - _install_toolchain("2.0.0") - assert "1.0.0" in str(exc_info.value) or "Available" in str(exc_info.value) + def test_install_toolchain_with_explicit_version(self): + """Test that explicit version skips latest fetch.""" + import requests as requests_module - @pytest.mark.integration - def test_install_toolchain_success(self): - """Test successful toolchain installation.""" - manifest = { - "default_toolchain": "1.0.0", - "toolchains": { - "1.0.0": { - "assets_base_path_r2": "v1", - "files": [ - { - "remote_name": "test.gz", - "sha256": "abc123", - "local_name": "test", - } - ], - } - }, - } - - mock_pooch = MagicMock() - mock_pooch.fetch = MagicMock() + with tempfile.TemporaryDirectory() as tmpdir: + cache_dir = Path(tmpdir) / "cache" + + # Mock the download to always fail so we can verify version is used + with ( + patch( + "lean_explore.cli.data_commands.Config.CACHE_DIRECTORY", cache_dir + ), + patch( + "lean_explore.cli.data_commands._fetch_latest_version" + ) as mock_fetch, + patch("requests.get") as mock_get, + ): + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = ( + requests_module.exceptions.HTTPError("Download fail") + ) + mock_get.return_value = mock_response + + with pytest.raises(ValueError): + _install_toolchain("explicit_version") + # Should not call fetch latest when version is explicit + mock_fetch.assert_not_called() + + @pytest.mark.integration + def test_install_toolchain_downloads_all_files(self): + """Test that all required files are downloaded.""" with tempfile.TemporaryDirectory() as tmpdir: - manifest_patch = patch( - "lean_explore.cli.data_commands._fetch_manifest", - return_value=manifest, - ) - pooch_patch = patch("lean_explore.cli.data_commands.pooch.create") - config_patch = patch( - "lean_explore.cli.data_commands.Config.DATA_DIRECTORY", - Path(tmpdir), - ) - with manifest_patch, pooch_patch as mock_create, config_patch: - mock_create.return_value = mock_pooch + cache_dir = Path(tmpdir) / "cache" - _install_toolchain("1.0.0") + downloaded_urls = [] + + def mock_get(url, **kwargs): + downloaded_urls.append(url) + response = MagicMock() + response.headers = {"content-length": "100"} + response.iter_content = MagicMock(return_value=[b"data"]) + return response + + with ( + patch( + "lean_explore.cli.data_commands.Config.CACHE_DIRECTORY", cache_dir + ), + patch( + "lean_explore.cli.data_commands._fetch_latest_version", + return_value="test_version", + ), + patch("requests.get", side_effect=mock_get), + ): + _install_toolchain() - mock_create.assert_called_once() - mock_pooch.fetch.assert_called_once() + # Verify key files were requested + url_paths = [url.split("/")[-1] for url in downloaded_urls] + assert "lean_explore.db" in url_paths + assert "informalization_faiss.index" in url_paths + assert "bm25_ids_map.json" in url_paths class TestFetchCommand: @@ -240,19 +230,15 @@ def test_fetch_command_help(self): def test_fetch_command_calls_install(self): """Test that fetch command calls _install_toolchain.""" - with patch( - "lean_explore.cli.data_commands._install_toolchain" - ) as mock_install: + with patch("lean_explore.cli.data_commands._install_toolchain") as mock_install: runner.invoke(app, ["fetch"]) mock_install.assert_called_once_with(None) def test_fetch_command_with_version(self): """Test fetch command with specific version.""" - with patch( - "lean_explore.cli.data_commands._install_toolchain" - ) as mock_install: - runner.invoke(app, ["fetch", "--version", "1.0.0"]) - mock_install.assert_called_once_with("1.0.0") + with patch("lean_explore.cli.data_commands._install_toolchain") as mock_install: + runner.invoke(app, ["fetch", "--version", "20260127_103630"]) + mock_install.assert_called_once_with("20260127_103630") class TestCleanCommand: @@ -300,6 +286,21 @@ def test_clean_command_confirmed(self): assert not cache_dir.exists() assert "cleared" in result.output.lower() + def test_clean_command_removes_version_file(self): + """Test that clean also removes the active_version file.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_dir = Path(tmpdir) / "cache" + cache_dir.mkdir() + version_file = Path(tmpdir) / "active_version" + version_file.write_text("some_version") + + with patch( + "lean_explore.cli.data_commands.Config.CACHE_DIRECTORY", cache_dir + ): + result = runner.invoke(app, ["clean"], input="y\n") + assert not version_file.exists() + assert "cleared" in result.output.lower() + class TestDataApp: """Tests for the data app structure."""