diff --git a/docs/examples/plot_KimCNN_quickstart.py b/docs/examples/plot_KimCNN_quickstart.py index 49ae1f0d..64990b0d 100644 --- a/docs/examples/plot_KimCNN_quickstart.py +++ b/docs/examples/plot_KimCNN_quickstart.py @@ -56,7 +56,6 @@ model_name=model_name, network_config=network_config, classes=classes, - word_dict=word_dict, embed_vecs=embed_vecs, learning_rate=learning_rate, monitor_metrics=["Micro-F1", "Macro-F1", "P@1", "P@3", "P@5"], @@ -66,7 +65,7 @@ # * ``model_name`` leads ``init_model`` function to find a network model. # * ``network_config`` contains the configurations of a network model. # * ``classes`` is the label set of the data. -# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them. +# * ``embed_vecs`` is the the pre-trained word vectors. # * ``moniter_metrics`` includes metrics you would like to track. # # diff --git a/docs/examples/plot_bert_quickstart.py b/docs/examples/plot_bert_quickstart.py index f71bb813..562aa97a 100644 --- a/docs/examples/plot_bert_quickstart.py +++ b/docs/examples/plot_bert_quickstart.py @@ -70,7 +70,6 @@ # * ``model_name`` leads ``init_model`` function to find a network model. # * ``network_config`` contains the configurations of a network model. # * ``classes`` is the label set of the data. -# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them. # * ``moniter_metrics`` includes metrics you would like to track. # # diff --git a/libmultilabel/nn/attentionxml.py b/libmultilabel/nn/attentionxml.py index 747f1b05..16c02b7c 100644 --- a/libmultilabel/nn/attentionxml.py +++ b/libmultilabel/nn/attentionxml.py @@ -1,6 +1,8 @@ from __future__ import annotations import logging +import os +import pickle from functools import partial from pathlib import Path from typing import Generator, Sequence, Optional @@ -33,6 +35,7 @@ class PLTTrainer: CHECKPOINT_NAME = "model_" + WORD_DICT_NAME = "word_dict.pickle" def __init__( self, @@ -261,7 +264,6 @@ def fit(self, datasets): model_name="AttentionXML_0", network_config=self.network_config, classes=clusters, - word_dict=self.word_dict, embed_vecs=self.embed_vecs, init_weight=self.init_weight, log_path=self.log_path, @@ -380,7 +382,6 @@ def fit(self, datasets): model_1 = PLTModel( classes=self.classes, - word_dict=self.word_dict, network=network, log_path=self.log_path, learning_rate=self.learning_rate, @@ -427,7 +428,11 @@ def test(self, dataset): save_k_predictions=self.save_k_predictions, metrics=self.metrics, ) - self.word_dict = model_1.word_dict + + word_dict_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), self.WORD_DICT_NAME) + if os.path.exists(word_dict_path): + with open(word_dict_path, "rb") as f: + self.word_dict = pickle.load(f) classes = model_1.classes test_x = self.reformat_text(dataset) @@ -489,9 +494,11 @@ def reformat_text(self, dataset): # Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length. encoded_text = list( map( - lambda text: torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64) - if text - else torch.tensor([self.word_dict[UNK]], dtype=torch.int64), + lambda text: ( + torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64) + if text + else torch.tensor([self.word_dict[UNK]], dtype=torch.int64) + ), [instance["text"][: self.max_seq_length] for instance in dataset], ) ) @@ -519,7 +526,6 @@ class PLTModel(Model): def __init__( self, classes, - word_dict, network, loss_function="binary_cross_entropy_with_logits", log_path=None, @@ -527,7 +533,6 @@ def __init__( ): super().__init__( classes=classes, - word_dict=word_dict, network=network, loss_function=loss_function, log_path=log_path, diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index f7f76439..aa0853e6 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -181,27 +181,17 @@ class Model(MultiLabelModel): Args: classes (list): List of class names. - word_dict (dict): A dictionary for mapping tokens to indices. network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN). loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits, cross_entropy). Defaults to 'binary_cross_entropy_with_logits'. log_path (str): Path to a directory holding the log files and models. """ - def __init__( - self, - classes, - word_dict, - network, - loss_function="binary_cross_entropy_with_logits", - log_path=None, - **kwargs - ): + def __init__(self, classes, network, loss_function="binary_cross_entropy_with_logits", log_path=None, **kwargs): super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) self.save_hyperparameters( ignore=["log_path"] ) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir). - self.word_dict = word_dict self.classes = classes self.network = network self.configure_loss_function(loss_function) diff --git a/libmultilabel/nn/nn_utils.py b/libmultilabel/nn/nn_utils.py index f9107d01..f8e0ff1f 100644 --- a/libmultilabel/nn/nn_utils.py +++ b/libmultilabel/nn/nn_utils.py @@ -37,7 +37,6 @@ def init_model( model_name, network_config, classes, - word_dict=None, embed_vecs=None, init_weight=None, log_path=None, @@ -61,7 +60,6 @@ def init_model( model_name (str): Model to be used such as KimCNN. network_config (dict): Configuration for defining the network. classes (list): List of class names. - word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None. embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape (vocab_size, embed_dim). Defaults to None. init_weight (str): Weight initialization method from `torch.nn.init`. @@ -98,7 +96,6 @@ def init_model( model = Model( classes=classes, - word_dict=word_dict, network=network, log_path=log_path, learning_rate=learning_rate, diff --git a/tests/nn/components.py b/tests/nn/components.py index bcfbcd68..c747b0cc 100644 --- a/tests/nn/components.py +++ b/tests/nn/components.py @@ -20,7 +20,7 @@ def get_name(self): return "token_to_id" def get_from_trainer(self, trainer): - return trainer.model.word_dict + return trainer.word_dict def compare(self, a, b): return a == b @@ -34,7 +34,7 @@ def get_name(self): return "embed_vecs" def get_from_trainer(self, trainer): - return trainer.model.embed_vecs + return trainer.embed_vecs def compare(self, a, b): return (a == b).all() diff --git a/torch_trainer.py b/torch_trainer.py index a7f0641d..fba9f68c 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -1,5 +1,6 @@ import logging import os +import pickle import numpy as np from lightning.pytorch.callbacks import ModelCheckpoint @@ -25,6 +26,8 @@ class TorchTrainer: Defaults to True. """ + WORD_DICT_NAME = "word_dict.pickle" + def __init__( self, config: dict, @@ -44,6 +47,11 @@ def __init__( self.device = init_device(use_cpu=config.cpu) self.config = config + # Set dataset meta info + self.embed_vecs = embed_vecs + self.word_dict = word_dict + self.classes = classes + # Load pretrained tokenizer for dataset loader self.tokenizer = None tokenize_text = "lm_weight" not in config.network_config @@ -69,8 +77,9 @@ def __init__( # Note that AttentionXML produces two models. checkpoint_path directs to model_1 if config.checkpoint_path is None: if self.config.embed_file is not None: - logging.info("Load word dictionary ") - word_dict, embed_vecs = data_utils.load_or_build_text_dict( + word_dict_path = os.path.join(self.checkpoint_dir, self.WORD_DICT_NAME) + logging.info(f"Load and cache the word dictionary into {word_dict_path}.") + self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict( dataset=self.datasets["train"] + self.datasets["val"], vocab_file=config.vocab_file, min_vocab_freq=config.min_vocab_freq, @@ -79,9 +88,11 @@ def __init__( normalize_embed=config.normalize_embed, embed_cache_dir=config.embed_cache_dir, ) + with open(word_dict_path, "wb") as f: + pickle.dump(self.word_dict, f) - if not classes: - classes = data_utils.load_or_build_label( + if not self.classes: + self.classes = data_utils.load_or_build_label( self.datasets, self.config.label_file, self.config.include_test_labels ) @@ -98,15 +109,12 @@ def __init__( f"Add {self.config.val_metric} to `monitor_metrics`." ) self.config.monitor_metrics += [self.config.val_metric] - self.trainer = PLTTrainer(self.config, classes=classes, embed_vecs=embed_vecs, word_dict=word_dict) + self.trainer = PLTTrainer( + self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict + ) return - self._setup_model( - classes=classes, - word_dict=word_dict, - embed_vecs=embed_vecs, - log_path=self.log_path, - checkpoint_path=config.checkpoint_path, - ) + + self._setup_model(log_path=self.log_path, checkpoint_path=config.checkpoint_path) self.trainer = init_trainer( checkpoint_dir=self.checkpoint_dir, epochs=config.epochs, @@ -125,9 +133,6 @@ def __init__( def _setup_model( self, - classes: list = None, - word_dict: dict = None, - embed_vecs=None, log_path: str = None, checkpoint_path: str = None, ): @@ -135,9 +140,6 @@ def _setup_model( Otherwise, initialize model from scratch. Args: - classes(list): List of class names. - word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None. - embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). log_path (str): Path to the log file. The log file contains the validation results for each epoch and the test results. If the `log_path` is None, no performance results will be logged. @@ -149,11 +151,16 @@ def _setup_model( if checkpoint_path is not None: logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...") self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path) + word_dict_path = os.path.join(os.path.dirname(checkpoint_path), self.WORD_DICT_NAME) + if os.path.exists(word_dict_path): + with open(word_dict_path, "rb") as f: + self.word_dict = pickle.load(f) else: logging.info("Initialize model from scratch.") if self.config.embed_file is not None: - logging.info("Load word dictionary ") - word_dict, embed_vecs = data_utils.load_or_build_text_dict( + word_dict_path = os.path.join(self.checkpoint_dir, self.WORD_DICT_NAME) + logging.info(f"Load and cache the word dictionary into {word_dict_path}.") + self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict( dataset=self.datasets["train"], vocab_file=self.config.vocab_file, min_vocab_freq=self.config.min_vocab_freq, @@ -162,8 +169,11 @@ def _setup_model( normalize_embed=self.config.normalize_embed, embed_cache_dir=self.config.embed_cache_dir, ) - if not classes: - classes = data_utils.load_or_build_label( + with open(word_dict_path, "wb") as f: + pickle.dump(self.word_dict, f) + + if not self.classes: + self.classes = data_utils.load_or_build_label( self.datasets, self.config.label_file, self.config.include_test_labels ) @@ -184,9 +194,8 @@ def _setup_model( self.model = init_model( model_name=self.config.model_name, network_config=dict(self.config.network_config), - classes=classes, - word_dict=word_dict, - embed_vecs=embed_vecs, + classes=self.classes, + embed_vecs=self.embed_vecs, init_weight=self.config.init_weight, log_path=log_path, learning_rate=self.config.learning_rate, @@ -222,7 +231,7 @@ def _get_dataset_loader(self, split, shuffle=False): batch_size=self.config.batch_size if split == "train" else self.config.eval_batch_size, shuffle=shuffle, data_workers=self.config.data_workers, - word_dict=self.model.word_dict, + word_dict=self.word_dict, tokenizer=self.tokenizer, add_special_tokens=self.config.add_special_tokens, )