diff --git a/model/misc.py b/model/misc.py index 43b84990..d0758359 100644 --- a/model/misc.py +++ b/model/misc.py @@ -53,11 +53,10 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None return logger -IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ - torch.__version__)[0][:3])] >= [1, 12, 0] +IS_MPS_AWARE = hasattr(torch.backends, 'mps') def gpu_is_available(): - if IS_HIGH_VERSION: + if IS_MPS_AWARE: if torch.backends.mps.is_available(): return True return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False @@ -70,7 +69,7 @@ def get_device(gpu_id=None): else: raise TypeError('Input should be int value.') - if IS_HIGH_VERSION: + if IS_MPS_AWARE: if torch.backends.mps.is_available(): return torch.device('mps'+gpu_str) return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')