Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"metatensor-torch >=0.7.6,<0.9",
"metatomic-torch >=0.1.2,<0.2",
"vesin",
"requests",
]

readme = "README.md"
Expand Down
39 changes: 25 additions & 14 deletions src/shiftml/ase/calculator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os
import urllib.request

import numpy as np
import requests
from metatomic.torch import ModelOutput
from metatomic.torch.ase_calculator import MetatomicCalculator
from platformdirs import user_cache_path
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from shiftml.utils.tensorial import T_sym_np_inv, symmetrize

Expand Down Expand Up @@ -57,6 +59,25 @@ def is_fitted_on(atoms, fitted_species):
)


def download_with_retry(url, destination):
"""Helper function to download data with retries on errors."""

# Retry strategy: wait 1s, 2s, 4s, 8s, 16s on 429/5xx errors
retry_strategy = Retry(
total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504]
)
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=retry_strategy))

# Fetch with automatic retry and error raising
response = session.get(url, stream=True)
response.raise_for_status()

with open(destination, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)


def ShiftML(model_version, force_download=False, device=None):
"""
Initialize the ShiftML calculator
Expand Down Expand Up @@ -247,24 +268,14 @@ def __init__(self, model_version, force_download=False, device=None):
download = True

if download:
urllib.request.urlretrieve(url, model_file)
download_with_retry(url, model_file)
logging.info(
"Downloaded {} and saved to {}".format(model_version, cachedir)
)

except urllib.error.URLError as e:
except requests.exceptions.RequestException as e:
logging.error(
"Failed to download {} from {}. URL Error: {}".format(
model_version, url, e.reason
)
)
raise e
except urllib.error.HTTPError as e:
logging.error(
"Failed to download {} from {}.\
HTTP Error: {} - {}".format(
model_version, url, e.code, e.reason
)
"Failed to download {} from {}. Error: {}".format(model_version, url, e)
)
raise e
except Exception as e:
Expand Down