diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..0bf060e --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,10 @@ +{ + "permissions": { + "allow": [ + "Bash(du -sh:*)", + "Bash(du:*)" + ], + "deny": [], + "ask": [] + } +} diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 45d1436..7ea5775 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -13,7 +13,7 @@ jobs: fail-fast: false matrix: # match ubuntu-22.04 - python-version: ['3.10'] + python-version: ["3.9", "3.10", "3.11"] env: OS: ubuntu-22.04 @@ -22,10 +22,26 @@ jobs: steps: # Checkout and env setup - - name: Checkout code - uses: nschloe/action-cached-lfs-checkout@v1.1.3 + - uses: actions/checkout@v5 with: - exclude: "scoutbot/*/models/pytorch/" + lfs: false # Skip LFS files - models are fetched from CDN during tests + fetch-depth: 0 + + # Free up disk space on GitHub runner + - name: Free Disk Space + run: | + echo "Disk usage before cleanup:" + df -h + + # Remove unnecessary large packages + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + sudo apt-get clean + + echo "Disk usage after cleanup:" + df -h - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 diff --git a/docs/environment.rst b/docs/environment.rst index 50e8cea..1cac696 100644 --- a/docs/environment.rst +++ b/docs/environment.rst @@ -1,22 +1,22 @@ Environment Variables --------------------- -The Scoutbot API and CLI have two environment variables (envars) that allow you to configure global settings +The Scoutbot API and CLI have environment variables (envars) that allow you to configure global settings and configurations. - ``CONFIG`` (default: mvp) - The configuration setting for which machine lerning models to use. + The configuration setting for which machine learning models to use. Must be one of ``phase1`` or ``mvp``, or their respective aliases as ``old`` or ``new``. - ``WIC_CONFIG`` (default: not set) - The configuration setting for which machine lerning models to use with the WIC. + The configuration setting for which machine learning models to use with the WIC. Must be one of ``phase1`` or ``mvp``, or their respective aliases as ``old`` or ``new``. Defaults to the value of the ``CONFIG`` environment variable. - ``LOC_CONFIG`` (default: not set) - The configuration setting for which machine lerning models to use with the LOC. + The configuration setting for which machine learning models to use with the LOC. Must be one of ``phase1`` or ``mvp``, or their respective aliases as ``old`` or ``new``. Defaults to the value of the ``CONFIG`` environment variable. - ``AGG_CONFIG`` (default: not set) - The configuration setting for which machine lerning models to use with the AGG. + The configuration setting for which machine learning models to use with the AGG. Must be one of ``phase1`` or ``mvp``, or their respective aliases as ``old`` or ``new``. Defaults to the value of the ``CONFIG`` environment variable. - ``WIC_BATCH_SIZE`` (default: 160) @@ -34,3 +34,23 @@ and configurations. A verbosity flag that can be set to turn on debug logging. Defaults to "not set", which translates to no debug logging. Setting this value to anything will turn on debug logging (e.g., ``VERBOSE=1``). + - ``SCOUTBOT_MODEL_URL`` (default: https://wildbookiarepository.azureedge.net/models) + The base URL or path for accessing ONNX model files. This can be: + + * An HTTP/HTTPS URL (e.g., ``https://example.com/models``) + * A local file path (e.g., ``/opt/models`` or ``C:\models``) + * A network path (e.g., ``//server/share/models`` or ``\\server\share\models``) + + When using file paths, models will be copied to a local cache for consistency. + This allows organizations to host models on their own infrastructure, CDN, or + local/network storage. + - ``SCOUTBOT_DATA_URL`` (default: https://wildbookiarepository.azureedge.net/data) + The base URL or path for downloading test data files. + + As with SCOUTBOT_MODEL_URL, this can be: + * An HTTP/HTTPS URL or + * A local file path or + * A network path + + This allows organizations to host test data on their own infrastructure, CDN, or + local/network storage. diff --git a/docs/onnx.rst b/docs/onnx.rst index ad67516..8f77704 100644 --- a/docs/onnx.rst +++ b/docs/onnx.rst @@ -7,6 +7,11 @@ All of the machine learning models are hosted on GitHub as LFS files. The two m however need those files downloaded to the local machine prior to running inference. These models are hosted on a separate CDN for convenient access and can be fetched by running the following functions: +.. note:: + + The model download URL can be configured via the ``SCOUTBOT_MODEL_URL`` environment variable. + See :doc:`environment` for details. + - :meth:`scoutbot.wic.fetch` - :meth:`scoutbot.loc.fetch` diff --git a/scoutbot/__init__.py b/scoutbot/__init__.py index b08779a..a716744 100644 --- a/scoutbot/__init__.py +++ b/scoutbot/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- ''' -The above components must be run in the correct order, but ScoutbBot also offers a single pipeline. +The above components must be run in the correct order, but Scoutbot also offers a single pipeline. All of the ML models can be pre-downloaded and fetched in a single call to :func:`scoutbot.fetch` and the unified pipeline -- which uses the 4 components correctly -- can be run by the function @@ -49,8 +49,12 @@ # nms_thresh=agg_nms_thresh, # Optional override of config ) ''' +import os import cv2 +import shutil from os.path import exists +from urllib.parse import urlparse +from pathlib import Path import pooch import utool as ut @@ -60,18 +64,115 @@ log = utils.init_logging() QUIET = not utils.VERBOSE +# Base URLs for downloading models and data (can be overridden via environment variables) +MODEL_BASE_URL = (os.getenv('SCOUTBOT_MODEL_URL') or 'https://wildbookiarepository.azureedge.net/models').rstrip('/') +DATA_BASE_URL = (os.getenv('SCOUTBOT_DATA_URL') or 'https://wildbookiarepository.azureedge.net/data').rstrip('/') + +# Validate MODEL_BASE_URL +parsed = urlparse(MODEL_BASE_URL) +if parsed.scheme in ('http', 'https', 'ftp'): + # It's a URL, no further validation needed + pass +elif parsed.scheme and parsed.scheme not in ('file', ''): + raise ValueError(f"Unrecognized scheme in SCOUTBOT_MODEL_URL: {parsed.scheme}") +else: + # It's a path - check if it exists + model_path = Path(MODEL_BASE_URL) + if not model_path.exists(): + raise FileNotFoundError(f"SCOUTBOT_MODEL_URL path does not exist: {MODEL_BASE_URL}") + +# Validate DATA_BASE_URL +parsed = urlparse(DATA_BASE_URL) +if parsed.scheme in ('http', 'https', 'ftp'): + # It's a URL, no further validation needed + pass +elif parsed.scheme and parsed.scheme not in ('file', ''): + raise ValueError(f"Unrecognized scheme in DATA_BASE_URL: {parsed.scheme}") +else: + # It's a path - check if it exists + data_path = Path(DATA_BASE_URL) + if not data_path.exists(): + raise FileNotFoundError(f"DATA_BASE_URL path does not exist: {DATA_BASE_URL}") from scoutbot import agg, loc, tile, wic, tile_batched # NOQA -from scoutbot.loc import CONFIGS as LOC_CONFIGS # NOQA +from scoutbot.loc import CONFIGS as LOC_CONFIGS # NOQA # from tile_batched.models import Yolov8DetectionModel # from tile_batched import get_sliced_prediction_batched -VERSION = '0.1.18' +VERSION = '2.4.2' version = VERSION __version__ = VERSION +def get_resource_from_source(resource_name, resource_hash, source_base, resource_type='model', use_cache=True): + """ + Retrieve a resource (model or data) from URL, local path, or network path. + + Args: + resource_name: Name of the resource file + resource_hash: Expected hash of the resource (used for URL downloads) + source_base: Base URL or path (from MODEL_BASE_URL or DATA_BASE_URL) + resource_type: Type of resource ('model' or 'data') for cache organization + use_cache: Whether to use cached version for URLs + + Returns: + str: Path to the resource file + + Raises: + FileNotFoundError: If the resource file doesn't exist at the specified path + """ + # Parse the source to determine if it's a URL or path + parsed = urlparse(source_base) + + if parsed.scheme in ('http', 'https', 'ftp'): + # It's a URL - use pooch as before + resource_url = f'{source_base}/{resource_name}' + return pooch.retrieve( + url=resource_url, + known_hash=resource_hash, + progressbar=not QUIET, + ) + else: + # It's a local or network path + if source_base.startswith('file://'): + source_base = source_base[7:] + + source_path = Path(source_base) / resource_name + + if not source_path.exists(): + raise FileNotFoundError(f"{resource_type.capitalize()} file not found: {source_path}") + + if use_cache: + # Copy to cache directory to maintain consistency + cache_dir = Path(pooch.os_cache(f"scoutbot/{resource_type}s")) + cache_dir.mkdir(parents=True, exist_ok=True) + cache_path = cache_dir / resource_name + + if source_path.resolve() == cache_path.resolve(): + return str(source_path) + + # Only copy if not already cached or file has changed + if not cache_path.exists() or os.path.getmtime(source_path) > os.path.getmtime(cache_path): + log.debug(f"Copying {resource_type} from {source_path} to cache {cache_path}") + shutil.copy2(source_path, cache_path) + + return str(cache_path) + else: + # Use directly from source + return str(source_path) + + +def get_model_from_source(model_name, model_hash, source_base, use_cache=True): + """Backward compatibility wrapper for get_resource_from_source""" + return get_resource_from_source(model_name, model_hash, source_base, 'model', use_cache) + + +def get_data_from_source(data_name, data_hash, source_base, use_cache=True): + """Convenience wrapper for get_resource_from_source for data files""" + return get_resource_from_source(data_name, data_hash, source_base, 'data', use_cache) + + def fetch(pull=False, config=None): """ Fetch the WIC and Localizer ONNX model files from a CDN if they do not exist locally. @@ -99,15 +200,15 @@ def fetch(pull=False, config=None): def pipeline( - filepath, - config=None, - backend_device='cuda:0', - wic_thresh=wic.CONFIGS[None]['thresh'], - loc_thresh=loc.CONFIGS[None]['thresh'], - loc_nms_thresh=loc.CONFIGS[None]['nms'], - agg_thresh=agg.CONFIGS[None]['thresh'], - agg_nms_thresh=agg.CONFIGS[None]['nms'], - clean=True, + filepath, + config=None, + backend_device='cuda:0', + wic_thresh=wic.CONFIGS[None]['thresh'], + loc_thresh=loc.CONFIGS[None]['thresh'], + loc_nms_thresh=loc.CONFIGS[None]['nms'], + agg_thresh=agg.CONFIGS[None]['thresh'], + agg_nms_thresh=agg.CONFIGS[None]['nms'], + clean=True, ): """ Run the ML pipeline on a given image filepath and return the detections @@ -193,17 +294,17 @@ def pipeline( def pipeline_v3( - filepath, - config, - batched_detection_model=None, - backend_device='cuda:0', - loc_thresh=0.45, - slice_height=512, - slice_width=512, - overlap_height_ratio=0.25, - overlap_width_ratio=0.25, - perform_standard_pred=False, - postprocess_class_agnostic=True, + filepath, + config, + batched_detection_model=None, + backend_device='cuda:0', + loc_thresh=0.45, + slice_height=512, + slice_width=512, + overlap_height_ratio=0.25, + overlap_width_ratio=0.25, + perform_standard_pred=False, + postprocess_class_agnostic=True, ): """ Run the ML pipeline on a given image filepath and return the detections @@ -274,15 +375,15 @@ def pipeline_v3( def batch( - filepaths, - config=None, - backend_device='cuda:0', - wic_thresh=wic.CONFIGS[None]['thresh'], - loc_thresh=loc.CONFIGS[None]['thresh'], - loc_nms_thresh=loc.CONFIGS[None]['nms'], - agg_thresh=agg.CONFIGS[None]['thresh'], - agg_nms_thresh=agg.CONFIGS[None]['nms'], - clean=True, + filepaths, + config=None, + backend_device='cuda:0', + wic_thresh=wic.CONFIGS[None]['thresh'], + loc_thresh=loc.CONFIGS[None]['thresh'], + loc_nms_thresh=loc.CONFIGS[None]['nms'], + agg_thresh=agg.CONFIGS[None]['thresh'], + agg_nms_thresh=agg.CONFIGS[None]['nms'], + clean=True, ): """ Run the ML pipeline on a given batch of image filepaths and return the detections @@ -381,7 +482,7 @@ def batch( assert len(loc_tile_grids) == len(loc_outputs) for filepath, loc_tile_grid, loc_output in zip( - loc_tile_img_filepaths, loc_tile_grids, loc_outputs + loc_tile_img_filepaths, loc_tile_grids, loc_outputs ): batch[filepath]['loc']['grids'].append(loc_tile_grid) batch[filepath]['loc']['outputs'].append(loc_output) @@ -420,16 +521,16 @@ def batch( def batch_v3( - filepaths, - config, - backend_device, - loc_thresh=0.45, - slice_height=512, - slice_width=512, - overlap_height_ratio=0.25, - overlap_width_ratio=0.25, - perform_standard_pred=False, - postprocess_class_agnostic=True, + filepaths, + config, + backend_device, + loc_thresh=0.45, + slice_height=512, + slice_width=512, + overlap_height_ratio=0.25, + overlap_width_ratio=0.25, + perform_standard_pred=False, + postprocess_class_agnostic=True, ): yolov8_model_path = loc.fetch(config=config) @@ -490,10 +591,11 @@ def example(): '786a940b062a90961f409539292f09144c3dbdbc6b6faa64c3e764d63d55c988' # NOQA ) - img_filepath = pooch.retrieve( - url=f'https://wildbookiarepository.azureedge.net/data/{TEST_IMAGE}', - known_hash=TEST_IMAGE_HASH, - progressbar=True, + img_filepath = get_data_from_source( + data_name=TEST_IMAGE, + data_hash=TEST_IMAGE_HASH, + source_base=DATA_BASE_URL, + use_cache=True ) assert exists(img_filepath) diff --git a/scoutbot/loc/__init__.py b/scoutbot/loc/__init__.py index e29d388..9e7f46c 100644 --- a/scoutbot/loc/__init__.py +++ b/scoutbot/loc/__init__.py @@ -15,13 +15,12 @@ import cv2 import numpy as np import onnxruntime as ort -import pooch import torch import torchvision import tqdm import utool as ut -from scoutbot import QUIET, log +from scoutbot import log, utils from scoutbot.loc.transforms import ( Compose, GetBoundingBoxes, @@ -144,13 +143,13 @@ def fetch(pull=False, config=DEFAULT_CONFIG): """ - Fetch the Localizer ONNX model file from a CDN if it does not exist locally. + Fetch the Localizer ONNX model file from a CDN, local path, or network path. This function will throw an AssertionError if the download fails or the file otherwise does not exists locally on disk. Args: - pull (bool, optional): If :obj:`True`, force using the downloaded versions + pull (bool, optional): If :obj:`True`, force using the downloaded/copied versions stored in the local system's cache. Defaults to :obj:`False`. config (str or None, optional): the configuration to use, one of ``phase1`` or ``mvp``. Defaults to :obj:`None`. @@ -160,6 +159,7 @@ def fetch(pull=False, config=DEFAULT_CONFIG): Raises: AssertionError: If the model cannot be fetched. + FileNotFoundError: If the model file doesn't exist at the specified path. """ if config is None: config = DEFAULT_CONFIG @@ -168,15 +168,22 @@ def fetch(pull=False, config=DEFAULT_CONFIG): model_path = CONFIGS[config]['path'] model_hash = CONFIGS[config]['hash'] - if not pull and exists(model_path): - onnx_model = model_path - else: - onnx_model = pooch.retrieve( - url=f'https://wildbookiarepository.azureedge.net/models/{model_name}', - known_hash=model_hash, - progressbar=not QUIET, - ) - assert exists(onnx_model) + if exists(model_path) and not pull: + if utils.check_file_integrity(model_path, model_hash): + return model_path + else: + log.warning(f"Local model found at {model_path} but hash mismatch. Attempting fetch...") + + # Import the utility function from parent module + from scoutbot import MODEL_BASE_URL, get_model_from_source + + onnx_model = get_model_from_source( + model_name=model_name, + model_hash=model_hash, + source_base=MODEL_BASE_URL, + use_cache=True + ) + assert exists(onnx_model) log.debug(f'LOC Model: {onnx_model}') @@ -256,6 +263,9 @@ def predict(gen): """ log.debug('Running LOC inference') + # Import QUIET from parent module + from scoutbot import QUIET + ort_sessions = {} for chunk, sizes, trim, config in tqdm.tqdm(gen, disable=QUIET): diff --git a/scoutbot/utils.py b/scoutbot/utils.py index b7bd234..0d05358 100644 --- a/scoutbot/utils.py +++ b/scoutbot/utils.py @@ -4,6 +4,7 @@ ''' import logging import os +import hashlib from logging.handlers import TimedRotatingFileHandler DAYS = 21 @@ -11,6 +12,26 @@ DEFAULT_LOG_LEVEL = logging.DEBUG if VERBOSE else logging.INFO +def check_file_integrity(filepath, expected_hash): + """ + Verifies that the file at filepath matches the expected sha256 hash. + Returns True if matches, False otherwise. + """ + if not os.path.exists(filepath): + return False + + if expected_hash is None: + return True + + sha256_hash = hashlib.sha256() + with open(filepath, "rb") as f: + # Read and update hash string value in blocks of 4K + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + + return sha256_hash.hexdigest() == expected_hash + + def init_logging(): """ Setup Python's built in logging functionality with on-disk logging, and prettier logging with Rich diff --git a/scoutbot/wic/__init__.py b/scoutbot/wic/__init__.py index 902dd1d..6895826 100644 --- a/scoutbot/wic/__init__.py +++ b/scoutbot/wic/__init__.py @@ -14,11 +14,10 @@ import numpy as np import onnxruntime as ort -import pooch import torch import tqdm -from scoutbot import QUIET, log +from scoutbot import log, utils from scoutbot.wic.dataloader import ( # NOQA BATCH_SIZE, INPUT_SIZE, @@ -28,7 +27,6 @@ PWD = Path(__file__).absolute().parent - DEFAULT_CONFIG = os.getenv('WIC_CONFIG', os.getenv('CONFIG', 'mvp')).strip().lower() CONFIGS = { 'phase1': { @@ -54,13 +52,13 @@ def fetch(pull=False, config=DEFAULT_CONFIG): """ - Fetch the WIC ONNX model file from a CDN if it does not exist locally. + Fetch the WIC ONNX model file from a CDN, local path, or network path. This function will throw an AssertionError if the download fails or the file otherwise does not exists locally on disk. Args: - pull (bool, optional): If :obj:`True`, force using the downloaded versions + pull (bool, optional): If :obj:`True`, force using the downloaded/copied versions stored in the local system's cache. Defaults to :obj:`False`. config (str or None, optional): the configuration to use, one of ``phase1`` or ``mvp``. Defaults to :obj:`None`. @@ -70,6 +68,7 @@ def fetch(pull=False, config=DEFAULT_CONFIG): Raises: AssertionError: If the model cannot be fetched. + FileNotFoundError: If the model file doesn't exist at the specified path. """ if config is None: config = DEFAULT_CONFIG @@ -78,16 +77,22 @@ def fetch(pull=False, config=DEFAULT_CONFIG): model_path = CONFIGS[config]['path'] model_hash = CONFIGS[config]['hash'] - if not pull and exists(model_path): - onnx_model = model_path - else: - onnx_model = pooch.retrieve( - url=f'https://wildbookiarepository.azureedge.net/models/{model_name}', - known_hash=model_hash, - progressbar=not QUIET, - ) - assert exists(onnx_model) + if exists(model_path) and not pull: + if utils.check_file_integrity(model_path, model_hash): + return model_path + else: + log.warning(f"Local model found at {model_path} but hash mismatch. Attempting fetch...") + # Proceed to download logic below + + from scoutbot import MODEL_BASE_URL, get_model_from_source + onnx_model = get_model_from_source( + model_name=model_name, + model_hash=model_hash, + source_base=MODEL_BASE_URL, + use_cache=True + ) + assert exists(onnx_model) log.debug(f'WIC Model: {onnx_model}') return onnx_model @@ -155,6 +160,9 @@ def predict(gen): """ log.debug('Running WIC inference') + # Import QUIET from parent module + from scoutbot import QUIET + ort_sessions = {} for chunk, config in tqdm.tqdm(gen, disable=QUIET):