diff --git a/bert/dataset.py b/bert/dataset.py index 2e83be3..5af86af 100644 --- a/bert/dataset.py +++ b/bert/dataset.py @@ -13,7 +13,7 @@ from torchtext.data.utils import get_tokenizer -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) class IMDBBertDataset(Dataset): @@ -194,7 +194,7 @@ def _select_false_nsp_sentences(self, sentences: typing.List[str]): return sentences[sentence_index], sentences[next_sentence_index] def _preprocess_sentence(self, sentence: typing.List[str], should_mask: bool = True): - inverse_token_mask = None + inverse_token_mask = [] if should_mask: sentence, inverse_token_mask = self._mask_sentence(sentence) sentence, inverse_token_mask = self._pad_sentence([self.CLS] + sentence, [True] + inverse_token_mask) diff --git a/bert/model.py b/bert/model.py index 7f952b7..99dd8e6 100644 --- a/bert/model.py +++ b/bert/model.py @@ -4,7 +4,7 @@ import torch.nn.functional as f -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) class JointEmbedding(nn.Module): diff --git a/bert/trainer.py b/bert/trainer.py index 50f4051..1a7ec14 100644 --- a/bert/trainer.py +++ b/bert/trainer.py @@ -11,7 +11,7 @@ from bert.dataset import IMDBBertDataset from bert.model import BERT -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) def percentage(batch_size: int, max_index: int, current_index: int): diff --git a/main.py b/main.py index a55a425..926202f 100644 --- a/main.py +++ b/main.py @@ -21,7 +21,7 @@ timestamp = datetime.datetime.utcnow().timestamp() LOG_DIR = BASE_DIR.joinpath(f'data/logs/bert_experiment_{timestamp}') -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) if torch.cuda.is_available(): torch.cuda.empty_cache()