From 955ffbafef1e233fed3045f4c7d56ec7ad17ac22 Mon Sep 17 00:00:00 2001 From: Sacha Perry-Fagant Date: Mon, 6 Jan 2025 12:49:08 -0500 Subject: [PATCH] Add support for mps DEVICE type --- src/score_models/plot_utils.py | 2 +- src/score_models/toy_distributions.py | 2 +- src/score_models/utils.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/score_models/plot_utils.py b/src/score_models/plot_utils.py index 9bd8879..d17e7d4 100644 --- a/src/score_models/plot_utils.py +++ b/src/score_models/plot_utils.py @@ -11,7 +11,7 @@ from scipy.special import logsumexp from matplotlib.colors import Normalize -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") # plt.style.use('dark_background') plt.style.use('science') diff --git a/src/score_models/toy_distributions.py b/src/score_models/toy_distributions.py index a7acc55..5733598 100644 --- a/src/score_models/toy_distributions.py +++ b/src/score_models/toy_distributions.py @@ -3,7 +3,7 @@ from torch.distributions import constraints from torch import distributions as tfd -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") def swiss_roll(modes=128, size=0.1, width=0.1, spread=0.7, device=DEVICE) -> tfd.Distribution: """ diff --git a/src/score_models/utils.py b/src/score_models/utils.py index 4b48e78..44eea95 100644 --- a/src/score_models/utils.py +++ b/src/score_models/utils.py @@ -3,8 +3,7 @@ import torch.nn as nn DTYPE = torch.float32 -DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") - +DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") def get_norm_layer(norm_type='instance'): """Return a normalization layer