From 387ebebf40d89dd55d765f88d5fba33fd740d804 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 7 Jan 2025 19:24:58 +0100 Subject: [PATCH] Add mixin --- hubconf.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/hubconf.py b/hubconf.py index 4827414..46d4fe0 100644 --- a/hubconf.py +++ b/hubconf.py @@ -2,20 +2,21 @@ PyTorch Hub configuration for AnySat model. """ -import os import sys -import torch import torch.nn as nn from pathlib import Path import warnings +from huggingface_hub import PyTorchModelHubMixin + REPO_ROOT = Path(__file__).parent if str(REPO_ROOT / "src") not in sys.path: sys.path.append(str(REPO_ROOT / "src")) dependencies = ['torch'] -class AnySat(nn.Module): +class AnySat(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/gastruc/AnySat", pipeline_tag="image-feature-extraction", + tags=["image-feature-extraction"], license="mit"): """ AnySat: Earth Observation Model for Any Resolutions, Scales, and Modalities @@ -80,29 +81,6 @@ def __init__(self, model_size='base', flash_attn=True, **kwargs): if device is not None: self.model = self.model.to(device) - @classmethod - def from_pretrained(cls, model_size='base', **kwargs): - """ - Create a pretrained AnySat model - - Args: - model_size (str): Model size - 'tiny', 'small', or 'base' - **kwargs: Additional arguments passed to the constructor - """ - model = cls(model_size=model_size, **kwargs) - - checkpoint_urls = { - 'base': 'https://huggingface.co/g-astruc/AnySat/resolve/main/models/AnySat.pth', - # 'small': 'https://huggingface.co/gastruc/anysat/resolve/main/anysat_small_geoplex.pth', COMING SOON - # 'tiny': 'https://huggingface.co/gastruc/anysat/resolve/main/anysat_tiny_geoplex.pth' COMING SOON - } - - checkpoint_url = checkpoint_urls[model_size] - state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, progress=True)['state_dict'] - - model.model.load_state_dict(state_dict) - return model - def forward(self, x, patch_size, output='patch', **kwargs): assert output in ['patch', 'tile', 'dense', 'all'], "Output must be one of 'patch', 'tile', 'dense', 'all'" sizes = {}