From b22c454977d808a1ce26c1625acc6efd1407fc9a Mon Sep 17 00:00:00 2001 From: Eleven Liu Date: Thu, 14 Aug 2025 00:56:57 +0800 Subject: [PATCH 1/5] backup runnable draft before checkout --- libmultilabel/nn/attentionxml.py | 14 +-- libmultilabel/nn/model.py | 3 - libmultilabel/nn/nn_utils.py | 3 - torch_trainer.py | 148 ++++++++++++++----------------- 4 files changed, 76 insertions(+), 92 deletions(-) diff --git a/libmultilabel/nn/attentionxml.py b/libmultilabel/nn/attentionxml.py index 747f1b05..97e0113a 100644 --- a/libmultilabel/nn/attentionxml.py +++ b/libmultilabel/nn/attentionxml.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import pickle from functools import partial from pathlib import Path from typing import Generator, Sequence, Optional @@ -261,7 +262,6 @@ def fit(self, datasets): model_name="AttentionXML_0", network_config=self.network_config, classes=clusters, - word_dict=self.word_dict, embed_vecs=self.embed_vecs, init_weight=self.init_weight, log_path=self.log_path, @@ -380,7 +380,6 @@ def fit(self, datasets): model_1 = PLTModel( classes=self.classes, - word_dict=self.word_dict, network=network, log_path=self.log_path, learning_rate=self.learning_rate, @@ -427,7 +426,13 @@ def test(self, dataset): save_k_predictions=self.save_k_predictions, metrics=self.metrics, ) - self.word_dict = model_1.word_dict + # self.word_dict = model_1.word_dict + # TBD: similar to the one in torch_trainer + import os + metadata_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), "word_dict.pkl") + if os.path.exists(metadata_path): + with open(metadata_path, "rb") as f: + self.word_dict = pickle.load(f) classes = model_1.classes test_x = self.reformat_text(dataset) @@ -519,7 +524,6 @@ class PLTModel(Model): def __init__( self, classes, - word_dict, network, loss_function="binary_cross_entropy_with_logits", log_path=None, @@ -527,7 +531,7 @@ def __init__( ): super().__init__( classes=classes, - word_dict=word_dict, + # word_dict=word_dict, network=network, loss_function=loss_function, log_path=log_path, diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index f7f76439..37fd7559 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -181,7 +181,6 @@ class Model(MultiLabelModel): Args: classes (list): List of class names. - word_dict (dict): A dictionary for mapping tokens to indices. network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN). loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits, cross_entropy). Defaults to 'binary_cross_entropy_with_logits'. @@ -191,7 +190,6 @@ class Model(MultiLabelModel): def __init__( self, classes, - word_dict, network, loss_function="binary_cross_entropy_with_logits", log_path=None, @@ -201,7 +199,6 @@ def __init__( self.save_hyperparameters( ignore=["log_path"] ) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir). - self.word_dict = word_dict self.classes = classes self.network = network self.configure_loss_function(loss_function) diff --git a/libmultilabel/nn/nn_utils.py b/libmultilabel/nn/nn_utils.py index f9107d01..f8e0ff1f 100644 --- a/libmultilabel/nn/nn_utils.py +++ b/libmultilabel/nn/nn_utils.py @@ -37,7 +37,6 @@ def init_model( model_name, network_config, classes, - word_dict=None, embed_vecs=None, init_weight=None, log_path=None, @@ -61,7 +60,6 @@ def init_model( model_name (str): Model to be used such as KimCNN. network_config (dict): Configuration for defining the network. classes (list): List of class names. - word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None. embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape (vocab_size, embed_dim). Defaults to None. init_weight (str): Weight initialization method from `torch.nn.init`. @@ -98,7 +96,6 @@ def init_model( model = Model( classes=classes, - word_dict=word_dict, network=network, log_path=log_path, learning_rate=learning_rate, diff --git a/torch_trainer.py b/torch_trainer.py index a7f0641d..fd60bb68 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -1,5 +1,6 @@ import logging import os +import pickle import numpy as np from lightning.pytorch.callbacks import ModelCheckpoint @@ -38,12 +39,18 @@ def __init__( self.checkpoint_dir = config.checkpoint_dir self.log_path = config.log_path os.makedirs(self.checkpoint_dir, exist_ok=True) + self.metadata_path = os.path.join(config.checkpoint_dir, "word_dict.pkl") # TBD # Set up seed & device set_seed(seed=config.seed) self.device = init_device(use_cpu=config.cpu) self.config = config + # Set up meta data + self.embed_vecs = embed_vecs + self.word_dict = word_dict + self.classes = classes + # Load pretrained tokenizer for dataset loader self.tokenizer = None tokenize_text = "lm_weight" not in config.network_config @@ -64,69 +71,39 @@ def __init__( self.datasets = datasets self.config.multiclass = is_multiclass_dataset(self.datasets["train"] + self.datasets.get("val", list())) - - if self.config.model_name.lower() == "attentionxml": - # Note that AttentionXML produces two models. checkpoint_path directs to model_1 - if config.checkpoint_path is None: - if self.config.embed_file is not None: - logging.info("Load word dictionary ") - word_dict, embed_vecs = data_utils.load_or_build_text_dict( - dataset=self.datasets["train"] + self.datasets["val"], - vocab_file=config.vocab_file, - min_vocab_freq=config.min_vocab_freq, - embed_file=config.embed_file, - silent=config.silent, - normalize_embed=config.normalize_embed, - embed_cache_dir=config.embed_cache_dir, - ) - - if not classes: - classes = data_utils.load_or_build_label( - self.datasets, self.config.label_file, self.config.include_test_labels - ) - - if self.config.early_stopping_metric not in self.config.monitor_metrics: - logging.warning( - f"{self.config.early_stopping_metric} is not in `monitor_metrics`. " - f"Add {self.config.early_stopping_metric} to `monitor_metrics`." - ) - self.config.monitor_metrics += [self.config.early_stopping_metric] - - if self.config.val_metric not in self.config.monitor_metrics: - logging.warn( - f"{self.config.val_metric} is not in `monitor_metrics`. " - f"Add {self.config.val_metric} to `monitor_metrics`." - ) - self.config.monitor_metrics += [self.config.val_metric] - self.trainer = PLTTrainer(self.config, classes=classes, embed_vecs=embed_vecs, word_dict=word_dict) - return + self.config.is_attentionxml = self.config.model_name.lower() == "attentionxml" self._setup_model( + # TBD: check why we need this setting? + datasets=self.datasets["train"] + self.datasets["val"] if self.config.is_attentionxml else self.datasets["train"], classes=classes, - word_dict=word_dict, embed_vecs=embed_vecs, log_path=self.log_path, checkpoint_path=config.checkpoint_path, ) - self.trainer = init_trainer( - checkpoint_dir=self.checkpoint_dir, - epochs=config.epochs, - patience=config.patience, - early_stopping_metric=config.early_stopping_metric, - val_metric=config.val_metric, - silent=config.silent, - use_cpu=config.cpu, - limit_train_batches=config.limit_train_batches, - limit_val_batches=config.limit_val_batches, - limit_test_batches=config.limit_test_batches, - save_checkpoints=save_checkpoints, - ) - callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] - self.checkpoint_callback = callbacks[0] if callbacks else None + + if self.config.is_attentionxml: + self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict) + else: + self.trainer = init_trainer( + checkpoint_dir=self.checkpoint_dir, + epochs=config.epochs, + patience=config.patience, + early_stopping_metric=config.early_stopping_metric, + val_metric=config.val_metric, + silent=config.silent, + use_cpu=config.cpu, + limit_train_batches=config.limit_train_batches, + limit_val_batches=config.limit_val_batches, + limit_test_batches=config.limit_test_batches, + save_checkpoints=save_checkpoints, + ) + callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] + self.checkpoint_callback = callbacks[0] if callbacks else None def _setup_model( self, + datasets: dict = None, classes: list = None, - word_dict: dict = None, embed_vecs=None, log_path: str = None, checkpoint_path: str = None, @@ -135,8 +112,8 @@ def _setup_model( Otherwise, initialize model from scratch. Args: + datasets (dict, optional): Datasets for training, validation, and test. Defaults to None. classes(list): List of class names. - word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None. embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). log_path (str): Path to the log file. The log file contains the validation results for each epoch and the test results. If the `log_path` is None, no performance @@ -149,12 +126,17 @@ def _setup_model( if checkpoint_path is not None: logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...") self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path) + # TBD: pseudo code + metadata_path = os.path.join(os.path.dirname(checkpoint_path), "word_dict.pkl") + if os.path.exists(metadata_path): + with open(metadata_path, "rb") as f: + self.word_dict = pickle.load(f) else: logging.info("Initialize model from scratch.") if self.config.embed_file is not None: - logging.info("Load word dictionary ") - word_dict, embed_vecs = data_utils.load_or_build_text_dict( - dataset=self.datasets["train"], + logging.info(f"Load and cache the word dictionary into {self.metadata_path}") + self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict( + dataset=datasets, vocab_file=self.config.vocab_file, min_vocab_freq=self.config.min_vocab_freq, embed_file=self.config.embed_file, @@ -162,8 +144,11 @@ def _setup_model( normalize_embed=self.config.normalize_embed, embed_cache_dir=self.config.embed_cache_dir, ) + with open(self.metadata_path, "wb") as f: + pickle.dump(self.word_dict, f) + if not classes: - classes = data_utils.load_or_build_label( + self.classes = data_utils.load_or_build_label( self.datasets, self.config.label_file, self.config.include_test_labels ) @@ -181,28 +166,29 @@ def _setup_model( ) self.config.monitor_metrics += [self.config.val_metric] - self.model = init_model( - model_name=self.config.model_name, - network_config=dict(self.config.network_config), - classes=classes, - word_dict=word_dict, - embed_vecs=embed_vecs, - init_weight=self.config.init_weight, - log_path=log_path, - learning_rate=self.config.learning_rate, - optimizer=self.config.optimizer, - momentum=self.config.momentum, - weight_decay=self.config.weight_decay, - lr_scheduler=self.config.lr_scheduler, - scheduler_config=self.config.scheduler_config, - val_metric=self.config.val_metric, - metric_threshold=self.config.metric_threshold, - monitor_metrics=self.config.monitor_metrics, - multiclass=self.config.multiclass, - loss_function=self.config.loss_function, - silent=self.config.silent, - save_k_predictions=self.config.save_k_predictions, - ) + # TBD: not for attention xml + if not self.config.is_attentionxml: + self.model = init_model( + model_name=self.config.model_name, + network_config=dict(self.config.network_config), + classes=classes, + embed_vecs=embed_vecs, + init_weight=self.config.init_weight, + log_path=log_path, + learning_rate=self.config.learning_rate, + optimizer=self.config.optimizer, + momentum=self.config.momentum, + weight_decay=self.config.weight_decay, + lr_scheduler=self.config.lr_scheduler, + scheduler_config=self.config.scheduler_config, + val_metric=self.config.val_metric, + metric_threshold=self.config.metric_threshold, + monitor_metrics=self.config.monitor_metrics, + multiclass=self.config.multiclass, + loss_function=self.config.loss_function, + silent=self.config.silent, + save_k_predictions=self.config.save_k_predictions, + ) def _get_dataset_loader(self, split, shuffle=False): """Get dataset loader. @@ -222,7 +208,7 @@ def _get_dataset_loader(self, split, shuffle=False): batch_size=self.config.batch_size if split == "train" else self.config.eval_batch_size, shuffle=shuffle, data_workers=self.config.data_workers, - word_dict=self.model.word_dict, + word_dict=self.word_dict, tokenizer=self.tokenizer, add_special_tokens=self.config.add_special_tokens, ) From c3d140420d909911aa2bf3bc5eb729c55cf7ad3b Mon Sep 17 00:00:00 2001 From: Eleven Liu Date: Thu, 14 Aug 2025 10:22:01 +0800 Subject: [PATCH 2/5] Update torch_trainer (classes, embed_vecs) --- torch_trainer.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/torch_trainer.py b/torch_trainer.py index fd60bb68..dc10b96c 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -75,8 +75,6 @@ def __init__( self._setup_model( # TBD: check why we need this setting? datasets=self.datasets["train"] + self.datasets["val"] if self.config.is_attentionxml else self.datasets["train"], - classes=classes, - embed_vecs=embed_vecs, log_path=self.log_path, checkpoint_path=config.checkpoint_path, ) @@ -103,8 +101,6 @@ def __init__( def _setup_model( self, datasets: dict = None, - classes: list = None, - embed_vecs=None, log_path: str = None, checkpoint_path: str = None, ): @@ -113,8 +109,6 @@ def _setup_model( Args: datasets (dict, optional): Datasets for training, validation, and test. Defaults to None. - classes(list): List of class names. - embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). log_path (str): Path to the log file. The log file contains the validation results for each epoch and the test results. If the `log_path` is None, no performance results will be logged. @@ -147,7 +141,7 @@ def _setup_model( with open(self.metadata_path, "wb") as f: pickle.dump(self.word_dict, f) - if not classes: + if not self.classes: self.classes = data_utils.load_or_build_label( self.datasets, self.config.label_file, self.config.include_test_labels ) @@ -171,8 +165,8 @@ def _setup_model( self.model = init_model( model_name=self.config.model_name, network_config=dict(self.config.network_config), - classes=classes, - embed_vecs=embed_vecs, + classes=self.classes, + embed_vecs=self.embed_vecs, init_weight=self.config.init_weight, log_path=log_path, learning_rate=self.config.learning_rate, From 01e06518ce4eebba439d1642459a4588ada6c844 Mon Sep 17 00:00:00 2001 From: Eleven Liu Date: Sat, 16 Aug 2025 17:08:14 +0800 Subject: [PATCH 3/5] (1) Left attentionxml arch to next PR. (2) finish TBD in code --- libmultilabel/nn/attentionxml.py | 13 ++- torch_trainer.py | 146 ++++++++++++++++++------------- 2 files changed, 92 insertions(+), 67 deletions(-) diff --git a/libmultilabel/nn/attentionxml.py b/libmultilabel/nn/attentionxml.py index 97e0113a..1fe76c99 100644 --- a/libmultilabel/nn/attentionxml.py +++ b/libmultilabel/nn/attentionxml.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import pickle from functools import partial from pathlib import Path @@ -34,6 +35,7 @@ class PLTTrainer: CHECKPOINT_NAME = "model_" + WORD_DICT_NAME = "word_dict.pickle" def __init__( self, @@ -426,12 +428,10 @@ def test(self, dataset): save_k_predictions=self.save_k_predictions, metrics=self.metrics, ) - # self.word_dict = model_1.word_dict - # TBD: similar to the one in torch_trainer - import os - metadata_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), "word_dict.pkl") - if os.path.exists(metadata_path): - with open(metadata_path, "rb") as f: + + word_dict_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), self.WORD_DICT_NAME) + if os.path.exists(word_dict_path): + with open(word_dict_path, "rb") as f: self.word_dict = pickle.load(f) classes = model_1.classes @@ -531,7 +531,6 @@ def __init__( ): super().__init__( classes=classes, - # word_dict=word_dict, network=network, loss_function=loss_function, log_path=log_path, diff --git a/torch_trainer.py b/torch_trainer.py index dc10b96c..bd0ee2f3 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -25,6 +25,7 @@ class TorchTrainer: save_checkpoints (bool, optional): Whether to save the last and the best checkpoint or not. Defaults to True. """ + WORD_DICT_NAME = "word_dict.pickle" def __init__( self, @@ -39,14 +40,13 @@ def __init__( self.checkpoint_dir = config.checkpoint_dir self.log_path = config.log_path os.makedirs(self.checkpoint_dir, exist_ok=True) - self.metadata_path = os.path.join(config.checkpoint_dir, "word_dict.pkl") # TBD # Set up seed & device set_seed(seed=config.seed) self.device = init_device(use_cpu=config.cpu) self.config = config - # Set up meta data + # Set dataset meta info self.embed_vecs = embed_vecs self.word_dict = word_dict self.classes = classes @@ -71,36 +71,65 @@ def __init__( self.datasets = datasets self.config.multiclass = is_multiclass_dataset(self.datasets["train"] + self.datasets.get("val", list())) - self.config.is_attentionxml = self.config.model_name.lower() == "attentionxml" - self._setup_model( - # TBD: check why we need this setting? - datasets=self.datasets["train"] + self.datasets["val"] if self.config.is_attentionxml else self.datasets["train"], - log_path=self.log_path, - checkpoint_path=config.checkpoint_path, - ) - if self.config.is_attentionxml: - self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict) - else: - self.trainer = init_trainer( - checkpoint_dir=self.checkpoint_dir, - epochs=config.epochs, - patience=config.patience, - early_stopping_metric=config.early_stopping_metric, - val_metric=config.val_metric, - silent=config.silent, - use_cpu=config.cpu, - limit_train_batches=config.limit_train_batches, - limit_val_batches=config.limit_val_batches, - limit_test_batches=config.limit_test_batches, - save_checkpoints=save_checkpoints, - ) - callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] - self.checkpoint_callback = callbacks[0] if callbacks else None + if self.config.model_name.lower() == "attentionxml": + # Note that AttentionXML produces two models. checkpoint_path directs to model_1 + if config.checkpoint_path is None: + if self.config.embed_file is not None: + word_dict_path = os.path.join(self.checkpoint_dir, self.WORD_DICT_NAME) + logging.info(f"Load and cache the word dictionary into {word_dict_path}.") + self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict( + dataset=self.datasets["train"] + self.datasets["val"], + vocab_file=config.vocab_file, + min_vocab_freq=config.min_vocab_freq, + embed_file=config.embed_file, + silent=config.silent, + normalize_embed=config.normalize_embed, + embed_cache_dir=config.embed_cache_dir, + ) + with open(word_dict_path, "wb") as f: + pickle.dump(self.word_dict, f) + + if not self.classes: + self.classes = data_utils.load_or_build_label( + self.datasets, self.config.label_file, self.config.include_test_labels + ) + + if self.config.early_stopping_metric not in self.config.monitor_metrics: + logging.warning( + f"{self.config.early_stopping_metric} is not in `monitor_metrics`. " + f"Add {self.config.early_stopping_metric} to `monitor_metrics`." + ) + self.config.monitor_metrics += [self.config.early_stopping_metric] + + if self.config.val_metric not in self.config.monitor_metrics: + logging.warn( + f"{self.config.val_metric} is not in `monitor_metrics`. " + f"Add {self.config.val_metric} to `monitor_metrics`." + ) + self.config.monitor_metrics += [self.config.val_metric] + self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict) + return + + self._setup_model(log_path=self.log_path, checkpoint_path=config.checkpoint_path) + self.trainer = init_trainer( + checkpoint_dir=self.checkpoint_dir, + epochs=config.epochs, + patience=config.patience, + early_stopping_metric=config.early_stopping_metric, + val_metric=config.val_metric, + silent=config.silent, + use_cpu=config.cpu, + limit_train_batches=config.limit_train_batches, + limit_val_batches=config.limit_val_batches, + limit_test_batches=config.limit_test_batches, + save_checkpoints=save_checkpoints, + ) + callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] + self.checkpoint_callback = callbacks[0] if callbacks else None def _setup_model( self, - datasets: dict = None, log_path: str = None, checkpoint_path: str = None, ): @@ -108,7 +137,6 @@ def _setup_model( Otherwise, initialize model from scratch. Args: - datasets (dict, optional): Datasets for training, validation, and test. Defaults to None. log_path (str): Path to the log file. The log file contains the validation results for each epoch and the test results. If the `log_path` is None, no performance results will be logged. @@ -116,21 +144,21 @@ def _setup_model( """ if "checkpoint_path" in self.config and self.config.checkpoint_path is not None: checkpoint_path = self.config.checkpoint_path - + if checkpoint_path is not None: logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...") self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path) - # TBD: pseudo code - metadata_path = os.path.join(os.path.dirname(checkpoint_path), "word_dict.pkl") - if os.path.exists(metadata_path): - with open(metadata_path, "rb") as f: + word_dict_path = os.path.join(os.path.dirname(checkpoint_path), self.WORD_DICT_NAME) + if os.path.exists(word_dict_path): + with open(word_dict_path, "rb") as f: self.word_dict = pickle.load(f) else: logging.info("Initialize model from scratch.") if self.config.embed_file is not None: - logging.info(f"Load and cache the word dictionary into {self.metadata_path}") + word_dict_path = os.path.join(self.checkpoint_dir, self.WORD_DICT_NAME) + logging.info(f"Load and cache the word dictionary into {word_dict_path}.") self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict( - dataset=datasets, + dataset=self.datasets["train"], vocab_file=self.config.vocab_file, min_vocab_freq=self.config.min_vocab_freq, embed_file=self.config.embed_file, @@ -138,7 +166,7 @@ def _setup_model( normalize_embed=self.config.normalize_embed, embed_cache_dir=self.config.embed_cache_dir, ) - with open(self.metadata_path, "wb") as f: + with open(word_dict_path, "wb") as f: pickle.dump(self.word_dict, f) if not self.classes: @@ -160,29 +188,27 @@ def _setup_model( ) self.config.monitor_metrics += [self.config.val_metric] - # TBD: not for attention xml - if not self.config.is_attentionxml: - self.model = init_model( - model_name=self.config.model_name, - network_config=dict(self.config.network_config), - classes=self.classes, - embed_vecs=self.embed_vecs, - init_weight=self.config.init_weight, - log_path=log_path, - learning_rate=self.config.learning_rate, - optimizer=self.config.optimizer, - momentum=self.config.momentum, - weight_decay=self.config.weight_decay, - lr_scheduler=self.config.lr_scheduler, - scheduler_config=self.config.scheduler_config, - val_metric=self.config.val_metric, - metric_threshold=self.config.metric_threshold, - monitor_metrics=self.config.monitor_metrics, - multiclass=self.config.multiclass, - loss_function=self.config.loss_function, - silent=self.config.silent, - save_k_predictions=self.config.save_k_predictions, - ) + self.model = init_model( + model_name=self.config.model_name, + network_config=dict(self.config.network_config), + classes=self.classes, + embed_vecs=self.embed_vecs, + init_weight=self.config.init_weight, + log_path=log_path, + learning_rate=self.config.learning_rate, + optimizer=self.config.optimizer, + momentum=self.config.momentum, + weight_decay=self.config.weight_decay, + lr_scheduler=self.config.lr_scheduler, + scheduler_config=self.config.scheduler_config, + val_metric=self.config.val_metric, + metric_threshold=self.config.metric_threshold, + monitor_metrics=self.config.monitor_metrics, + multiclass=self.config.multiclass, + loss_function=self.config.loss_function, + silent=self.config.silent, + save_k_predictions=self.config.save_k_predictions, + ) def _get_dataset_loader(self, split, shuffle=False): """Get dataset loader. From c01dedb3ddbce17e1896aefc9e78075e4b2b5640 Mon Sep 17 00:00:00 2001 From: Eleven Liu Date: Sat, 16 Aug 2025 19:15:09 +0800 Subject: [PATCH 4/5] Update documents and test components. --- docs/examples/plot_KimCNN_quickstart.py | 3 +-- docs/examples/plot_bert_quickstart.py | 1 - tests/nn/components.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/examples/plot_KimCNN_quickstart.py b/docs/examples/plot_KimCNN_quickstart.py index 49ae1f0d..64990b0d 100644 --- a/docs/examples/plot_KimCNN_quickstart.py +++ b/docs/examples/plot_KimCNN_quickstart.py @@ -56,7 +56,6 @@ model_name=model_name, network_config=network_config, classes=classes, - word_dict=word_dict, embed_vecs=embed_vecs, learning_rate=learning_rate, monitor_metrics=["Micro-F1", "Macro-F1", "P@1", "P@3", "P@5"], @@ -66,7 +65,7 @@ # * ``model_name`` leads ``init_model`` function to find a network model. # * ``network_config`` contains the configurations of a network model. # * ``classes`` is the label set of the data. -# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them. +# * ``embed_vecs`` is the the pre-trained word vectors. # * ``moniter_metrics`` includes metrics you would like to track. # # diff --git a/docs/examples/plot_bert_quickstart.py b/docs/examples/plot_bert_quickstart.py index f71bb813..562aa97a 100644 --- a/docs/examples/plot_bert_quickstart.py +++ b/docs/examples/plot_bert_quickstart.py @@ -70,7 +70,6 @@ # * ``model_name`` leads ``init_model`` function to find a network model. # * ``network_config`` contains the configurations of a network model. # * ``classes`` is the label set of the data. -# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them. # * ``moniter_metrics`` includes metrics you would like to track. # # diff --git a/tests/nn/components.py b/tests/nn/components.py index bcfbcd68..c747b0cc 100644 --- a/tests/nn/components.py +++ b/tests/nn/components.py @@ -20,7 +20,7 @@ def get_name(self): return "token_to_id" def get_from_trainer(self, trainer): - return trainer.model.word_dict + return trainer.word_dict def compare(self, a, b): return a == b @@ -34,7 +34,7 @@ def get_name(self): return "embed_vecs" def get_from_trainer(self, trainer): - return trainer.model.embed_vecs + return trainer.embed_vecs def compare(self, a, b): return (a == b).all() From 33e407f39748453a67a9fbe221b5b4b0adb1558b Mon Sep 17 00:00:00 2001 From: Eleven Liu Date: Sat, 16 Aug 2025 20:33:27 +0800 Subject: [PATCH 5/5] Apply black formatter. --- libmultilabel/nn/attentionxml.py | 10 ++++++---- libmultilabel/nn/model.py | 9 +-------- torch_trainer.py | 11 +++++++---- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/libmultilabel/nn/attentionxml.py b/libmultilabel/nn/attentionxml.py index 1fe76c99..16c02b7c 100644 --- a/libmultilabel/nn/attentionxml.py +++ b/libmultilabel/nn/attentionxml.py @@ -428,7 +428,7 @@ def test(self, dataset): save_k_predictions=self.save_k_predictions, metrics=self.metrics, ) - + word_dict_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), self.WORD_DICT_NAME) if os.path.exists(word_dict_path): with open(word_dict_path, "rb") as f: @@ -494,9 +494,11 @@ def reformat_text(self, dataset): # Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length. encoded_text = list( map( - lambda text: torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64) - if text - else torch.tensor([self.word_dict[UNK]], dtype=torch.int64), + lambda text: ( + torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64) + if text + else torch.tensor([self.word_dict[UNK]], dtype=torch.int64) + ), [instance["text"][: self.max_seq_length] for instance in dataset], ) ) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 37fd7559..aa0853e6 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -187,14 +187,7 @@ class Model(MultiLabelModel): log_path (str): Path to a directory holding the log files and models. """ - def __init__( - self, - classes, - network, - loss_function="binary_cross_entropy_with_logits", - log_path=None, - **kwargs - ): + def __init__(self, classes, network, loss_function="binary_cross_entropy_with_logits", log_path=None, **kwargs): super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) self.save_hyperparameters( ignore=["log_path"] diff --git a/torch_trainer.py b/torch_trainer.py index bd0ee2f3..fba9f68c 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -25,6 +25,7 @@ class TorchTrainer: save_checkpoints (bool, optional): Whether to save the last and the best checkpoint or not. Defaults to True. """ + WORD_DICT_NAME = "word_dict.pickle" def __init__( @@ -87,7 +88,7 @@ def __init__( normalize_embed=config.normalize_embed, embed_cache_dir=config.embed_cache_dir, ) - with open(word_dict_path, "wb") as f: + with open(word_dict_path, "wb") as f: pickle.dump(self.word_dict, f) if not self.classes: @@ -108,9 +109,11 @@ def __init__( f"Add {self.config.val_metric} to `monitor_metrics`." ) self.config.monitor_metrics += [self.config.val_metric] - self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict) + self.trainer = PLTTrainer( + self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict + ) return - + self._setup_model(log_path=self.log_path, checkpoint_path=config.checkpoint_path) self.trainer = init_trainer( checkpoint_dir=self.checkpoint_dir, @@ -144,7 +147,7 @@ def _setup_model( """ if "checkpoint_path" in self.config and self.config.checkpoint_path is not None: checkpoint_path = self.config.checkpoint_path - + if checkpoint_path is not None: logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...") self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path)