From 8ab27d85e57278a93498134c2419fe88e3fbd729 Mon Sep 17 00:00:00 2001 From: "^R[0-9]*NG$" Date: Sat, 3 Feb 2024 20:34:33 +0800 Subject: [PATCH 1/2] fix: fixed inverse token mask concatenation error Newer Python versions does not allow direct concatenation between typing.List and None. The implementation has been corrected by setting `inverse_token_mask` in `_preprocess_sentence` from `None` to `[]`. --- bert/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert/dataset.py b/bert/dataset.py index 2e83be3..c86b43f 100644 --- a/bert/dataset.py +++ b/bert/dataset.py @@ -194,7 +194,7 @@ def _select_false_nsp_sentences(self, sentences: typing.List[str]): return sentences[sentence_index], sentences[next_sentence_index] def _preprocess_sentence(self, sentence: typing.List[str], should_mask: bool = True): - inverse_token_mask = None + inverse_token_mask = [] if should_mask: sentence, inverse_token_mask = self._mask_sentence(sentence) sentence, inverse_token_mask = self._pad_sentence([self.CLS] + sentence, [True] + inverse_token_mask) From f4b2eb3133af6c4a0262c7a1388a7192713b5714 Mon Sep 17 00:00:00 2001 From: "^R[0-9]*NG$" Date: Sat, 3 Feb 2024 20:35:53 +0800 Subject: [PATCH 2/2] feat: added mps acceleration support for Apple Silicon devices --- bert/dataset.py | 2 +- bert/model.py | 2 +- bert/trainer.py | 2 +- main.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bert/dataset.py b/bert/dataset.py index c86b43f..5af86af 100644 --- a/bert/dataset.py +++ b/bert/dataset.py @@ -13,7 +13,7 @@ from torchtext.data.utils import get_tokenizer -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) class IMDBBertDataset(Dataset): diff --git a/bert/model.py b/bert/model.py index 7f952b7..99dd8e6 100644 --- a/bert/model.py +++ b/bert/model.py @@ -4,7 +4,7 @@ import torch.nn.functional as f -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) class JointEmbedding(nn.Module): diff --git a/bert/trainer.py b/bert/trainer.py index 50f4051..1a7ec14 100644 --- a/bert/trainer.py +++ b/bert/trainer.py @@ -11,7 +11,7 @@ from bert.dataset import IMDBBertDataset from bert.model import BERT -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) def percentage(batch_size: int, max_index: int, current_index: int): diff --git a/main.py b/main.py index a55a425..926202f 100644 --- a/main.py +++ b/main.py @@ -21,7 +21,7 @@ timestamp = datetime.datetime.utcnow().timestamp() LOG_DIR = BASE_DIR.joinpath(f'data/logs/bert_experiment_{timestamp}') -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) if torch.cuda.is_available(): torch.cuda.empty_cache()