From 3a987f2551af99a46bb21d97f9f6571e15982966 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Thu, 20 Nov 2025 14:24:49 +0100 Subject: [PATCH 1/4] Replace `torch.nn.Module` with `metatensor.torch.learn.nn.Module` --- src/metatrain/deprecated/nanopet/modules/attention.py | 3 ++- src/metatrain/deprecated/nanopet/modules/encoder.py | 3 ++- src/metatrain/deprecated/nanopet/modules/feedforward.py | 3 ++- src/metatrain/deprecated/nanopet/modules/transformer.py | 5 +++-- src/metatrain/experimental/flashmd/modules/additive.py | 3 ++- src/metatrain/experimental/flashmd/modules/encoder.py | 3 ++- src/metatrain/gap/model.py | 7 ++++--- src/metatrain/pet/modules/transformer.py | 7 ++++--- src/metatrain/pet/modules/utilities.py | 3 ++- src/metatrain/soap_bpnn/model.py | 5 +++-- src/metatrain/utils/additive/_base_composition.py | 3 ++- src/metatrain/utils/additive/composition.py | 5 +++-- src/metatrain/utils/additive/zbl.py | 3 ++- src/metatrain/utils/long_range.py | 5 +++-- src/metatrain/utils/scaler/_base_scaler.py | 3 ++- src/metatrain/utils/scaler/scaler.py | 3 ++- 16 files changed, 40 insertions(+), 24 deletions(-) diff --git a/src/metatrain/deprecated/nanopet/modules/attention.py b/src/metatrain/deprecated/nanopet/modules/attention.py index 0a58adfaeb..04207a7dcb 100644 --- a/src/metatrain/deprecated/nanopet/modules/attention.py +++ b/src/metatrain/deprecated/nanopet/modules/attention.py @@ -1,7 +1,8 @@ import torch +from metatensor.torch.learn.nn import Module -class AttentionBlock(torch.nn.Module): +class AttentionBlock(Module): """ A single transformer attention block. We are not using the MultiHeadAttention module from torch.nn because we need to apply a diff --git a/src/metatrain/deprecated/nanopet/modules/encoder.py b/src/metatrain/deprecated/nanopet/modules/encoder.py index 7e634a2d55..af6dffc4da 100644 --- a/src/metatrain/deprecated/nanopet/modules/encoder.py +++ b/src/metatrain/deprecated/nanopet/modules/encoder.py @@ -1,9 +1,10 @@ from typing import Dict import torch +from metatensor.torch.learn.nn import Module -class Encoder(torch.nn.Module): +class Encoder(Module): """ An encoder of edges. It generates a fixed-size representation of the interatomic vector, the chemical element of the center and the chemical diff --git a/src/metatrain/deprecated/nanopet/modules/feedforward.py b/src/metatrain/deprecated/nanopet/modules/feedforward.py index 03d11b682a..8258108c52 100644 --- a/src/metatrain/deprecated/nanopet/modules/feedforward.py +++ b/src/metatrain/deprecated/nanopet/modules/feedforward.py @@ -1,7 +1,8 @@ import torch +from metatensor.torch.learn.nn import Module -class FeedForwardBlock(torch.nn.Module): +class FeedForwardBlock(Module): """A single transformer feed forward block.""" def __init__( diff --git a/src/metatrain/deprecated/nanopet/modules/transformer.py b/src/metatrain/deprecated/nanopet/modules/transformer.py index 110989f90f..5aebd69154 100644 --- a/src/metatrain/deprecated/nanopet/modules/transformer.py +++ b/src/metatrain/deprecated/nanopet/modules/transformer.py @@ -1,10 +1,11 @@ import torch +from metatensor.torch.learn.nn import Module from .attention import AttentionBlock from .feedforward import FeedForwardBlock -class TransformerLayer(torch.nn.Module): +class TransformerLayer(Module): """A single transformer layer.""" def __init__( @@ -40,7 +41,7 @@ def forward( return output -class Transformer(torch.nn.Module): +class Transformer(Module): """A transformer model.""" def __init__( diff --git a/src/metatrain/experimental/flashmd/modules/additive.py b/src/metatrain/experimental/flashmd/modules/additive.py index 4a083dc83b..1386728908 100644 --- a/src/metatrain/experimental/flashmd/modules/additive.py +++ b/src/metatrain/experimental/flashmd/modules/additive.py @@ -2,6 +2,7 @@ import metatensor.torch as mts import torch +from metatensor.torch.learn.nn import Module from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ModelOutput, NeighborListOptions, System from pydantic import TypeAdapter @@ -14,7 +15,7 @@ class PositionAdditiveHypers(TypedDict): also_momenta: bool -class PositionAdditive(torch.nn.Module): +class PositionAdditive(Module): """ A simple additive model that adds the positions of the system to any outputs that is either "positions" or one of its variants. diff --git a/src/metatrain/experimental/flashmd/modules/encoder.py b/src/metatrain/experimental/flashmd/modules/encoder.py index d8d0335c46..100921f026 100644 --- a/src/metatrain/experimental/flashmd/modules/encoder.py +++ b/src/metatrain/experimental/flashmd/modules/encoder.py @@ -1,7 +1,8 @@ import torch +from metatensor.torch.learn.nn import Module -class NodeEncoder(torch.nn.Module): +class NodeEncoder(Module): """ An encoder of edges. It generates a fixed-size representation of the interatomic vector, the chemical element of the center and the chemical diff --git a/src/metatrain/gap/model.py b/src/metatrain/gap/model.py index a8709cdfb6..8a26f93170 100644 --- a/src/metatrain/gap/model.py +++ b/src/metatrain/gap/model.py @@ -5,6 +5,7 @@ import numpy as np import scipy import torch +from metatensor.torch.learn.nn import Module from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( AtomisticModel, @@ -400,7 +401,7 @@ def predict( return KTM @ self._weights -class AggregateKernel(torch.nn.Module): +class AggregateKernel(Module): """ A kernel that aggregates values in a kernel over :param aggregate_names: using the sum as aggregate function @@ -458,7 +459,7 @@ def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap) -> TensorMap: return mts.pow(mts.dot(tensor1, tensor2), self._degree) -class TorchAggregateKernel(torch.nn.Module): +class TorchAggregateKernel(Module): """ A kernel that aggregates values in a kernel over :param aggregate_names: using the sum as aggregate function @@ -797,7 +798,7 @@ def export_torch_script_model(self) -> "TorchSubsetofRegressors": ) -class TorchSubsetofRegressors(torch.nn.Module): +class TorchSubsetofRegressors(Module): def __init__( self, weights: TensorMap, diff --git a/src/metatrain/pet/modules/transformer.py b/src/metatrain/pet/modules/transformer.py index 86af2f50fb..78cf76889c 100644 --- a/src/metatrain/pet/modules/transformer.py +++ b/src/metatrain/pet/modules/transformer.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F +from metatensor.torch.learn.nn import Module from torch import nn from .utilities import DummyModule @@ -105,7 +106,7 @@ def forward( return x -class TransformerLayer(torch.nn.Module): +class TransformerLayer(Module): """ Single layer of a Transformer. @@ -247,7 +248,7 @@ def forward( return node_embeddings, edge_embeddings -class Transformer(torch.nn.Module): +class Transformer(Module): """ Transformer implementation. @@ -338,7 +339,7 @@ def forward( return node_embeddings, edge_embeddings -class CartesianTransformer(torch.nn.Module): +class CartesianTransformer(Module): """ Cartesian Transformer implementation for handling 3D coordinates. diff --git a/src/metatrain/pet/modules/utilities.py b/src/metatrain/pet/modules/utilities.py index 89ab5df033..d212a61239 100644 --- a/src/metatrain/pet/modules/utilities.py +++ b/src/metatrain/pet/modules/utilities.py @@ -1,4 +1,5 @@ import torch +from metatensor.torch.learn.nn import Module def cutoff_func(grid: torch.Tensor, r_cut: float, delta: float) -> torch.Tensor: @@ -20,7 +21,7 @@ def cutoff_func(grid: torch.Tensor, r_cut: float, delta: float) -> torch.Tensor: return f -class DummyModule(torch.nn.Module): +class DummyModule(Module): """Dummy torch module to make torchscript happy. This model should never be run""" diff --git a/src/metatrain/soap_bpnn/model.py b/src/metatrain/soap_bpnn/model.py index 94b1c8aa4b..f28e00d0d7 100644 --- a/src/metatrain/soap_bpnn/model.py +++ b/src/metatrain/soap_bpnn/model.py @@ -3,6 +3,7 @@ import metatensor.torch as mts import torch +from metatensor.torch.learn.nn import Module from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.learn.nn import Linear as LinearMap from metatensor.torch.learn.nn import ModuleMap @@ -32,7 +33,7 @@ from .modules.tensor_basis import TensorBasis -class Identity(torch.nn.Module): +class Identity(Module): def __init__(self) -> None: super().__init__() @@ -48,7 +49,7 @@ def __init__(self, atomic_types: List[int], hypers: dict) -> None: # Build a neural network for each species nns_per_species = [] for _ in atomic_types: - module_list: List[torch.nn.Module] = [] + module_list: List[Module] = [] for _ in range(hypers["num_hidden_layers"]): if len(module_list) == 0: module_list.append( diff --git a/src/metatrain/utils/additive/_base_composition.py b/src/metatrain/utils/additive/_base_composition.py index 93c054dddd..28245195f5 100644 --- a/src/metatrain/utils/additive/_base_composition.py +++ b/src/metatrain/utils/additive/_base_composition.py @@ -10,13 +10,14 @@ import metatensor.torch as mts import torch from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, System FixedCompositionWeights = dict[str, float | dict[int, float]] -class BaseCompositionModel(torch.nn.Module): +class BaseCompositionModel(Module): """ Fits a composition model for a dict of targets. diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index 95320e3732..3fd1ad422f 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -4,6 +4,7 @@ import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, NeighborListOptions, System from torch.utils.data import DataLoader, DistributedSampler @@ -24,7 +25,7 @@ from .remove import remove_additive -class CompositionModel(torch.nn.Module): +class CompositionModel(Module): """ A simple model that calculates the per-species contributions to targets based on the stoichiometry in a system. @@ -176,7 +177,7 @@ def _get_dataloader( def train_model( self, datasets: List[Union[Dataset, torch.utils.data.Subset]], - additive_models: List[torch.nn.Module], + additive_models: List[Module], batch_size: int, is_distributed: bool, fixed_weights: Optional[FixedCompositionWeights] = None, diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py index e0bb6b2be6..fe4dbc8603 100644 --- a/src/metatrain/utils/additive/zbl.py +++ b/src/metatrain/utils/additive/zbl.py @@ -5,13 +5,14 @@ import torch from ase.data import covalent_radii from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, NeighborListOptions, System from ..data import DatasetInfo, TargetInfo from ..sum_over_atoms import sum_over_atoms -class ZBL(torch.nn.Module): +class ZBL(Module): """ A simple model for short-range repulsive interactions. diff --git a/src/metatrain/utils/long_range.py b/src/metatrain/utils/long_range.py index 1e73c14bff..59d47db736 100644 --- a/src/metatrain/utils/long_range.py +++ b/src/metatrain/utils/long_range.py @@ -3,6 +3,7 @@ # with default values is not allowed by mypy. import torch from metatomic.torch import System +from metatensor.torch.learn.nn import Module from typing_extensions import TypedDict from metatrain.utils.neighbor_lists import NeighborListOptions @@ -25,7 +26,7 @@ class LongRangeHypers(TypedDict): """Number of grid points for interpolation (for PME only)""" -class LongRangeFeaturizer(torch.nn.Module): +class LongRangeFeaturizer(Module): """A class to compute long-range features starting from short-range features. :param hypers: Dictionary containing the hyperparameters for the long-range @@ -191,7 +192,7 @@ def forward( return torch.concatenate(long_range_features) -class DummyLongRangeFeaturizer(torch.nn.Module): +class DummyLongRangeFeaturizer(Module): # a dummy class for torchscript def __init__(self) -> None: super().__init__() diff --git a/src/metatrain/utils/scaler/_base_scaler.py b/src/metatrain/utils/scaler/_base_scaler.py index 76a1bfac96..f9592f2413 100644 --- a/src/metatrain/utils/scaler/_base_scaler.py +++ b/src/metatrain/utils/scaler/_base_scaler.py @@ -8,13 +8,14 @@ import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import System FixedScalerWeights = dict[str, Union[float, dict[int, float]]] -class BaseScaler(torch.nn.Module): +class BaseScaler(Module): """ Fits a scaler for a dict of targets. Scales are computed as the per-property (and therefore per-block) standard deviations. By default, the scales are also computed diff --git a/src/metatrain/utils/scaler/scaler.py b/src/metatrain/utils/scaler/scaler.py index 1aa564064d..bea2614731 100644 --- a/src/metatrain/utils/scaler/scaler.py +++ b/src/metatrain/utils/scaler/scaler.py @@ -2,6 +2,7 @@ import metatensor.torch as mts import torch +from metatensor.torch.learn.nn import Module from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ModelOutput, System from torch.utils.data import DataLoader, DistributedSampler @@ -19,7 +20,7 @@ from ._base_scaler import BaseScaler, FixedScalerWeights -class Scaler(torch.nn.Module): +class Scaler(Module): """ Placeholder docs. From 4768300d652c5d098b5c54105ef7e293077ba472 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Fri, 21 Nov 2025 11:53:08 +0100 Subject: [PATCH 2/4] Update a little bit --- src/metatrain/experimental/flashmd/modules/additive.py | 2 +- src/metatrain/gap/model.py | 2 +- src/metatrain/soap_bpnn/model.py | 3 +-- src/metatrain/utils/long_range.py | 2 +- src/metatrain/utils/scaler/scaler.py | 2 +- 5 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/metatrain/experimental/flashmd/modules/additive.py b/src/metatrain/experimental/flashmd/modules/additive.py index 1386728908..6106fe4eaf 100644 --- a/src/metatrain/experimental/flashmd/modules/additive.py +++ b/src/metatrain/experimental/flashmd/modules/additive.py @@ -2,8 +2,8 @@ import metatensor.torch as mts import torch -from metatensor.torch.learn.nn import Module from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, NeighborListOptions, System from pydantic import TypeAdapter from typing_extensions import TypedDict diff --git a/src/metatrain/gap/model.py b/src/metatrain/gap/model.py index 8a26f93170..5e4b4659fe 100644 --- a/src/metatrain/gap/model.py +++ b/src/metatrain/gap/model.py @@ -5,8 +5,8 @@ import numpy as np import scipy import torch -from metatensor.torch.learn.nn import Module from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ( AtomisticModel, ModelCapabilities, diff --git a/src/metatrain/soap_bpnn/model.py b/src/metatrain/soap_bpnn/model.py index f28e00d0d7..981886d188 100644 --- a/src/metatrain/soap_bpnn/model.py +++ b/src/metatrain/soap_bpnn/model.py @@ -3,10 +3,9 @@ import metatensor.torch as mts import torch -from metatensor.torch.learn.nn import Module from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.learn.nn import Linear as LinearMap -from metatensor.torch.learn.nn import ModuleMap +from metatensor.torch.learn.nn import Module, ModuleMap from metatensor.torch.operations._add import _add_block_block from metatomic.torch import ( AtomisticModel, diff --git a/src/metatrain/utils/long_range.py b/src/metatrain/utils/long_range.py index 59d47db736..6ad51905ca 100644 --- a/src/metatrain/utils/long_range.py +++ b/src/metatrain/utils/long_range.py @@ -2,8 +2,8 @@ # We ignore misc errors in this file because TypedDict # with default values is not allowed by mypy. import torch -from metatomic.torch import System from metatensor.torch.learn.nn import Module +from metatomic.torch import System from typing_extensions import TypedDict from metatrain.utils.neighbor_lists import NeighborListOptions diff --git a/src/metatrain/utils/scaler/scaler.py b/src/metatrain/utils/scaler/scaler.py index bea2614731..8c74b8e079 100644 --- a/src/metatrain/utils/scaler/scaler.py +++ b/src/metatrain/utils/scaler/scaler.py @@ -2,8 +2,8 @@ import metatensor.torch as mts import torch -from metatensor.torch.learn.nn import Module from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, System from torch.utils.data import DataLoader, DistributedSampler From b828f4d4d47f0f0c09c3a26c15249fcca835e446 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Fri, 21 Nov 2025 12:05:57 +0100 Subject: [PATCH 3/4] Remove `to` related to properties of `self` --- src/metatrain/gap/model.py | 3 --- .../utils/additive/_base_composition.py | 21 ---------------- src/metatrain/utils/additive/composition.py | 17 ------------- src/metatrain/utils/scaler/_base_scaler.py | 25 ++----------------- 4 files changed, 2 insertions(+), 64 deletions(-) diff --git a/src/metatrain/gap/model.py b/src/metatrain/gap/model.py index 5e4b4659fe..489351c61f 100644 --- a/src/metatrain/gap/model.py +++ b/src/metatrain/gap/model.py @@ -823,9 +823,6 @@ def forward(self, T: TensorMap) -> TensorMap: :return: TensorMap with the predictions """ - # move weights and X_pseudo to the same device as T - self._weights = self._weights.to(T.device) - self._X_pseudo = self._X_pseudo.to(T.device) k_tm = self._kernel(T, self._X_pseudo, are_pseudo_points=(False, True)) return mts.dot(k_tm, self._weights) diff --git a/src/metatrain/utils/additive/_base_composition.py b/src/metatrain/utils/additive/_base_composition.py index 28245195f5..edf0a8cbb2 100644 --- a/src/metatrain/utils/additive/_base_composition.py +++ b/src/metatrain/utils/additive/_base_composition.py @@ -231,7 +231,6 @@ def accumulate( device = systems[0].positions.device dtype = systems[0].positions.dtype - self._sync_device_dtype(device, dtype) # check that the systems contain no unexpected atom types for system in systems: @@ -409,8 +408,6 @@ def forward( """ device = systems[0].positions.device - dtype = systems[0].positions.dtype - self._sync_device_dtype(device, dtype) # Build the sample labels that are required _, sample_labels = _get_system_indices_and_labels(systems, device) @@ -519,24 +516,6 @@ def _compute_X_per_atom( ) return one_hot_encoding.to(dtype) - def _sync_device_dtype(self, device: torch.device, dtype: torch.dtype) -> None: - # manually move the TensorMap dicts: - - self.atomic_types = self.atomic_types.to(device=device) - self.type_to_index = self.type_to_index.to(device=device) - self.XTX = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.XTX.items() - } - self.XTY = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.XTY.items() - } - self.weights = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.weights.items() - } - def _include_key(key: LabelsEntry) -> bool: """ diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index 3fd1ad422f..b94b531bfd 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -346,10 +346,6 @@ def forward( :raises ValueError: If no weights have been computed or if `outputs` keys contain unsupported keys. """ - dtype = systems[0].positions.dtype - device = systems[0].positions.device - - self.weights_to(device, dtype) for output_name in outputs.keys(): if output_name not in self.outputs: @@ -416,19 +412,6 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None: mts.save_buffer(mts.make_contiguous(fake_weights)), ) - def weights_to(self, device: torch.device, dtype: torch.dtype) -> None: - if len(self.model.weights) != 0: - if self.model.weights[list(self.model.weights.keys())[0]].device != device: - self.model.weights = { - k: v.to(device) for k, v in self.model.weights.items() - } - if self.model.weights[list(self.model.weights.keys())[0]].dtype != dtype: - self.model.weights = { - k: v.to(dtype) for k, v in self.model.weights.items() - } - - self.model._sync_device_dtype(device, dtype) - @staticmethod def is_valid_target(target_name: str, target_info: TargetInfo) -> bool: """Finds if a ``TargetInfo`` object is compatible with a composition model. diff --git a/src/metatrain/utils/scaler/_base_scaler.py b/src/metatrain/utils/scaler/_base_scaler.py index f9592f2413..0d56e6bdc0 100644 --- a/src/metatrain/utils/scaler/_base_scaler.py +++ b/src/metatrain/utils/scaler/_base_scaler.py @@ -171,7 +171,6 @@ def accumulate( device = list(targets.values())[0][0].values.device dtype = list(targets.values())[0][0].values.dtype - self._sync_device_dtype(device, dtype) # accumulate for target_name, target in targets.items(): @@ -313,7 +312,7 @@ def fit( blocks.append(block) self.scales[target_name] = TensorMap( - self.Y2[target_name].keys.to(device=scale_vals_type.device), + self.Y2[target_name].keys, blocks, ) @@ -341,8 +340,6 @@ def forward( """ device = list(outputs.values())[0][0].values.device - dtype = list(outputs.values())[0][0].values.dtype - self._sync_device_dtype(device, dtype) # Build the scaled outputs for each output predictions: Dict[str, TensorMap] = {} @@ -533,24 +530,6 @@ def _apply_fixed_weights( ) self.scales[target_name] = TensorMap( - self.Y2[target_name].keys.to(device=block.values.device), + self.Y2[target_name].keys, [block], ) - - def _sync_device_dtype(self, device: torch.device, dtype: torch.dtype) -> None: - # manually move the TensorMap dicts: - - self.atomic_types = self.atomic_types.to(device=device) - self.type_to_index = self.type_to_index.to(device=device) - self.N = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.N.items() - } - self.Y2 = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.Y2.items() - } - self.scales = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.scales.items() - } From 07b3f5ef33c7b5561a6664135f2738225cdb2063 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Fri, 21 Nov 2025 12:24:38 +0100 Subject: [PATCH 4/4] Remove more and break more tests --- src/metatrain/deprecated/nanopet/model.py | 4 ---- src/metatrain/deprecated/nanopet/trainer.py | 2 -- src/metatrain/experimental/flashmd/model.py | 4 ---- .../experimental/flashmd/tests/test_torchscript.py | 1 - src/metatrain/experimental/flashmd/trainer.py | 2 -- src/metatrain/gap/model.py | 4 ---- src/metatrain/llpr/model.py | 10 ---------- src/metatrain/pet/model.py | 4 ---- src/metatrain/pet/trainer.py | 2 -- src/metatrain/soap_bpnn/model.py | 4 ---- src/metatrain/soap_bpnn/trainer.py | 2 -- src/metatrain/utils/scaler/scaler.py | 2 -- 12 files changed, 41 deletions(-) diff --git a/src/metatrain/deprecated/nanopet/model.py b/src/metatrain/deprecated/nanopet/model.py index 1d3071f060..3c3067d2ea 100644 --- a/src/metatrain/deprecated/nanopet/model.py +++ b/src/metatrain/deprecated/nanopet/model.py @@ -613,10 +613,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # Move dataset info to CPU so that it can be saved self.dataset_info = self.dataset_info.to(device="cpu") - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This funciton moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.hypers["num_gnn_layers"] * self.hypers["cutoff"]] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): diff --git a/src/metatrain/deprecated/nanopet/trainer.py b/src/metatrain/deprecated/nanopet/trainer.py index 6a73df8bf0..53bb5cc6e8 100644 --- a/src/metatrain/deprecated/nanopet/trainer.py +++ b/src/metatrain/deprecated/nanopet/trainer.py @@ -147,12 +147,10 @@ def train( # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) - model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) diff --git a/src/metatrain/experimental/flashmd/model.py b/src/metatrain/experimental/flashmd/model.py index 71c8847aff..adcae13f57 100644 --- a/src/metatrain/experimental/flashmd/model.py +++ b/src/metatrain/experimental/flashmd/model.py @@ -1176,10 +1176,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # float64 self.to(dtype) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This function moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.num_gnn_layers * self.cutoff] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): diff --git a/src/metatrain/experimental/flashmd/tests/test_torchscript.py b/src/metatrain/experimental/flashmd/tests/test_torchscript.py index bf80007a33..fda852dff4 100644 --- a/src/metatrain/experimental/flashmd/tests/test_torchscript.py +++ b/src/metatrain/experimental/flashmd/tests/test_torchscript.py @@ -131,7 +131,6 @@ def test_torchscript_save_load(tmpdir): ) model = FlashMD(MODEL_HYPERS, dataset_info) model.to(torch.float64) - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) with tmpdir.as_cwd(): diff --git a/src/metatrain/experimental/flashmd/trainer.py b/src/metatrain/experimental/flashmd/trainer.py index 1e5a46287b..50e528b853 100644 --- a/src/metatrain/experimental/flashmd/trainer.py +++ b/src/metatrain/experimental/flashmd/trainer.py @@ -234,12 +234,10 @@ def train( # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) - model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) diff --git a/src/metatrain/gap/model.py b/src/metatrain/gap/model.py index 489351c61f..56819b1558 100644 --- a/src/metatrain/gap/model.py +++ b/src/metatrain/gap/model.py @@ -286,10 +286,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: interaction_ranges.append(additive_model.cutoff_radius) interaction_range = max(interaction_ranges) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This funciton moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=sorted(self.dataset_info.atomic_types), diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 6d7281d7e9..f52e58e342 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -758,16 +758,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # float64 self.to(dtype) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This function moves them: - try: - self.model.additive_models[0]._move_weights_to_device_and_dtype( - torch.device("cpu"), torch.float64 - ) - except Exception: - # no weights to move - pass - metadata = merge_metadata( merge_metadata(self.__default_metadata__, metadata), self.model.export().metadata(), diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 4ff06183a0..ef8ff9f373 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -1161,10 +1161,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # float64 self.to(dtype) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This function moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.num_gnn_layers * self.cutoff] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index d5fa635f2c..d603b6def7 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -207,12 +207,10 @@ def train( # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) - model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) diff --git a/src/metatrain/soap_bpnn/model.py b/src/metatrain/soap_bpnn/model.py index 981886d188..4df0bb68d6 100644 --- a/src/metatrain/soap_bpnn/model.py +++ b/src/metatrain/soap_bpnn/model.py @@ -946,10 +946,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # float64 self.to(dtype) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This funciton moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.hypers["soap"]["cutoff"]["radius"]] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): diff --git a/src/metatrain/soap_bpnn/trainer.py b/src/metatrain/soap_bpnn/trainer.py index f405f8346e..55eb55b7c7 100644 --- a/src/metatrain/soap_bpnn/trainer.py +++ b/src/metatrain/soap_bpnn/trainer.py @@ -184,12 +184,10 @@ def train( # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) - model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) diff --git a/src/metatrain/utils/scaler/scaler.py b/src/metatrain/utils/scaler/scaler.py index 8c74b8e079..5c304cb38a 100644 --- a/src/metatrain/utils/scaler/scaler.py +++ b/src/metatrain/utils/scaler/scaler.py @@ -343,8 +343,6 @@ def scales_to(self, device: torch.device, dtype: torch.dtype) -> None: k: v.to(dtype) for k, v in self.model.scales.items() } - self.model._sync_device_dtype(device, dtype) - def sync_tensor_maps(self) -> None: # Reload the scales of the (old) targets, which are not stored in the model # state_dict, from the buffers