diff --git a/libmultilabel/nn/attentionxml.py b/libmultilabel/nn/attentionxml.py index 16c02b7c..a138a68e 100644 --- a/libmultilabel/nn/attentionxml.py +++ b/libmultilabel/nn/attentionxml.py @@ -287,7 +287,7 @@ def fit(self, datasets): logger.info(f"Finish training level 0") logger.info(f"Best model loaded from {best_model_path}") - model_0 = Model.load_from_checkpoint(best_model_path) + model_0 = Model.load_from_checkpoint(best_model_path, weights_only=False) logger.info( f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters for each instance and " @@ -422,11 +422,13 @@ def test(self, dataset): model_0 = Model.load_from_checkpoint( self.get_best_model_path(level=0), save_k_predictions=self.beam_width, + weights_only=False, ) model_1 = PLTModel.load_from_checkpoint( self.get_best_model_path(level=1), save_k_predictions=self.save_k_predictions, metrics=self.metrics, + weights_only=False, ) word_dict_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), self.WORD_DICT_NAME) diff --git a/libmultilabel/nn/networks/bert.py b/libmultilabel/nn/networks/bert.py index ab7cc221..8d88ebbb 100644 --- a/libmultilabel/nn/networks/bert.py +++ b/libmultilabel/nn/networks/bert.py @@ -34,7 +34,6 @@ def __init__( hidden_dropout_prob=encoder_hidden_dropout, attention_probs_dropout_prob=encoder_attention_dropout, classifier_dropout=post_encoder_dropout, - torchscript=True, ) def forward(self, input): diff --git a/libmultilabel/nn/networks/bert_attention.py b/libmultilabel/nn/networks/bert_attention.py index 078bc207..c820fee8 100644 --- a/libmultilabel/nn/networks/bert_attention.py +++ b/libmultilabel/nn/networks/bert_attention.py @@ -40,7 +40,6 @@ def __init__( self.lm = AutoModel.from_pretrained( lm_weight, - torchscript=True, hidden_dropout_prob=encoder_hidden_dropout, attention_probs_dropout_prob=encoder_attention_dropout, ) diff --git a/torch_trainer.py b/torch_trainer.py index fba9f68c..8b90a706 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -150,7 +150,7 @@ def _setup_model( if checkpoint_path is not None: logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...") - self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path) + self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path, weights_only=False) word_dict_path = os.path.join(os.path.dirname(checkpoint_path), self.WORD_DICT_NAME) if os.path.exists(word_dict_path): with open(word_dict_path, "rb") as f: