diff --git a/pyproject.toml b/pyproject.toml index c2f6285..2232f50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "metatensor-torch >=0.7.6,<0.9", "metatomic-torch >=0.1.2,<0.2", "vesin", + "requests", ] readme = "README.md" diff --git a/src/shiftml/ase/calculator.py b/src/shiftml/ase/calculator.py index 7ef6e1e..9523416 100644 --- a/src/shiftml/ase/calculator.py +++ b/src/shiftml/ase/calculator.py @@ -1,11 +1,13 @@ import logging import os -import urllib.request import numpy as np +import requests from metatomic.torch import ModelOutput from metatomic.torch.ase_calculator import MetatomicCalculator from platformdirs import user_cache_path +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry from shiftml.utils.tensorial import T_sym_np_inv, symmetrize @@ -57,6 +59,25 @@ def is_fitted_on(atoms, fitted_species): ) +def download_with_retry(url, destination): + """Helper function to download data with retries on errors.""" + + # Retry strategy: wait 1s, 2s, 4s, 8s, 16s on 429/5xx errors + retry_strategy = Retry( + total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504] + ) + session = requests.Session() + session.mount("https://", HTTPAdapter(max_retries=retry_strategy)) + + # Fetch with automatic retry and error raising + response = session.get(url, stream=True) + response.raise_for_status() + + with open(destination, "wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + + def ShiftML(model_version, force_download=False, device=None): """ Initialize the ShiftML calculator @@ -247,24 +268,14 @@ def __init__(self, model_version, force_download=False, device=None): download = True if download: - urllib.request.urlretrieve(url, model_file) + download_with_retry(url, model_file) logging.info( "Downloaded {} and saved to {}".format(model_version, cachedir) ) - except urllib.error.URLError as e: + except requests.exceptions.RequestException as e: logging.error( - "Failed to download {} from {}. URL Error: {}".format( - model_version, url, e.reason - ) - ) - raise e - except urllib.error.HTTPError as e: - logging.error( - "Failed to download {} from {}.\ - HTTP Error: {} - {}".format( - model_version, url, e.code, e.reason - ) + "Failed to download {} from {}. Error: {}".format(model_version, url, e) ) raise e except Exception as e: