diff --git a/src/metatrain/deprecated/nanopet/model.py b/src/metatrain/deprecated/nanopet/model.py index 1d3071f060..3c3067d2ea 100644 --- a/src/metatrain/deprecated/nanopet/model.py +++ b/src/metatrain/deprecated/nanopet/model.py @@ -613,10 +613,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # Move dataset info to CPU so that it can be saved self.dataset_info = self.dataset_info.to(device="cpu") - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This funciton moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.hypers["num_gnn_layers"] * self.hypers["cutoff"]] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): diff --git a/src/metatrain/deprecated/nanopet/modules/attention.py b/src/metatrain/deprecated/nanopet/modules/attention.py index 0a58adfaeb..04207a7dcb 100644 --- a/src/metatrain/deprecated/nanopet/modules/attention.py +++ b/src/metatrain/deprecated/nanopet/modules/attention.py @@ -1,7 +1,8 @@ import torch +from metatensor.torch.learn.nn import Module -class AttentionBlock(torch.nn.Module): +class AttentionBlock(Module): """ A single transformer attention block. We are not using the MultiHeadAttention module from torch.nn because we need to apply a diff --git a/src/metatrain/deprecated/nanopet/modules/encoder.py b/src/metatrain/deprecated/nanopet/modules/encoder.py index 7e634a2d55..af6dffc4da 100644 --- a/src/metatrain/deprecated/nanopet/modules/encoder.py +++ b/src/metatrain/deprecated/nanopet/modules/encoder.py @@ -1,9 +1,10 @@ from typing import Dict import torch +from metatensor.torch.learn.nn import Module -class Encoder(torch.nn.Module): +class Encoder(Module): """ An encoder of edges. It generates a fixed-size representation of the interatomic vector, the chemical element of the center and the chemical diff --git a/src/metatrain/deprecated/nanopet/modules/feedforward.py b/src/metatrain/deprecated/nanopet/modules/feedforward.py index 03d11b682a..8258108c52 100644 --- a/src/metatrain/deprecated/nanopet/modules/feedforward.py +++ b/src/metatrain/deprecated/nanopet/modules/feedforward.py @@ -1,7 +1,8 @@ import torch +from metatensor.torch.learn.nn import Module -class FeedForwardBlock(torch.nn.Module): +class FeedForwardBlock(Module): """A single transformer feed forward block.""" def __init__( diff --git a/src/metatrain/deprecated/nanopet/modules/transformer.py b/src/metatrain/deprecated/nanopet/modules/transformer.py index 110989f90f..5aebd69154 100644 --- a/src/metatrain/deprecated/nanopet/modules/transformer.py +++ b/src/metatrain/deprecated/nanopet/modules/transformer.py @@ -1,10 +1,11 @@ import torch +from metatensor.torch.learn.nn import Module from .attention import AttentionBlock from .feedforward import FeedForwardBlock -class TransformerLayer(torch.nn.Module): +class TransformerLayer(Module): """A single transformer layer.""" def __init__( @@ -40,7 +41,7 @@ def forward( return output -class Transformer(torch.nn.Module): +class Transformer(Module): """A transformer model.""" def __init__( diff --git a/src/metatrain/deprecated/nanopet/trainer.py b/src/metatrain/deprecated/nanopet/trainer.py index 6a73df8bf0..53bb5cc6e8 100644 --- a/src/metatrain/deprecated/nanopet/trainer.py +++ b/src/metatrain/deprecated/nanopet/trainer.py @@ -147,12 +147,10 @@ def train( # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) - model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) diff --git a/src/metatrain/experimental/flashmd/model.py b/src/metatrain/experimental/flashmd/model.py index 71c8847aff..adcae13f57 100644 --- a/src/metatrain/experimental/flashmd/model.py +++ b/src/metatrain/experimental/flashmd/model.py @@ -1176,10 +1176,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # float64 self.to(dtype) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This function moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.num_gnn_layers * self.cutoff] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): diff --git a/src/metatrain/experimental/flashmd/modules/additive.py b/src/metatrain/experimental/flashmd/modules/additive.py index 4a083dc83b..6106fe4eaf 100644 --- a/src/metatrain/experimental/flashmd/modules/additive.py +++ b/src/metatrain/experimental/flashmd/modules/additive.py @@ -3,6 +3,7 @@ import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, NeighborListOptions, System from pydantic import TypeAdapter from typing_extensions import TypedDict @@ -14,7 +15,7 @@ class PositionAdditiveHypers(TypedDict): also_momenta: bool -class PositionAdditive(torch.nn.Module): +class PositionAdditive(Module): """ A simple additive model that adds the positions of the system to any outputs that is either "positions" or one of its variants. diff --git a/src/metatrain/experimental/flashmd/modules/encoder.py b/src/metatrain/experimental/flashmd/modules/encoder.py index d8d0335c46..100921f026 100644 --- a/src/metatrain/experimental/flashmd/modules/encoder.py +++ b/src/metatrain/experimental/flashmd/modules/encoder.py @@ -1,7 +1,8 @@ import torch +from metatensor.torch.learn.nn import Module -class NodeEncoder(torch.nn.Module): +class NodeEncoder(Module): """ An encoder of edges. It generates a fixed-size representation of the interatomic vector, the chemical element of the center and the chemical diff --git a/src/metatrain/experimental/flashmd/tests/test_torchscript.py b/src/metatrain/experimental/flashmd/tests/test_torchscript.py index bf80007a33..fda852dff4 100644 --- a/src/metatrain/experimental/flashmd/tests/test_torchscript.py +++ b/src/metatrain/experimental/flashmd/tests/test_torchscript.py @@ -131,7 +131,6 @@ def test_torchscript_save_load(tmpdir): ) model = FlashMD(MODEL_HYPERS, dataset_info) model.to(torch.float64) - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) with tmpdir.as_cwd(): diff --git a/src/metatrain/experimental/flashmd/trainer.py b/src/metatrain/experimental/flashmd/trainer.py index 1e5a46287b..50e528b853 100644 --- a/src/metatrain/experimental/flashmd/trainer.py +++ b/src/metatrain/experimental/flashmd/trainer.py @@ -234,12 +234,10 @@ def train( # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) - model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) diff --git a/src/metatrain/gap/model.py b/src/metatrain/gap/model.py index a8709cdfb6..56819b1558 100644 --- a/src/metatrain/gap/model.py +++ b/src/metatrain/gap/model.py @@ -6,6 +6,7 @@ import scipy import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ( AtomisticModel, ModelCapabilities, @@ -285,10 +286,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: interaction_ranges.append(additive_model.cutoff_radius) interaction_range = max(interaction_ranges) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This funciton moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=sorted(self.dataset_info.atomic_types), @@ -400,7 +397,7 @@ def predict( return KTM @ self._weights -class AggregateKernel(torch.nn.Module): +class AggregateKernel(Module): """ A kernel that aggregates values in a kernel over :param aggregate_names: using the sum as aggregate function @@ -458,7 +455,7 @@ def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap) -> TensorMap: return mts.pow(mts.dot(tensor1, tensor2), self._degree) -class TorchAggregateKernel(torch.nn.Module): +class TorchAggregateKernel(Module): """ A kernel that aggregates values in a kernel over :param aggregate_names: using the sum as aggregate function @@ -797,7 +794,7 @@ def export_torch_script_model(self) -> "TorchSubsetofRegressors": ) -class TorchSubsetofRegressors(torch.nn.Module): +class TorchSubsetofRegressors(Module): def __init__( self, weights: TensorMap, @@ -822,9 +819,6 @@ def forward(self, T: TensorMap) -> TensorMap: :return: TensorMap with the predictions """ - # move weights and X_pseudo to the same device as T - self._weights = self._weights.to(T.device) - self._X_pseudo = self._X_pseudo.to(T.device) k_tm = self._kernel(T, self._X_pseudo, are_pseudo_points=(False, True)) return mts.dot(k_tm, self._weights) diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 6d7281d7e9..f52e58e342 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -758,16 +758,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # float64 self.to(dtype) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This function moves them: - try: - self.model.additive_models[0]._move_weights_to_device_and_dtype( - torch.device("cpu"), torch.float64 - ) - except Exception: - # no weights to move - pass - metadata = merge_metadata( merge_metadata(self.__default_metadata__, metadata), self.model.export().metadata(), diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 4ff06183a0..ef8ff9f373 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -1161,10 +1161,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # float64 self.to(dtype) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This function moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.num_gnn_layers * self.cutoff] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): diff --git a/src/metatrain/pet/modules/transformer.py b/src/metatrain/pet/modules/transformer.py index 86af2f50fb..78cf76889c 100644 --- a/src/metatrain/pet/modules/transformer.py +++ b/src/metatrain/pet/modules/transformer.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F +from metatensor.torch.learn.nn import Module from torch import nn from .utilities import DummyModule @@ -105,7 +106,7 @@ def forward( return x -class TransformerLayer(torch.nn.Module): +class TransformerLayer(Module): """ Single layer of a Transformer. @@ -247,7 +248,7 @@ def forward( return node_embeddings, edge_embeddings -class Transformer(torch.nn.Module): +class Transformer(Module): """ Transformer implementation. @@ -338,7 +339,7 @@ def forward( return node_embeddings, edge_embeddings -class CartesianTransformer(torch.nn.Module): +class CartesianTransformer(Module): """ Cartesian Transformer implementation for handling 3D coordinates. diff --git a/src/metatrain/pet/modules/utilities.py b/src/metatrain/pet/modules/utilities.py index 89ab5df033..d212a61239 100644 --- a/src/metatrain/pet/modules/utilities.py +++ b/src/metatrain/pet/modules/utilities.py @@ -1,4 +1,5 @@ import torch +from metatensor.torch.learn.nn import Module def cutoff_func(grid: torch.Tensor, r_cut: float, delta: float) -> torch.Tensor: @@ -20,7 +21,7 @@ def cutoff_func(grid: torch.Tensor, r_cut: float, delta: float) -> torch.Tensor: return f -class DummyModule(torch.nn.Module): +class DummyModule(Module): """Dummy torch module to make torchscript happy. This model should never be run""" diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index d5fa635f2c..d603b6def7 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -207,12 +207,10 @@ def train( # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) - model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) diff --git a/src/metatrain/soap_bpnn/model.py b/src/metatrain/soap_bpnn/model.py index 94b1c8aa4b..4df0bb68d6 100644 --- a/src/metatrain/soap_bpnn/model.py +++ b/src/metatrain/soap_bpnn/model.py @@ -5,7 +5,7 @@ import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.learn.nn import Linear as LinearMap -from metatensor.torch.learn.nn import ModuleMap +from metatensor.torch.learn.nn import Module, ModuleMap from metatensor.torch.operations._add import _add_block_block from metatomic.torch import ( AtomisticModel, @@ -32,7 +32,7 @@ from .modules.tensor_basis import TensorBasis -class Identity(torch.nn.Module): +class Identity(Module): def __init__(self) -> None: super().__init__() @@ -48,7 +48,7 @@ def __init__(self, atomic_types: List[int], hypers: dict) -> None: # Build a neural network for each species nns_per_species = [] for _ in atomic_types: - module_list: List[torch.nn.Module] = [] + module_list: List[Module] = [] for _ in range(hypers["num_hidden_layers"]): if len(module_list) == 0: module_list.append( @@ -946,10 +946,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # float64 self.to(dtype) - # Additionally, the composition model contains some `TensorMap`s that cannot - # be registered correctly with Pytorch. This funciton moves them: - self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.hypers["soap"]["cutoff"]["radius"]] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): diff --git a/src/metatrain/soap_bpnn/trainer.py b/src/metatrain/soap_bpnn/trainer.py index f405f8346e..55eb55b7c7 100644 --- a/src/metatrain/soap_bpnn/trainer.py +++ b/src/metatrain/soap_bpnn/trainer.py @@ -184,12 +184,10 @@ def train( # Extract additive models and scaler and move them to CPU/float64 so they # can be used in the collate function - model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) additive_models = copy.deepcopy( model.additive_models.to(dtype=torch.float64, device="cpu") ) model.additive_models.to(device) - model.additive_models[0].weights_to(device=device, dtype=torch.float64) model.scaler.scales_to(device="cpu", dtype=torch.float64) scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) model.scaler.to(device) diff --git a/src/metatrain/utils/additive/_base_composition.py b/src/metatrain/utils/additive/_base_composition.py index 93c054dddd..edf0a8cbb2 100644 --- a/src/metatrain/utils/additive/_base_composition.py +++ b/src/metatrain/utils/additive/_base_composition.py @@ -10,13 +10,14 @@ import metatensor.torch as mts import torch from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, System FixedCompositionWeights = dict[str, float | dict[int, float]] -class BaseCompositionModel(torch.nn.Module): +class BaseCompositionModel(Module): """ Fits a composition model for a dict of targets. @@ -230,7 +231,6 @@ def accumulate( device = systems[0].positions.device dtype = systems[0].positions.dtype - self._sync_device_dtype(device, dtype) # check that the systems contain no unexpected atom types for system in systems: @@ -408,8 +408,6 @@ def forward( """ device = systems[0].positions.device - dtype = systems[0].positions.dtype - self._sync_device_dtype(device, dtype) # Build the sample labels that are required _, sample_labels = _get_system_indices_and_labels(systems, device) @@ -518,24 +516,6 @@ def _compute_X_per_atom( ) return one_hot_encoding.to(dtype) - def _sync_device_dtype(self, device: torch.device, dtype: torch.dtype) -> None: - # manually move the TensorMap dicts: - - self.atomic_types = self.atomic_types.to(device=device) - self.type_to_index = self.type_to_index.to(device=device) - self.XTX = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.XTX.items() - } - self.XTY = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.XTY.items() - } - self.weights = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.weights.items() - } - def _include_key(key: LabelsEntry) -> bool: """ diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index 95320e3732..b94b531bfd 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -4,6 +4,7 @@ import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, NeighborListOptions, System from torch.utils.data import DataLoader, DistributedSampler @@ -24,7 +25,7 @@ from .remove import remove_additive -class CompositionModel(torch.nn.Module): +class CompositionModel(Module): """ A simple model that calculates the per-species contributions to targets based on the stoichiometry in a system. @@ -176,7 +177,7 @@ def _get_dataloader( def train_model( self, datasets: List[Union[Dataset, torch.utils.data.Subset]], - additive_models: List[torch.nn.Module], + additive_models: List[Module], batch_size: int, is_distributed: bool, fixed_weights: Optional[FixedCompositionWeights] = None, @@ -345,10 +346,6 @@ def forward( :raises ValueError: If no weights have been computed or if `outputs` keys contain unsupported keys. """ - dtype = systems[0].positions.dtype - device = systems[0].positions.device - - self.weights_to(device, dtype) for output_name in outputs.keys(): if output_name not in self.outputs: @@ -415,19 +412,6 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None: mts.save_buffer(mts.make_contiguous(fake_weights)), ) - def weights_to(self, device: torch.device, dtype: torch.dtype) -> None: - if len(self.model.weights) != 0: - if self.model.weights[list(self.model.weights.keys())[0]].device != device: - self.model.weights = { - k: v.to(device) for k, v in self.model.weights.items() - } - if self.model.weights[list(self.model.weights.keys())[0]].dtype != dtype: - self.model.weights = { - k: v.to(dtype) for k, v in self.model.weights.items() - } - - self.model._sync_device_dtype(device, dtype) - @staticmethod def is_valid_target(target_name: str, target_info: TargetInfo) -> bool: """Finds if a ``TargetInfo`` object is compatible with a composition model. diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py index e0bb6b2be6..fe4dbc8603 100644 --- a/src/metatrain/utils/additive/zbl.py +++ b/src/metatrain/utils/additive/zbl.py @@ -5,13 +5,14 @@ import torch from ase.data import covalent_radii from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, NeighborListOptions, System from ..data import DatasetInfo, TargetInfo from ..sum_over_atoms import sum_over_atoms -class ZBL(torch.nn.Module): +class ZBL(Module): """ A simple model for short-range repulsive interactions. diff --git a/src/metatrain/utils/long_range.py b/src/metatrain/utils/long_range.py index 1e73c14bff..6ad51905ca 100644 --- a/src/metatrain/utils/long_range.py +++ b/src/metatrain/utils/long_range.py @@ -2,6 +2,7 @@ # We ignore misc errors in this file because TypedDict # with default values is not allowed by mypy. import torch +from metatensor.torch.learn.nn import Module from metatomic.torch import System from typing_extensions import TypedDict @@ -25,7 +26,7 @@ class LongRangeHypers(TypedDict): """Number of grid points for interpolation (for PME only)""" -class LongRangeFeaturizer(torch.nn.Module): +class LongRangeFeaturizer(Module): """A class to compute long-range features starting from short-range features. :param hypers: Dictionary containing the hyperparameters for the long-range @@ -191,7 +192,7 @@ def forward( return torch.concatenate(long_range_features) -class DummyLongRangeFeaturizer(torch.nn.Module): +class DummyLongRangeFeaturizer(Module): # a dummy class for torchscript def __init__(self) -> None: super().__init__() diff --git a/src/metatrain/utils/scaler/_base_scaler.py b/src/metatrain/utils/scaler/_base_scaler.py index 76a1bfac96..0d56e6bdc0 100644 --- a/src/metatrain/utils/scaler/_base_scaler.py +++ b/src/metatrain/utils/scaler/_base_scaler.py @@ -8,13 +8,14 @@ import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import System FixedScalerWeights = dict[str, Union[float, dict[int, float]]] -class BaseScaler(torch.nn.Module): +class BaseScaler(Module): """ Fits a scaler for a dict of targets. Scales are computed as the per-property (and therefore per-block) standard deviations. By default, the scales are also computed @@ -170,7 +171,6 @@ def accumulate( device = list(targets.values())[0][0].values.device dtype = list(targets.values())[0][0].values.dtype - self._sync_device_dtype(device, dtype) # accumulate for target_name, target in targets.items(): @@ -312,7 +312,7 @@ def fit( blocks.append(block) self.scales[target_name] = TensorMap( - self.Y2[target_name].keys.to(device=scale_vals_type.device), + self.Y2[target_name].keys, blocks, ) @@ -340,8 +340,6 @@ def forward( """ device = list(outputs.values())[0][0].values.device - dtype = list(outputs.values())[0][0].values.dtype - self._sync_device_dtype(device, dtype) # Build the scaled outputs for each output predictions: Dict[str, TensorMap] = {} @@ -532,24 +530,6 @@ def _apply_fixed_weights( ) self.scales[target_name] = TensorMap( - self.Y2[target_name].keys.to(device=block.values.device), + self.Y2[target_name].keys, [block], ) - - def _sync_device_dtype(self, device: torch.device, dtype: torch.dtype) -> None: - # manually move the TensorMap dicts: - - self.atomic_types = self.atomic_types.to(device=device) - self.type_to_index = self.type_to_index.to(device=device) - self.N = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.N.items() - } - self.Y2 = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.Y2.items() - } - self.scales = { - target_name: tm.to(device=device, dtype=dtype) - for target_name, tm in self.scales.items() - } diff --git a/src/metatrain/utils/scaler/scaler.py b/src/metatrain/utils/scaler/scaler.py index 1aa564064d..5c304cb38a 100644 --- a/src/metatrain/utils/scaler/scaler.py +++ b/src/metatrain/utils/scaler/scaler.py @@ -3,6 +3,7 @@ import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Module from metatomic.torch import ModelOutput, System from torch.utils.data import DataLoader, DistributedSampler @@ -19,7 +20,7 @@ from ._base_scaler import BaseScaler, FixedScalerWeights -class Scaler(torch.nn.Module): +class Scaler(Module): """ Placeholder docs. @@ -342,8 +343,6 @@ def scales_to(self, device: torch.device, dtype: torch.dtype) -> None: k: v.to(dtype) for k, v in self.model.scales.items() } - self.model._sync_device_dtype(device, dtype) - def sync_tensor_maps(self) -> None: # Reload the scales of the (old) targets, which are not stored in the model # state_dict, from the buffers