diff --git a/xanesnet/utils/gaussian.py b/xanesnet/utils/gaussian.py index 4d4546e3..0286c8ba 100644 --- a/xanesnet/utils/gaussian.py +++ b/xanesnet/utils/gaussian.py @@ -64,7 +64,7 @@ def __init__( self.register_buffer("centers", centers) def synthesize(self, coeffs: torch.Tensor): - return coeffs @ self.Phi.T + return coeffs @ self.Phi.to(coeffs.device).T def build_ridge_operator(Phi: Tensor, lam: float = 1e-2) -> Tensor: