From 018ddefd9f5669be75328686d9860840d4cba6a4 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 15 Mar 2025 00:48:09 +0300 Subject: [PATCH 01/30] Added cosine distance to u2i transformer models. --- rectools/models/nn/transformers/base.py | 5 + rectools/models/nn/transformers/bert4rec.py | 5 + rectools/models/nn/transformers/lightning.py | 18 ++- rectools/models/nn/transformers/sasrec.py | 5 + rectools/models/rank/rank_torch.py | 121 ++++++++++++++----- 5 files changed, 123 insertions(+), 31 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 92c35020..2e12a020 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -28,6 +28,7 @@ from rectools import ExternalIds from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig +from rectools.models.rank import Distance from rectools.types import InternalIdsArray from rectools.utils.misc import get_class_or_function_full_path, import_object @@ -178,6 +179,7 @@ class TransformerModelConfig(ModelConfig): recommend_batch_size: int = 256 recommend_torch_device: tp.Optional[str] = None train_min_user_interactions: int = 2 + u2i_dist: Distance = Distance.DOT item_net_block_types: ItemNetBlockTypes = (IdEmbeddingsItemNet, CatFeaturesItemNet) item_net_constructor_type: ItemNetConstructorType = SumOfEmbeddingsConstructor pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding @@ -233,6 +235,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_batch_size: int = 256, recommend_torch_device: tp.Optional[str] = None, train_min_user_interactions: int = 2, + u2i_dist: Distance = Distance.DOT, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, @@ -268,6 +271,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.recommend_batch_size = recommend_batch_size self.recommend_torch_device = recommend_torch_device self.train_min_user_interactions = train_min_user_interactions + self.u2i_dist = u2i_dist self.item_net_block_types = item_net_block_types self.item_net_constructor_type = item_net_constructor_type self.pos_encoding_type = pos_encoding_type @@ -390,6 +394,7 @@ def _init_lightning_model( train_loss_name=self.train_loss_name, val_loss_name=self.val_loss_name, adam_betas=(0.9, 0.98), + u2i_dist=self.u2i_dist, **self._get_kwargs(self.lightning_module_kwargs), ) diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 71675ebd..47045731 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -19,6 +19,7 @@ import numpy as np import torch +from rectools.models.rank import Distance from ..item_net import ( CatFeaturesItemNet, IdEmbeddingsItemNet, @@ -241,6 +242,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): BERT4Rec training task ("MLM") does not work with causal masking. Set this parameter to ``True`` only when you change the training task with custom `data_preparator_type` or if you are absolutely sure of what you are doing. + u2i_dist : Distance, default Distance.DOT + U2I distance metric. item_net_block_types : sequence of `type(ItemNetBase)`, default `(IdEmbeddingsItemNet, CatFeaturesItemNet)` Type of network returning item embeddings. (IdEmbeddingsItemNet,) - item embeddings based on ids. @@ -314,6 +317,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals use_pos_emb: bool = True, use_key_padding_mask: bool = True, use_causal_attn: bool = False, + u2i_dist: Distance = Distance.DOT, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, @@ -360,6 +364,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_n_threads=recommend_n_threads, recommend_use_torch_ranking=recommend_use_torch_ranking, train_min_user_interactions=train_min_user_interactions, + u2i_dist=u2i_dist, item_net_block_types=item_net_block_types, item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 05e363fc..c345e638 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -24,6 +24,7 @@ from rectools.dataset.dataset import Dataset, DatasetSchemaDict from rectools.models.base import InternalRecoTriplet from rectools.models.rank import Distance, TorchRanker +from rectools.models.rank.rank_torch import get_scorer from rectools.types import InternalIdsArray from .data_preparator import TransformerDataPreparatorBase @@ -55,8 +56,12 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma Name of the training loss. val_loss_name : str, default "val_loss" Name of the training loss. + u2i_dist : Distance, default Distance.DOT + U2I distance metric. """ + u2i_dist_available = [Distance.DOT, Distance.COSINE] + def __init__( self, torch_model: TransformerTorchBackbone, @@ -72,6 +77,7 @@ def __init__( train_loss_name: str = "train_loss", val_loss_name: str = "val_loss", adam_betas: tp.Tuple[float, float] = (0.9, 0.98), + u2i_dist: Distance = Distance.DOT, **kwargs: tp.Any, ): super().__init__() @@ -90,6 +96,12 @@ def __init__( self.val_loss_name = val_loss_name self.item_embs: torch.Tensor + if u2i_dist not in self.u2i_dist_available: + raise ValueError("`u2i_distance` can only be either `Distance.DOT` or `Distance.COSINE`") + + self.u2i_dist = u2i_dist + self.scorer, _ = get_scorer(u2i_dist) + self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) def configure_optimizers(self) -> torch.optim.Adam: @@ -214,7 +226,7 @@ def _calc_custom_loss_outputs( def _get_full_catalog_logits(self, x: torch.Tensor) -> torch.Tensor: item_embs, session_embs = self.torch_model(x) - logits = session_embs @ item_embs.T + logits = self.scorer(session_embs, item_embs) return logits def _get_pos_neg_logits(self, x: torch.Tensor, y: torch.Tensor, negatives: torch.Tensor) -> torch.Tensor: @@ -223,7 +235,7 @@ def _get_pos_neg_logits(self, x: torch.Tensor, y: torch.Tensor, negatives: torch pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] pos_neg_embs = item_embs[pos_neg] # [batch_size, session_max_len, n_negatives + 1, n_factors] # [batch_size, session_max_len, n_negatives + 1] - logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) + logits = self.scorer(session_embs.unsqueeze(-1), pos_neg_embs).squeeze(-1) return logits def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int, n_negatives: int) -> torch.Tensor: @@ -337,7 +349,7 @@ def _recommend_u2i( user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device) ranker = TorchRanker( - distance=Distance.DOT, + distance=self.u2i_dist, device=item_embs.device, subjects_factors=user_embs[user_ids], objects_factors=item_embs, diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 343c9c7c..6607f163 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -19,6 +19,7 @@ import torch from torch import nn +from rectools.models.rank import Distance from ..item_net import ( CatFeaturesItemNet, IdEmbeddingsItemNet, @@ -321,6 +322,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): SASRec training task ("Shifted Sequence") does not work without causal masking. Set this parameter to ``False`` only when you change the training task with custom `data_preparator_type` or if you are absolutely sure of what you are doing. + u2i_dist : Distance, default Distance.DOT + U2I distance metric. item_net_block_types : sequence of `type(ItemNetBase)`, default `(IdEmbeddingsItemNet, CatFeaturesItemNet)` Type of network returning item embeddings. (IdEmbeddingsItemNet,) - item embeddings based on ids. @@ -393,6 +396,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals use_pos_emb: bool = True, use_key_padding_mask: bool = False, use_causal_attn: bool = True, + u2i_dist: Distance = Distance.DOT, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, @@ -436,6 +440,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_n_threads=recommend_n_threads, recommend_use_torch_ranking=recommend_use_torch_ranking, train_min_user_interactions=train_min_user_interactions, + u2i_dist=u2i_dist, item_net_block_types=item_net_block_types, item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, diff --git a/rectools/models/rank/rank_torch.py b/rectools/models/rank/rank_torch.py index ed091acc..eb2d6ad7 100644 --- a/rectools/models/rank/rank_torch.py +++ b/rectools/models/rank/rank_torch.py @@ -67,7 +67,7 @@ def __init__( self.device = torch.device(device) self.batch_size = batch_size self.distance = distance - self._scorer, self._higher_is_better = self._get_scorer(distance) + self._scorer, self._higher_is_better = get_scorer(distance) self.subjects_factors = self._normalize_tensor(subjects_factors) self.objects_factors = self._normalize_tensor(objects_factors) @@ -175,33 +175,6 @@ def rank( all_scores, ) - def _get_scorer( - self, distance: Distance - ) -> tp.Tuple[tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], bool]: - """Return scorer and higher_is_better flag""" - if distance == Distance.DOT: - return self._dot_score, True - - if distance == Distance.COSINE: - return self._cosine_score, True - - if distance == Distance.EUCLIDEAN: - return self._euclid_score, False - - raise NotImplementedError(f"distance {distance} is not supported") # pragma: no cover - - def _euclid_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - return torch.cdist(user_embs.unsqueeze(0), item_embs.unsqueeze(0)).squeeze(0) - - def _cosine_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - user_embs = user_embs / torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1) - item_embs = item_embs / torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1) - - return user_embs @ item_embs.T - - def _dot_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - return user_embs @ item_embs.T - def _normalize_tensor( self, tensor: tp.Union[np.ndarray, sparse.csr_matrix, torch.Tensor], @@ -216,3 +189,95 @@ def _normalize_tensor( tensor = tensor.to(self.dtype) return tensor + + +def get_scorer(distance: Distance) -> tp.Tuple[tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], bool]: + """ + Return scorer and higher_is_better flag + + Parameters + ---------- + distance : Distance + Distance metric. + + Returns + ------- + tuple(Callable(torch.Tensor, torch.Tensor), torch.Tensor) + Dictionary where keys are the same with keys in `metrics` + and values are metric calculation results. + + Raises + ------ + ValueError + If distance is not supported. + """ + if distance == Distance.DOT: + return dot_score, True + + if distance == Distance.COSINE: + return cosine_score, True + + if distance == Distance.EUCLIDEAN: + return euclid_score, False + + raise NotImplementedError(f"distance {distance} is not supported") # pragma: no cover + + +def euclid_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + """ + Calculate Euclidean score. + + Parameters + ---------- + user_embs : torch.Tensor + User embeddings. + item_embs : torch.Tensor + User embeddings. + + Returns + ------- + torch.Tensor + Result Euclidean score. + """ + return torch.cdist(user_embs.unsqueeze(0), item_embs.unsqueeze(0)).squeeze(0) + + +def cosine_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + """ + Calculate cosine score. + + Parameters + ---------- + user_embs : torch.Tensor + User embeddings. + item_embs : torch.Tensor + User embeddings. + + Returns + ------- + torch.Tensor + Result cosine score. + """ + user_embs = user_embs / torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1) + item_embs = item_embs / torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1) + + return user_embs @ item_embs.T + + +def dot_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + """ + Calculate dot product score. + + Parameters + ---------- + user_embs : torch.Tensor + User embeddings. + item_embs : torch.Tensor + User embeddings. + + Returns + ------- + torch.Tensor + Result dot product score. + """ + return user_embs @ item_embs.T From 269430a1aecd07f22c66badeed037d2f372f7071 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sun, 16 Mar 2025 00:12:05 +0300 Subject: [PATCH 02/30] Fixed docs and added cosine to test_bert4rec. --- rectools/models/nn/transformers/lightning.py | 2 +- rectools/models/rank/rank_torch.py | 6 +- tests/models/nn/transformers/test_bert4rec.py | 271 +++++++++--------- 3 files changed, 143 insertions(+), 136 deletions(-) diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index c345e638..7b71ba4f 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -97,7 +97,7 @@ def __init__( self.item_embs: torch.Tensor if u2i_dist not in self.u2i_dist_available: - raise ValueError("`u2i_distance` can only be either `Distance.DOT` or `Distance.COSINE`") + raise ValueError("`u2i_distance` can only be either `Distance.DOT` or `Distance.COSINE`.") self.u2i_dist = u2i_dist self.scorer, _ = get_scorer(u2i_dist) diff --git a/rectools/models/rank/rank_torch.py b/rectools/models/rank/rank_torch.py index eb2d6ad7..a71d3d3f 100644 --- a/rectools/models/rank/rank_torch.py +++ b/rectools/models/rank/rank_torch.py @@ -232,7 +232,7 @@ def euclid_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tens user_embs : torch.Tensor User embeddings. item_embs : torch.Tensor - User embeddings. + Item embeddings. Returns ------- @@ -251,7 +251,7 @@ def cosine_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tens user_embs : torch.Tensor User embeddings. item_embs : torch.Tensor - User embeddings. + Item embeddings. Returns ------- @@ -273,7 +273,7 @@ def dot_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: user_embs : torch.Tensor User embeddings. item_embs : torch.Tensor - User embeddings. + Item embeddings. Returns ------- diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 62a73d83..fe309070 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -25,6 +25,7 @@ from rectools.columns import Columns from rectools.dataset import Dataset from rectools.models import BERT4RecModel +from rectools.models.rank import Distance from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor from rectools.models.nn.transformers.base import ( LearnableInversePositionalEncoding, @@ -212,6 +213,7 @@ def get_trainer() -> Trainer: ), ), ) + @pytest.mark.parametrize("u2i_dist", (Distance.DOT, Distance.COSINE)) def test_u2i( self, dataset_devices: Dataset, @@ -223,6 +225,7 @@ def test_u2i( expected_cpu_2: pd.DataFrame, expected_gpu_1: pd.DataFrame, expected_gpu_2: pd.DataFrame, + u2i_dist: Distance, ) -> None: if n_devices != 1: pytest.skip("DEBUG: skipping multi-device tests") @@ -249,6 +252,7 @@ def get_trainer() -> Trainer: recommend_torch_device=recommend_torch_device, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer, + u2i_dist=u2i_dist, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -292,12 +296,14 @@ def get_trainer() -> Trainer: ), ), ) + @pytest.mark.parametrize("u2i_dist", (Distance.DOT, Distance.COSINE)) def test_u2i_losses( self, dataset_devices: Dataset, loss: str, get_trainer_func: TrainerCallable, expected: pd.DataFrame, + u2i_dist: Distance ) -> None: model = BERT4RecModel( n_negatives=2, @@ -313,6 +319,7 @@ def test_u2i_losses( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, loss=loss, + u2i_dist=u2i_dist, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -792,135 +799,135 @@ def test_get_dataloader_val( assert torch.equal(value, val_batch[key]) -class TestBERT4RecModelConfiguration: - def setup_method(self) -> None: - self._seed_everything() - - def _seed_everything(self) -> None: - torch.use_deterministic_algorithms(True) - seed_everything(32, workers=True) - - @pytest.fixture - def initial_config(self) -> tp.Dict[str, tp.Any]: - config = { - "n_blocks": 2, - "n_heads": 4, - "n_factors": 64, - "use_pos_emb": False, - "use_causal_attn": False, - "use_key_padding_mask": True, - "dropout_rate": 0.5, - "session_max_len": 10, - "dataloader_num_workers": 0, - "batch_size": 1024, - "loss": "softmax", - "n_negatives": 10, - "gbce_t": 0.5, - "lr": 0.001, - "epochs": 10, - "verbose": 1, - "deterministic": True, - "recommend_torch_device": None, - "recommend_batch_size": 256, - "train_min_user_interactions": 2, - "item_net_block_types": (IdEmbeddingsItemNet,), - "item_net_constructor_type": SumOfEmbeddingsConstructor, - "pos_encoding_type": LearnableInversePositionalEncoding, - "transformer_layers_type": PreLNTransformerLayers, - "data_preparator_type": BERT4RecDataPreparator, - "lightning_module_type": TransformerLightningModule, - "mask_prob": 0.15, - "get_val_mask_func": leave_one_out_mask, - "get_trainer_func": None, - "data_preparator_kwargs": None, - "transformer_layers_kwargs": None, - "item_net_constructor_kwargs": None, - "pos_encoding_kwargs": None, - "lightning_module_kwargs": None, - } - return config - - @pytest.mark.parametrize("use_custom_trainer", (True, False)) - def test_from_config(self, initial_config: tp.Dict[str, tp.Any], use_custom_trainer: bool) -> None: - config = initial_config - if use_custom_trainer: - config["get_trainer_func"] = custom_trainer - model = BERT4RecModel.from_config(initial_config) - - for key, config_value in initial_config.items(): - assert getattr(model, key) == config_value - - assert model._trainer is not None # pylint: disable = protected-access - - @pytest.mark.parametrize("use_custom_trainer", (True, False)) - @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config( - self, simple_types: bool, initial_config: tp.Dict[str, tp.Any], use_custom_trainer: bool - ) -> None: - config = initial_config - if use_custom_trainer: - config["get_trainer_func"] = custom_trainer - model = BERT4RecModel(**config) - actual = model.get_config(simple_types=simple_types) - - expected = config.copy() - expected["cls"] = BERT4RecModel - - if simple_types: - simple_types_params = { - "cls": "BERT4RecModel", - "item_net_block_types": ["rectools.models.nn.item_net.IdEmbeddingsItemNet"], - "item_net_constructor_type": "rectools.models.nn.item_net.SumOfEmbeddingsConstructor", - "pos_encoding_type": "rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding", - "transformer_layers_type": "rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers", - "data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator", - "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", - "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", - } - expected.update(simple_types_params) - if use_custom_trainer: - expected["get_trainer_func"] = "tests.models.nn.transformers.utils.custom_trainer" - - assert actual == expected - - @pytest.mark.parametrize("use_custom_trainer", (True, False)) - @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config_and_from_config_compatibility( - self, - simple_types: bool, - initial_config: tp.Dict[str, tp.Any], - use_custom_trainer: bool, - ) -> None: - dataset = DATASET - model = BERT4RecModel - updated_params = { - "n_blocks": 1, - "n_heads": 1, - "n_factors": 10, - "session_max_len": 5, - "epochs": 1, - } - config = initial_config.copy() - config.update(updated_params) - if use_custom_trainer: - config["get_trainer_func"] = custom_trainer - - def get_reco(model: BERT4RecModel) -> pd.DataFrame: - return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) - - model_1 = model.from_config(initial_config) - reco_1 = get_reco(model_1) - config_1 = model_1.get_config(simple_types=simple_types) - - self._seed_everything() - model_2 = model.from_config(config_1) - reco_2 = get_reco(model_2) - config_2 = model_2.get_config(simple_types=simple_types) - - assert config_1 == config_2 - pd.testing.assert_frame_equal(reco_1, reco_2) - - def test_default_config_and_default_model_params_are_the_same(self) -> None: - default_config: tp.Dict[str, int] = {} - model = BERT4RecModel() - assert_default_config_and_default_model_params_are_the_same(model, default_config) +# class TestBERT4RecModelConfiguration: +# def setup_method(self) -> None: +# self._seed_everything() + +# def _seed_everything(self) -> None: +# torch.use_deterministic_algorithms(True) +# seed_everything(32, workers=True) + +# @pytest.fixture +# def initial_config(self) -> tp.Dict[str, tp.Any]: +# config = { +# "n_blocks": 2, +# "n_heads": 4, +# "n_factors": 64, +# "use_pos_emb": False, +# "use_causal_attn": False, +# "use_key_padding_mask": True, +# "dropout_rate": 0.5, +# "session_max_len": 10, +# "dataloader_num_workers": 0, +# "batch_size": 1024, +# "loss": "softmax", +# "n_negatives": 10, +# "gbce_t": 0.5, +# "lr": 0.001, +# "epochs": 10, +# "verbose": 1, +# "deterministic": True, +# "recommend_torch_device": None, +# "recommend_batch_size": 256, +# "train_min_user_interactions": 2, +# "item_net_block_types": (IdEmbeddingsItemNet,), +# "item_net_constructor_type": SumOfEmbeddingsConstructor, +# "pos_encoding_type": LearnableInversePositionalEncoding, +# "transformer_layers_type": PreLNTransformerLayers, +# "data_preparator_type": BERT4RecDataPreparator, +# "lightning_module_type": TransformerLightningModule, +# "mask_prob": 0.15, +# "get_val_mask_func": leave_one_out_mask, +# "get_trainer_func": None, +# "data_preparator_kwargs": None, +# "transformer_layers_kwargs": None, +# "item_net_constructor_kwargs": None, +# "pos_encoding_kwargs": None, +# "lightning_module_kwargs": None, +# } +# return config + +# @pytest.mark.parametrize("use_custom_trainer", (True, False)) +# def test_from_config(self, initial_config: tp.Dict[str, tp.Any], use_custom_trainer: bool) -> None: +# config = initial_config +# if use_custom_trainer: +# config["get_trainer_func"] = custom_trainer +# model = BERT4RecModel.from_config(initial_config) + +# for key, config_value in initial_config.items(): +# assert getattr(model, key) == config_value + +# assert model._trainer is not None # pylint: disable = protected-access + +# @pytest.mark.parametrize("use_custom_trainer", (True, False)) +# @pytest.mark.parametrize("simple_types", (False, True)) +# def test_get_config( +# self, simple_types: bool, initial_config: tp.Dict[str, tp.Any], use_custom_trainer: bool +# ) -> None: +# config = initial_config +# if use_custom_trainer: +# config["get_trainer_func"] = custom_trainer +# model = BERT4RecModel(**config) +# actual = model.get_config(simple_types=simple_types) + +# expected = config.copy() +# expected["cls"] = BERT4RecModel + +# if simple_types: +# simple_types_params = { +# "cls": "BERT4RecModel", +# "item_net_block_types": ["rectools.models.nn.item_net.IdEmbeddingsItemNet"], +# "item_net_constructor_type": "rectools.models.nn.item_net.SumOfEmbeddingsConstructor", +# "pos_encoding_type": "rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding", +# "transformer_layers_type": "rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers", +# "data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator", +# "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", +# "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", +# } +# expected.update(simple_types_params) +# if use_custom_trainer: +# expected["get_trainer_func"] = "tests.models.nn.transformers.utils.custom_trainer" + +# assert actual == expected + +# @pytest.mark.parametrize("use_custom_trainer", (True, False)) +# @pytest.mark.parametrize("simple_types", (False, True)) +# def test_get_config_and_from_config_compatibility( +# self, +# simple_types: bool, +# initial_config: tp.Dict[str, tp.Any], +# use_custom_trainer: bool, +# ) -> None: +# dataset = DATASET +# model = BERT4RecModel +# updated_params = { +# "n_blocks": 1, +# "n_heads": 1, +# "n_factors": 10, +# "session_max_len": 5, +# "epochs": 1, +# } +# config = initial_config.copy() +# config.update(updated_params) +# if use_custom_trainer: +# config["get_trainer_func"] = custom_trainer + +# def get_reco(model: BERT4RecModel) -> pd.DataFrame: +# return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) + +# model_1 = model.from_config(initial_config) +# reco_1 = get_reco(model_1) +# config_1 = model_1.get_config(simple_types=simple_types) + +# self._seed_everything() +# model_2 = model.from_config(config_1) +# reco_2 = get_reco(model_2) +# config_2 = model_2.get_config(simple_types=simple_types) + +# assert config_1 == config_2 +# pd.testing.assert_frame_equal(reco_1, reco_2) + +# def test_default_config_and_default_model_params_are_the_same(self) -> None: +# default_config: tp.Dict[str, int] = {} +# model = BERT4RecModel() +# assert_default_config_and_default_model_params_are_the_same(model, default_config) From 3750ce850e68e1108bc8ecbcabf63084aa0ae1d2 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Tue, 18 Mar 2025 00:52:00 +0300 Subject: [PATCH 03/30] reverted `rank torch` to the original code --- rectools/models/rank/rank_torch.py | 123 +++++++---------------------- 1 file changed, 29 insertions(+), 94 deletions(-) diff --git a/rectools/models/rank/rank_torch.py b/rectools/models/rank/rank_torch.py index a71d3d3f..e883f134 100644 --- a/rectools/models/rank/rank_torch.py +++ b/rectools/models/rank/rank_torch.py @@ -67,7 +67,7 @@ def __init__( self.device = torch.device(device) self.batch_size = batch_size self.distance = distance - self._scorer, self._higher_is_better = get_scorer(distance) + self._scorer, self._higher_is_better = self._get_scorer(distance) self.subjects_factors = self._normalize_tensor(subjects_factors) self.objects_factors = self._normalize_tensor(objects_factors) @@ -175,6 +175,33 @@ def rank( all_scores, ) + def _get_scorer( + self, distance: Distance + ) -> tp.Tuple[tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], bool]: + """Return scorer and higher_is_better flag""" + if distance == Distance.DOT: + return self._dot_score, True + + if distance == Distance.COSINE: + return self._cosine_score, True + + if distance == Distance.EUCLIDEAN: + return self._euclid_score, False + + raise NotImplementedError(f"distance {distance} is not supported") # pragma: no cover + + def _euclid_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + return torch.cdist(user_embs.unsqueeze(0), item_embs.unsqueeze(0)).squeeze(0) + + def _cosine_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + user_embs = user_embs / torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1) + item_embs = item_embs / torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1) + + return user_embs @ item_embs.T + + def _dot_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + return user_embs @ item_embs.T + def _normalize_tensor( self, tensor: tp.Union[np.ndarray, sparse.csr_matrix, torch.Tensor], @@ -188,96 +215,4 @@ def _normalize_tensor( if self.dtype is not None: tensor = tensor.to(self.dtype) - return tensor - - -def get_scorer(distance: Distance) -> tp.Tuple[tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], bool]: - """ - Return scorer and higher_is_better flag - - Parameters - ---------- - distance : Distance - Distance metric. - - Returns - ------- - tuple(Callable(torch.Tensor, torch.Tensor), torch.Tensor) - Dictionary where keys are the same with keys in `metrics` - and values are metric calculation results. - - Raises - ------ - ValueError - If distance is not supported. - """ - if distance == Distance.DOT: - return dot_score, True - - if distance == Distance.COSINE: - return cosine_score, True - - if distance == Distance.EUCLIDEAN: - return euclid_score, False - - raise NotImplementedError(f"distance {distance} is not supported") # pragma: no cover - - -def euclid_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - """ - Calculate Euclidean score. - - Parameters - ---------- - user_embs : torch.Tensor - User embeddings. - item_embs : torch.Tensor - Item embeddings. - - Returns - ------- - torch.Tensor - Result Euclidean score. - """ - return torch.cdist(user_embs.unsqueeze(0), item_embs.unsqueeze(0)).squeeze(0) - - -def cosine_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - """ - Calculate cosine score. - - Parameters - ---------- - user_embs : torch.Tensor - User embeddings. - item_embs : torch.Tensor - Item embeddings. - - Returns - ------- - torch.Tensor - Result cosine score. - """ - user_embs = user_embs / torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1) - item_embs = item_embs / torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1) - - return user_embs @ item_embs.T - - -def dot_score(user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - """ - Calculate dot product score. - - Parameters - ---------- - user_embs : torch.Tensor - User embeddings. - item_embs : torch.Tensor - Item embeddings. - - Returns - ------- - torch.Tensor - Result dot product score. - """ - return user_embs @ item_embs.T + return tensor \ No newline at end of file From 6f3f4a4e79f2a03665a25d43edad2568c5521139 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Tue, 18 Mar 2025 00:59:55 +0300 Subject: [PATCH 04/30] Fixed cosine fusion on train stage. --- rectools/models/nn/transformers/lightning.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 7b71ba4f..eafe4139 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -24,7 +24,6 @@ from rectools.dataset.dataset import Dataset, DatasetSchemaDict from rectools.models.base import InternalRecoTriplet from rectools.models.rank import Distance, TorchRanker -from rectools.models.rank.rank_torch import get_scorer from rectools.types import InternalIdsArray from .data_preparator import TransformerDataPreparatorBase @@ -61,6 +60,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma """ u2i_dist_available = [Distance.DOT, Distance.COSINE] + epsilon_cosine_dist = 1e-8 def __init__( self, @@ -98,9 +98,7 @@ def __init__( if u2i_dist not in self.u2i_dist_available: raise ValueError("`u2i_distance` can only be either `Distance.DOT` or `Distance.COSINE`.") - self.u2i_dist = u2i_dist - self.scorer, _ = get_scorer(u2i_dist) self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) @@ -224,9 +222,16 @@ def _calc_custom_loss_outputs( ) -> tp.Dict[str, torch.Tensor]: raise ValueError(f"loss {self.loss} is not supported") # pragma: no cover + def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: + embeddings = embeddings / (torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) + self.epsilon_cosine_dist) + return embeddings + def _get_full_catalog_logits(self, x: torch.Tensor) -> torch.Tensor: item_embs, session_embs = self.torch_model(x) - logits = self.scorer(session_embs, item_embs) + if self.u2i_dist == Distance.COSINE: + session_embs = self._get_embeddings_norm(session_embs) + item_embs = self._get_embeddings_norm(item_embs) + logits = session_embs @ item_embs.T return logits def _get_pos_neg_logits(self, x: torch.Tensor, y: torch.Tensor, negatives: torch.Tensor) -> torch.Tensor: @@ -234,8 +239,11 @@ def _get_pos_neg_logits(self, x: torch.Tensor, y: torch.Tensor, negatives: torch item_embs, session_embs = self.torch_model(x) pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] pos_neg_embs = item_embs[pos_neg] # [batch_size, session_max_len, n_negatives + 1, n_factors] + if self.u2i_dist == Distance.COSINE: + session_embs = self._get_embeddings_norm(session_embs) + item_embs = self._get_embeddings_norm(item_embs) # [batch_size, session_max_len, n_negatives + 1] - logits = self.scorer(session_embs.unsqueeze(-1), pos_neg_embs).squeeze(-1) + logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) return logits def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int, n_negatives: int) -> torch.Tensor: From f387186681c06e330b41d1ce4b51de39e9f207bb Mon Sep 17 00:00:00 2001 From: In48semenov Date: Tue, 18 Mar 2025 01:01:34 +0300 Subject: [PATCH 05/30] Fixed u2i tests with cosine distance and config tests. --- tests/models/nn/transformers/test_bert4rec.py | 268 +++++++++--------- tests/models/nn/transformers/test_sasrec.py | 34 ++- 2 files changed, 168 insertions(+), 134 deletions(-) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index fe309070..e6cbfed3 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -303,7 +303,7 @@ def test_u2i_losses( loss: str, get_trainer_func: TrainerCallable, expected: pd.DataFrame, - u2i_dist: Distance + u2i_dist: Distance, ) -> None: model = BERT4RecModel( n_negatives=2, @@ -799,135 +799,137 @@ def test_get_dataloader_val( assert torch.equal(value, val_batch[key]) -# class TestBERT4RecModelConfiguration: -# def setup_method(self) -> None: -# self._seed_everything() - -# def _seed_everything(self) -> None: -# torch.use_deterministic_algorithms(True) -# seed_everything(32, workers=True) - -# @pytest.fixture -# def initial_config(self) -> tp.Dict[str, tp.Any]: -# config = { -# "n_blocks": 2, -# "n_heads": 4, -# "n_factors": 64, -# "use_pos_emb": False, -# "use_causal_attn": False, -# "use_key_padding_mask": True, -# "dropout_rate": 0.5, -# "session_max_len": 10, -# "dataloader_num_workers": 0, -# "batch_size": 1024, -# "loss": "softmax", -# "n_negatives": 10, -# "gbce_t": 0.5, -# "lr": 0.001, -# "epochs": 10, -# "verbose": 1, -# "deterministic": True, -# "recommend_torch_device": None, -# "recommend_batch_size": 256, -# "train_min_user_interactions": 2, -# "item_net_block_types": (IdEmbeddingsItemNet,), -# "item_net_constructor_type": SumOfEmbeddingsConstructor, -# "pos_encoding_type": LearnableInversePositionalEncoding, -# "transformer_layers_type": PreLNTransformerLayers, -# "data_preparator_type": BERT4RecDataPreparator, -# "lightning_module_type": TransformerLightningModule, -# "mask_prob": 0.15, -# "get_val_mask_func": leave_one_out_mask, -# "get_trainer_func": None, -# "data_preparator_kwargs": None, -# "transformer_layers_kwargs": None, -# "item_net_constructor_kwargs": None, -# "pos_encoding_kwargs": None, -# "lightning_module_kwargs": None, -# } -# return config - -# @pytest.mark.parametrize("use_custom_trainer", (True, False)) -# def test_from_config(self, initial_config: tp.Dict[str, tp.Any], use_custom_trainer: bool) -> None: -# config = initial_config -# if use_custom_trainer: -# config["get_trainer_func"] = custom_trainer -# model = BERT4RecModel.from_config(initial_config) - -# for key, config_value in initial_config.items(): -# assert getattr(model, key) == config_value - -# assert model._trainer is not None # pylint: disable = protected-access - -# @pytest.mark.parametrize("use_custom_trainer", (True, False)) -# @pytest.mark.parametrize("simple_types", (False, True)) -# def test_get_config( -# self, simple_types: bool, initial_config: tp.Dict[str, tp.Any], use_custom_trainer: bool -# ) -> None: -# config = initial_config -# if use_custom_trainer: -# config["get_trainer_func"] = custom_trainer -# model = BERT4RecModel(**config) -# actual = model.get_config(simple_types=simple_types) - -# expected = config.copy() -# expected["cls"] = BERT4RecModel - -# if simple_types: -# simple_types_params = { -# "cls": "BERT4RecModel", -# "item_net_block_types": ["rectools.models.nn.item_net.IdEmbeddingsItemNet"], -# "item_net_constructor_type": "rectools.models.nn.item_net.SumOfEmbeddingsConstructor", -# "pos_encoding_type": "rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding", -# "transformer_layers_type": "rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers", -# "data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator", -# "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", -# "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", -# } -# expected.update(simple_types_params) -# if use_custom_trainer: -# expected["get_trainer_func"] = "tests.models.nn.transformers.utils.custom_trainer" - -# assert actual == expected - -# @pytest.mark.parametrize("use_custom_trainer", (True, False)) -# @pytest.mark.parametrize("simple_types", (False, True)) -# def test_get_config_and_from_config_compatibility( -# self, -# simple_types: bool, -# initial_config: tp.Dict[str, tp.Any], -# use_custom_trainer: bool, -# ) -> None: -# dataset = DATASET -# model = BERT4RecModel -# updated_params = { -# "n_blocks": 1, -# "n_heads": 1, -# "n_factors": 10, -# "session_max_len": 5, -# "epochs": 1, -# } -# config = initial_config.copy() -# config.update(updated_params) -# if use_custom_trainer: -# config["get_trainer_func"] = custom_trainer - -# def get_reco(model: BERT4RecModel) -> pd.DataFrame: -# return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) - -# model_1 = model.from_config(initial_config) -# reco_1 = get_reco(model_1) -# config_1 = model_1.get_config(simple_types=simple_types) - -# self._seed_everything() -# model_2 = model.from_config(config_1) -# reco_2 = get_reco(model_2) -# config_2 = model_2.get_config(simple_types=simple_types) - -# assert config_1 == config_2 -# pd.testing.assert_frame_equal(reco_1, reco_2) - -# def test_default_config_and_default_model_params_are_the_same(self) -> None: -# default_config: tp.Dict[str, int] = {} -# model = BERT4RecModel() -# assert_default_config_and_default_model_params_are_the_same(model, default_config) +class TestBERT4RecModelConfiguration: + def setup_method(self) -> None: + self._seed_everything() + + def _seed_everything(self) -> None: + torch.use_deterministic_algorithms(True) + seed_everything(32, workers=True) + + @pytest.fixture + def initial_config(self) -> tp.Dict[str, tp.Any]: + config = { + "n_blocks": 2, + "n_heads": 4, + "n_factors": 64, + "use_pos_emb": False, + "use_causal_attn": False, + "use_key_padding_mask": True, + "dropout_rate": 0.5, + "session_max_len": 10, + "dataloader_num_workers": 0, + "batch_size": 1024, + "loss": "softmax", + "n_negatives": 10, + "gbce_t": 0.5, + "lr": 0.001, + "epochs": 10, + "verbose": 1, + "deterministic": True, + "recommend_torch_device": None, + "recommend_batch_size": 256, + "train_min_user_interactions": 2, + "u2i_dist": Distance.DOT, + "item_net_block_types": (IdEmbeddingsItemNet,), + "item_net_constructor_type": SumOfEmbeddingsConstructor, + "pos_encoding_type": LearnableInversePositionalEncoding, + "transformer_layers_type": PreLNTransformerLayers, + "data_preparator_type": BERT4RecDataPreparator, + "lightning_module_type": TransformerLightningModule, + "mask_prob": 0.15, + "get_val_mask_func": leave_one_out_mask, + "get_trainer_func": None, + "data_preparator_kwargs": None, + "transformer_layers_kwargs": None, + "item_net_constructor_kwargs": None, + "pos_encoding_kwargs": None, + "lightning_module_kwargs": None, + } + return config + + @pytest.mark.parametrize("use_custom_trainer", (True, False)) + def test_from_config(self, initial_config: tp.Dict[str, tp.Any], use_custom_trainer: bool) -> None: + config = initial_config + if use_custom_trainer: + config["get_trainer_func"] = custom_trainer + model = BERT4RecModel.from_config(initial_config) + + for key, config_value in initial_config.items(): + assert getattr(model, key) == config_value + + assert model._trainer is not None # pylint: disable = protected-access + + @pytest.mark.parametrize("use_custom_trainer", (True, False)) + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config( + self, simple_types: bool, initial_config: tp.Dict[str, tp.Any], use_custom_trainer: bool + ) -> None: + config = initial_config + if use_custom_trainer: + config["get_trainer_func"] = custom_trainer + model = BERT4RecModel(**config) + actual = model.get_config(simple_types=simple_types) + + expected = config.copy() + expected["cls"] = BERT4RecModel + + if simple_types: + simple_types_params = { + "cls": "BERT4RecModel", + "item_net_block_types": ["rectools.models.nn.item_net.IdEmbeddingsItemNet"], + "item_net_constructor_type": "rectools.models.nn.item_net.SumOfEmbeddingsConstructor", + "pos_encoding_type": "rectools.models.nn.transformers.net_blocks.LearnableInversePositionalEncoding", + "transformer_layers_type": "rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers", + "data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator", + "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", + "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", + "u2i_dist": Distance.DOT.value, + } + expected.update(simple_types_params) + if use_custom_trainer: + expected["get_trainer_func"] = "tests.models.nn.transformers.utils.custom_trainer" + + assert actual == expected + + @pytest.mark.parametrize("use_custom_trainer", (True, False)) + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config_and_from_config_compatibility( + self, + simple_types: bool, + initial_config: tp.Dict[str, tp.Any], + use_custom_trainer: bool, + ) -> None: + dataset = DATASET + model = BERT4RecModel + updated_params = { + "n_blocks": 1, + "n_heads": 1, + "n_factors": 10, + "session_max_len": 5, + "epochs": 1, + } + config = initial_config.copy() + config.update(updated_params) + if use_custom_trainer: + config["get_trainer_func"] = custom_trainer + + def get_reco(model: BERT4RecModel) -> pd.DataFrame: + return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) + + model_1 = model.from_config(initial_config) + reco_1 = get_reco(model_1) + config_1 = model_1.get_config(simple_types=simple_types) + + self._seed_everything() + model_2 = model.from_config(config_1) + reco_2 = get_reco(model_2) + config_2 = model_2.get_config(simple_types=simple_types) + + assert config_1 == config_2 + pd.testing.assert_frame_equal(reco_1, reco_2) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, int] = {} + model = BERT4RecModel() + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 58442de3..02a57a6c 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -27,6 +27,7 @@ from rectools.columns import Columns from rectools.dataset import Dataset, IdMap, Interactions from rectools.models import SASRecModel +from rectools.models.rank import Distance from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor from rectools.models.nn.transformers.base import ( LearnableInversePositionalEncoding, @@ -243,6 +244,7 @@ def get_trainer() -> Trainer: ), ), ) + @pytest.mark.parametrize("u2i_dist", (Distance.DOT, Distance.COSINE)) def test_u2i( self, dataset_devices: Dataset, @@ -253,6 +255,7 @@ def test_u2i( expected_cpu_1: pd.DataFrame, expected_cpu_2: pd.DataFrame, expected_gpu: pd.DataFrame, + u2i_dist: Distance, ) -> None: if devices != 1: @@ -280,6 +283,7 @@ def get_trainer() -> Trainer: recommend_torch_device=recommend_torch_device, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer, + u2i_dist=u2i_dist, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -297,7 +301,7 @@ def get_trainer() -> Trainer: ) @pytest.mark.parametrize( - "loss,expected", + "loss,expected,u2i_dist", ( ( "BCE", @@ -308,6 +312,7 @@ def get_trainer() -> Trainer: Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), + Distance.DOT, ), ( "gBCE", @@ -318,6 +323,29 @@ def get_trainer() -> Trainer: Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), + Distance.DOT, + ), + ( + "BCE", + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 13, 14, 17, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + Distance.COSINE, + ), + ( + "gBCE", + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 13, 14, 17, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + Distance.COSINE, ), ), ) @@ -327,6 +355,7 @@ def test_u2i_losses( loss: str, get_trainer_func: TrainerCallable, expected: pd.DataFrame, + u2i_dist: Distance, ) -> None: model = SASRecModel( n_negatives=2, @@ -340,6 +369,7 @@ def test_u2i_losses( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, loss=loss, + u2i_dist=u2i_dist, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -895,6 +925,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "recommend_torch_device": None, "recommend_batch_size": 256, "train_min_user_interactions": 2, + "u2i_dist": Distance.DOT, "item_net_block_types": (IdEmbeddingsItemNet,), "item_net_constructor_type": SumOfEmbeddingsConstructor, "pos_encoding_type": LearnableInversePositionalEncoding, @@ -947,6 +978,7 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.sasrec.SASRecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", + "u2i_dist": Distance.DOT.value, } expected.update(simple_types_params) if use_custom_trainer: From f5cd7737ca3acaab8f6ea0c5732522c68e8c8a96 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Tue, 18 Mar 2025 10:13:47 +0300 Subject: [PATCH 06/30] Fixed code style. --- rectools/models/nn/transformers/bert4rec.py | 1 + rectools/models/nn/transformers/sasrec.py | 1 + rectools/models/rank/rank_torch.py | 2 +- tests/models/nn/transformers/test_bert4rec.py | 2 +- tests/models/nn/transformers/test_sasrec.py | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 47045731..e45063a5 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -20,6 +20,7 @@ import torch from rectools.models.rank import Distance + from ..item_net import ( CatFeaturesItemNet, IdEmbeddingsItemNet, diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 6607f163..fb0367d6 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -20,6 +20,7 @@ from torch import nn from rectools.models.rank import Distance + from ..item_net import ( CatFeaturesItemNet, IdEmbeddingsItemNet, diff --git a/rectools/models/rank/rank_torch.py b/rectools/models/rank/rank_torch.py index e883f134..ed091acc 100644 --- a/rectools/models/rank/rank_torch.py +++ b/rectools/models/rank/rank_torch.py @@ -215,4 +215,4 @@ def _normalize_tensor( if self.dtype is not None: tensor = tensor.to(self.dtype) - return tensor \ No newline at end of file + return tensor diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index e6cbfed3..b875d2e9 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -25,7 +25,6 @@ from rectools.columns import Columns from rectools.dataset import Dataset from rectools.models import BERT4RecModel -from rectools.models.rank import Distance from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor from rectools.models.nn.transformers.base import ( LearnableInversePositionalEncoding, @@ -34,6 +33,7 @@ TransformerLightningModule, ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable +from rectools.models.rank import Distance from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 02a57a6c..60d3b9e6 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -27,7 +27,6 @@ from rectools.columns import Columns from rectools.dataset import Dataset, IdMap, Interactions from rectools.models import SASRecModel -from rectools.models.rank import Distance from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor from rectools.models.nn.transformers.base import ( LearnableInversePositionalEncoding, @@ -36,6 +35,7 @@ TransformerTorchBackbone, ) from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers +from rectools.models.rank import Distance from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, From bfd4c63c081096fbb8f3f493d4849f252f0682b4 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Tue, 18 Mar 2025 10:34:08 +0300 Subject: [PATCH 07/30] Added test with raises u2i_dist. --- tests/models/nn/transformers/test_base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index bc61bea5..b8aef5c7 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -28,6 +28,7 @@ from rectools.models import BERT4RecModel, SASRecModel, load_model from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet from rectools.models.nn.transformers.base import TransformerModelBase +from rectools.models.rank import Distance from tests.models.data import INTERACTIONS from tests.models.utils import assert_save_load_do_not_change_model @@ -317,3 +318,10 @@ def test_log_metrics( actual_columns = list(pd.read_csv(metrics_path).columns) assert actual_columns == expected_columns + + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + def test_raises_when_incorrect_u2i_dist(self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset) -> None: + model_config = {"u2i_dist": Distance.EUCLIDEAN} + with pytest.raises(ValueError): + model = model_cls.from_config(model_config) + model.fit(dataset=dataset) From 2548c643139b71bf1c473cdb3c179eebbf1a684a Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 22 Mar 2025 23:30:28 +0300 Subject: [PATCH 08/30] Added similarity module base and distance. --- rectools/models/nn/transformers/base.py | 28 ++++++++--- rectools/models/nn/transformers/bert4rec.py | 9 ++-- rectools/models/nn/transformers/lightning.py | 33 ++++++------- rectools/models/nn/transformers/sasrec.py | 9 ++-- rectools/models/nn/transformers/similarity.py | 48 +++++++++++++++++++ 5 files changed, 94 insertions(+), 33 deletions(-) create mode 100644 rectools/models/nn/transformers/similarity.py diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 2e12a020..984c9779 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -28,7 +28,6 @@ from rectools import ExternalIds from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig -from rectools.models.rank import Distance from rectools.types import InternalIdsArray from rectools.utils.misc import get_class_or_function_full_path, import_object @@ -47,6 +46,7 @@ PreLNTransformerLayers, TransformerLayersBase, ) +from .similarity import SimilarityModuleBase from .torch_backbone import TransformerTorchBackbone InitKwargs = tp.Dict[str, tp.Any] @@ -98,6 +98,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: ), ] +SimilarityModuleType = tpe.Annotated[ + tp.Type[SimilarityModuleBase], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + TransformerDataPreparatorType = tpe.Annotated[ tp.Type[TransformerDataPreparatorBase], BeforeValidator(_get_class_obj), @@ -178,13 +188,13 @@ class TransformerModelConfig(ModelConfig): deterministic: bool = False recommend_batch_size: int = 256 recommend_torch_device: tp.Optional[str] = None - train_min_user_interactions: int = 2 - u2i_dist: Distance = Distance.DOT + train_min_user_interactions: int = (2,) item_net_block_types: ItemNetBlockTypes = (IdEmbeddingsItemNet, CatFeaturesItemNet) item_net_constructor_type: ItemNetConstructorType = SumOfEmbeddingsConstructor pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: TransformerLightningModuleType = TransformerLightningModule + similarity_module_type: SimilarityModuleType = (SimilarityModuleBase,) get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None @@ -235,11 +245,11 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_batch_size: int = 256, recommend_torch_device: tp.Optional[str] = None, train_min_user_interactions: int = 2, - u2i_dist: Distance = Distance.DOT, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, + similarity_module_type: tp.Type[SimilarityModuleBase] = SimilarityModuleBase, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, data_preparator_kwargs: tp.Optional[InitKwargs] = None, @@ -247,6 +257,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs: tp.Optional[InitKwargs] = None, pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, + similarity_module_kwargs: tp.Optional[InitKwargs] = None, **kwargs: tp.Any, ) -> None: super().__init__(verbose=verbose) @@ -271,7 +282,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.recommend_batch_size = recommend_batch_size self.recommend_torch_device = recommend_torch_device self.train_min_user_interactions = train_min_user_interactions - self.u2i_dist = u2i_dist + self.similarity_module_type = similarity_module_type self.item_net_block_types = item_net_block_types self.item_net_constructor_type = item_net_constructor_type self.pos_encoding_type = pos_encoding_type @@ -283,6 +294,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.item_net_constructor_kwargs = item_net_constructor_kwargs self.pos_encoding_kwargs = pos_encoding_kwargs self.lightning_module_kwargs = lightning_module_kwargs + self.similarity_module_kwargs = similarity_module_kwargs self._init_data_preparator() self._init_trainer() @@ -373,6 +385,9 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone use_key_padding_mask=self.use_key_padding_mask, ) + def _init_similarity_model(self) -> SimilarityModuleBase: + return self.similarity_module_type(loss_type=self.loss, **self.similarity_module_kwargs) + def _init_lightning_model( self, torch_model: TransformerTorchBackbone, @@ -380,6 +395,7 @@ def _init_lightning_model( item_external_ids: ExternalIds, model_config: tp.Dict[str, tp.Any], ) -> None: + similarity_model = self._init_similarity_model() self.lightning_model = self.lightning_module_type( torch_model=torch_model, dataset_schema=dataset_schema, @@ -394,7 +410,7 @@ def _init_lightning_model( train_loss_name=self.train_loss_name, val_loss_name=self.val_loss_name, adam_betas=(0.9, 0.98), - u2i_dist=self.u2i_dist, + similarity_model=similarity_model, **self._get_kwargs(self.lightning_module_kwargs), ) diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index e45063a5..f72313eb 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -19,8 +19,6 @@ import numpy as np import torch -from rectools.models.rank import Distance - from ..item_net import ( CatFeaturesItemNet, IdEmbeddingsItemNet, @@ -46,6 +44,7 @@ PreLNTransformerLayers, TransformerLayersBase, ) +from .similarity import SimilarityModuleBase class BERT4RecDataPreparator(TransformerDataPreparatorBase): @@ -318,13 +317,13 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals use_pos_emb: bool = True, use_key_padding_mask: bool = True, use_causal_attn: bool = False, - u2i_dist: Distance = Distance.DOT, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers, data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, + similarity_module_type: tp.Type[SimilarityModuleBase] = SimilarityModuleBase, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -337,6 +336,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs: tp.Optional[InitKwargs] = None, pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, + similarity_module_kwargs: tp.Optional[InitKwargs] = None, ): self.mask_prob = mask_prob @@ -365,7 +365,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_n_threads=recommend_n_threads, recommend_use_torch_ranking=recommend_use_torch_ranking, train_min_user_interactions=train_min_user_interactions, - u2i_dist=u2i_dist, + similarity_module_type=similarity_module_type, item_net_block_types=item_net_block_types, item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, @@ -378,6 +378,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs=item_net_constructor_kwargs, pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, + similarity_module_kwargs=similarity_module_kwargs, ) def _init_data_preparator(self) -> None: diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index eafe4139..34ec58df 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -27,6 +27,7 @@ from rectools.types import InternalIdsArray from .data_preparator import TransformerDataPreparatorBase +from .similarity import SimilarityModuleBase, SimilarityModuleDistance from .torch_backbone import TransformerTorchBackbone # #### -------------- Lightning Base Model -------------- #### # @@ -73,11 +74,11 @@ def __init__( lr: float, gbce_t: float, loss: str, + similarity_model: SimilarityModuleBase, verbose: int = 0, train_loss_name: str = "train_loss", val_loss_name: str = "val_loss", adam_betas: tp.Tuple[float, float] = (0.9, 0.98), - u2i_dist: Distance = Distance.DOT, **kwargs: tp.Any, ): super().__init__() @@ -94,12 +95,9 @@ def __init__( self.verbose = verbose self.train_loss_name = train_loss_name self.val_loss_name = val_loss_name + self.similarity_model = similarity_model self.item_embs: torch.Tensor - if u2i_dist not in self.u2i_dist_available: - raise ValueError("`u2i_distance` can only be either `Distance.DOT` or `Distance.COSINE`.") - self.u2i_dist = u2i_dist - self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) def configure_optimizers(self) -> torch.optim.Adam: @@ -228,10 +226,7 @@ def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: def _get_full_catalog_logits(self, x: torch.Tensor) -> torch.Tensor: item_embs, session_embs = self.torch_model(x) - if self.u2i_dist == Distance.COSINE: - session_embs = self._get_embeddings_norm(session_embs) - item_embs = self._get_embeddings_norm(item_embs) - logits = session_embs @ item_embs.T + logits = self.similarity_model(session_embs, item_embs) return logits def _get_pos_neg_logits(self, x: torch.Tensor, y: torch.Tensor, negatives: torch.Tensor) -> torch.Tensor: @@ -239,11 +234,8 @@ def _get_pos_neg_logits(self, x: torch.Tensor, y: torch.Tensor, negatives: torch item_embs, session_embs = self.torch_model(x) pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] pos_neg_embs = item_embs[pos_neg] # [batch_size, session_max_len, n_negatives + 1, n_factors] - if self.u2i_dist == Distance.COSINE: - session_embs = self._get_embeddings_norm(session_embs) - item_embs = self._get_embeddings_norm(item_embs) # [batch_size, session_max_len, n_negatives + 1] - logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) + logits = self.similarity_model(session_embs, pos_neg_embs) return logits def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int, n_negatives: int) -> torch.Tensor: @@ -356,12 +348,15 @@ def _recommend_u2i( user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device) - ranker = TorchRanker( - distance=self.u2i_dist, - device=item_embs.device, - subjects_factors=user_embs[user_ids], - objects_factors=item_embs, - ) + if isinstance(self.similarity_model, SimilarityModuleDistance): + ranker = TorchRanker( + distance=self.similarity_model.dist, + device=item_embs.device, + subjects_factors=user_embs[user_ids], + objects_factors=item_embs, + ) + else: + raise NotImplementedError() user_ids_indices, all_reco_ids, all_scores = ranker.rank( subject_ids=np.arange(len(user_ids)), # n_rec_users diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index fb0367d6..457b9c0a 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -19,8 +19,6 @@ import torch from torch import nn -from rectools.models.rank import Distance - from ..item_net import ( CatFeaturesItemNet, IdEmbeddingsItemNet, @@ -46,6 +44,7 @@ PositionalEncodingBase, TransformerLayersBase, ) +from .similarity import SimilarityModuleBase class SASRecDataPreparator(TransformerDataPreparatorBase): @@ -397,13 +396,13 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals use_pos_emb: bool = True, use_key_padding_mask: bool = False, use_causal_attn: bool = True, - u2i_dist: Distance = Distance.DOT, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, # SASRec authors net data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, + similarity_module_type: tp.Type[SimilarityModuleBase] = SimilarityModuleBase, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -415,6 +414,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs: tp.Optional[InitKwargs] = None, pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, + similarity_module_kwargs: tp.Optional[InitKwargs] = None, ): super().__init__( transformer_layers_type=transformer_layers_type, @@ -441,7 +441,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_n_threads=recommend_n_threads, recommend_use_torch_ranking=recommend_use_torch_ranking, train_min_user_interactions=train_min_user_interactions, - u2i_dist=u2i_dist, + similarity_module_type=similarity_module_type, item_net_block_types=item_net_block_types, item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, @@ -453,4 +453,5 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs=item_net_constructor_kwargs, pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, + similarity_module_kwargs=similarity_module_kwargs, ) diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py new file mode 100644 index 00000000..4e2c3655 --- /dev/null +++ b/rectools/models/nn/transformers/similarity.py @@ -0,0 +1,48 @@ +import typing as tp + +import torch +import torch.nn as nn + +from rectools.models.rank import Distance + + +class SimilarityModuleBase(nn.Module): + + def __init__(self, loss_type: str, *args: tp.Any, **kwargs: tp.Any) -> None: + self.loss_type = loss_type + + def forward(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + +class SimilarityModuleDistance(SimilarityModuleBase): + + dist_available: tp.List[Distance] = [Distance.DOT, Distance.COSINE] + epsilon_cosine_dist: float = 1e-8 + + def __init__(self, loss_type: str, dist: Distance = Distance.DOT) -> None: + if dist not in self.dist_available: + raise ValueError("`dist` can only be either `Distance.DOT` or `Distance.COSINE`.") + + self.dist = dist + self.loss_type = loss_type + + def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: + embeddings = embeddings / (torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) + self.epsilon_cosine_dist) + return embeddings + + def _calc_custom_score(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: # TODO + raise ValueError(f"loss {self.loss} is not supported in `DistanceSimilarity`") # pragma: no cover + + def forward(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + if self.dist == Distance.COSINE: + session_embs = self._get_embeddings_norm(session_embs) + item_embs = self._get_embeddings_norm(item_embs) + + if self.loss_type == "softmax": + scores = session_embs @ item_embs.T + elif self.loss_type in ["BCE", "gBCE"]: + scores = (item_embs @ session_embs.unsqueeze(-1)).squeeze(-1) + else: # TODO: think about it + scores = self._calc_custom_score(session_embs, item_embs) # pragma: no cover + return scores From fe607ae735f86912d90b08d46d8f94b4f35d9325 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Tue, 25 Mar 2025 18:34:43 +0300 Subject: [PATCH 09/30] Added Similarity Module to torch model. --- rectools/models/nn/transformers/base.py | 15 +-- rectools/models/nn/transformers/bert4rec.py | 7 +- rectools/models/nn/transformers/lightning.py | 61 +++-------- rectools/models/nn/transformers/sasrec.py | 7 +- rectools/models/nn/transformers/similarity.py | 102 ++++++++++++++---- .../models/nn/transformers/torch_backbone.py | 42 +++++++- tests/models/nn/transformers/test_base.py | 19 +++- tests/models/nn/transformers/test_bert4rec.py | 27 +++-- tests/models/nn/transformers/test_sasrec.py | 44 +++++--- 9 files changed, 216 insertions(+), 108 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 984c9779..27192738 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -188,13 +188,13 @@ class TransformerModelConfig(ModelConfig): deterministic: bool = False recommend_batch_size: int = 256 recommend_torch_device: tp.Optional[str] = None - train_min_user_interactions: int = (2,) + train_min_user_interactions: int = 2 item_net_block_types: ItemNetBlockTypes = (IdEmbeddingsItemNet, CatFeaturesItemNet) item_net_constructor_type: ItemNetConstructorType = SumOfEmbeddingsConstructor pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: TransformerLightningModuleType = TransformerLightningModule - similarity_module_type: SimilarityModuleType = (SimilarityModuleBase,) + similarity_module_type: SimilarityModuleType = SimilarityModuleBase get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None @@ -202,6 +202,7 @@ class TransformerModelConfig(ModelConfig): item_net_constructor_kwargs: tp.Optional[InitKwargs] = None pos_encoding_kwargs: tp.Optional[InitKwargs] = None lightning_module_kwargs: tp.Optional[InitKwargs] = None + similarity_module_kwargs: tp.Optional[InitKwargs] = None TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig) @@ -372,22 +373,24 @@ def _init_transformer_layers(self) -> TransformerLayersBase: **self._get_kwargs(self.transformer_layers_kwargs), ) + def _init_similarity_module(self) -> SimilarityModuleBase: + return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs)) + def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone: pos_encoding_layer = self._init_pos_encoding_layer() transformer_layers = self._init_transformer_layers() + similarity_module = self._init_similarity_module() return TransformerTorchBackbone( n_heads=self.n_heads, dropout_rate=self.dropout_rate, item_model=item_model, pos_encoding_layer=pos_encoding_layer, transformer_layers=transformer_layers, + similarity_module=similarity_module, use_causal_attn=self.use_causal_attn, use_key_padding_mask=self.use_key_padding_mask, ) - def _init_similarity_model(self) -> SimilarityModuleBase: - return self.similarity_module_type(loss_type=self.loss, **self.similarity_module_kwargs) - def _init_lightning_model( self, torch_model: TransformerTorchBackbone, @@ -395,7 +398,6 @@ def _init_lightning_model( item_external_ids: ExternalIds, model_config: tp.Dict[str, tp.Any], ) -> None: - similarity_model = self._init_similarity_model() self.lightning_model = self.lightning_module_type( torch_model=torch_model, dataset_schema=dataset_schema, @@ -410,7 +412,6 @@ def _init_lightning_model( train_loss_name=self.train_loss_name, val_loss_name=self.val_loss_name, adam_betas=(0.9, 0.98), - similarity_model=similarity_model, **self._get_kwargs(self.lightning_module_kwargs), ) diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index f72313eb..f88b7833 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -242,8 +242,6 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): BERT4Rec training task ("MLM") does not work with causal masking. Set this parameter to ``True`` only when you change the training task with custom `data_preparator_type` or if you are absolutely sure of what you are doing. - u2i_dist : Distance, default Distance.DOT - U2I distance metric. item_net_block_types : sequence of `type(ItemNetBase)`, default `(IdEmbeddingsItemNet, CatFeaturesItemNet)` Type of network returning item embeddings. (IdEmbeddingsItemNet,) - item embeddings based on ids. @@ -259,6 +257,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): Type of data preparator used for dataset processing and dataloader creation. lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule` Type of lightning module defining training procedure. + similarity_module_type : type(SimilarityModuleBase), default `SimilarityModuleBase` + Type of similarity module. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` @@ -292,6 +292,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): lightning_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `lightning_module_type` initialization. Make sure all dict values have JSON serializable types. + similarity_module_kwargs: optional(dict), default ``None`` + Additional keyword arguments to pass during `similarity_module_type` initialization. + Make sure all dict values have JSON serializable types. """ config_class = BERT4RecModelConfig diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 34ec58df..fc8316e0 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -15,7 +15,6 @@ import typing as tp from collections.abc import Hashable -import numpy as np import torch from pytorch_lightning import LightningModule from torch.utils.data import DataLoader @@ -27,7 +26,6 @@ from rectools.types import InternalIdsArray from .data_preparator import TransformerDataPreparatorBase -from .similarity import SimilarityModuleBase, SimilarityModuleDistance from .torch_backbone import TransformerTorchBackbone # #### -------------- Lightning Base Model -------------- #### # @@ -56,8 +54,6 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma Name of the training loss. val_loss_name : str, default "val_loss" Name of the training loss. - u2i_dist : Distance, default Distance.DOT - U2I distance metric. """ u2i_dist_available = [Distance.DOT, Distance.COSINE] @@ -74,7 +70,6 @@ def __init__( lr: float, gbce_t: float, loss: str, - similarity_model: SimilarityModuleBase, verbose: int = 0, train_loss_name: str = "train_loss", val_loss_name: str = "val_loss", @@ -95,7 +90,6 @@ def __init__( self.verbose = verbose self.train_loss_name = train_loss_name self.val_loss_name = val_loss_name - self.similarity_model = similarity_model self.item_embs: torch.Tensor self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) @@ -157,15 +151,17 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to """Training step.""" x, y, w = batch["x"], batch["y"], batch["yw"] if self.loss == "softmax": - logits = self._get_full_catalog_logits(x) + _, _, logits = self.torch_model(sessions=x) loss = self._calc_softmax_loss(logits, y, w) elif self.loss == "BCE": negatives = batch["negatives"] - logits = self._get_pos_neg_logits(x, y, negatives) + pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] + _, _, logits = self.torch_model(sessions=x, item_ids=pos_neg) loss = self._calc_bce_loss(logits, y, w) elif self.loss == "gBCE": negatives = batch["negatives"] - logits = self._get_pos_neg_logits(x, y, negatives) + pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] + _, _, logits = self.torch_model(sessions=x, item_ids=pos_neg) loss = self._calc_gbce_loss(logits, y, w, negatives) else: loss = self._calc_custom_loss(batch, batch_idx) @@ -196,17 +192,19 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> x, y, w = batch["x"], batch["y"], batch["yw"] outputs = {} if self.loss == "softmax": - logits = self._get_full_catalog_logits(x)[:, -1:, :] + _, _, logits = self.torch_model(sessions=x, last_n_items=1) outputs["loss"] = self._calc_softmax_loss(logits, y, w) outputs["logits"] = logits.squeeze() elif self.loss == "BCE": negatives = batch["negatives"] - pos_neg_logits = self._get_pos_neg_logits(x, y, negatives)[:, -1:, :] + pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] + _, _, pos_neg_logits = self.torch_model(sessions=x, item_ids=pos_neg, last_n_items=1) outputs["loss"] = self._calc_bce_loss(pos_neg_logits, y, w) outputs["pos_neg_logits"] = pos_neg_logits.squeeze() elif self.loss == "gBCE": negatives = batch["negatives"] - pos_neg_logits = self._get_pos_neg_logits(x, y, negatives)[:, -1:, :] + pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] + _, _, pos_neg_logits = self.torch_model(sessions=x, item_ids=pos_neg, last_n_items=1) outputs["loss"] = self._calc_gbce_loss(pos_neg_logits, y, w, negatives) outputs["pos_neg_logits"] = pos_neg_logits.squeeze() else: @@ -220,24 +218,6 @@ def _calc_custom_loss_outputs( ) -> tp.Dict[str, torch.Tensor]: raise ValueError(f"loss {self.loss} is not supported") # pragma: no cover - def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: - embeddings = embeddings / (torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) + self.epsilon_cosine_dist) - return embeddings - - def _get_full_catalog_logits(self, x: torch.Tensor) -> torch.Tensor: - item_embs, session_embs = self.torch_model(x) - logits = self.similarity_model(session_embs, item_embs) - return logits - - def _get_pos_neg_logits(self, x: torch.Tensor, y: torch.Tensor, negatives: torch.Tensor) -> torch.Tensor: - # [n_items + n_item_extra_tokens, n_factors], [batch_size, session_max_len, n_factors] - item_embs, session_embs = self.torch_model(x) - pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] - pos_neg_embs = item_embs[pos_neg] # [batch_size, session_max_len, n_negatives + 1, n_factors] - # [batch_size, session_max_len, n_negatives + 1] - logits = self.similarity_model(session_embs, pos_neg_embs) - return logits - def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int, n_negatives: int) -> torch.Tensor: # https://arxiv.org/pdf/2308.07192.pdf @@ -348,23 +328,14 @@ def _recommend_u2i( user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device) - if isinstance(self.similarity_model, SimilarityModuleDistance): - ranker = TorchRanker( - distance=self.similarity_model.dist, - device=item_embs.device, - subjects_factors=user_embs[user_ids], - objects_factors=item_embs, - ) - else: - raise NotImplementedError() - - user_ids_indices, all_reco_ids, all_scores = ranker.rank( - subject_ids=np.arange(len(user_ids)), # n_rec_users + all_user_ids, all_reco_ids, all_scores = self.torch_model._recommend_u2i( + user_embs=user_embs, + item_embs=item_embs, + user_ids=user_ids, k=k, - filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + n_item_extra_tokens] - sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal + sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, + ui_csr_for_filter=ui_csr_for_filter, ) - all_user_ids = user_ids[user_ids_indices] return all_user_ids, all_reco_ids, all_scores def _recommend_i2i( diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 457b9c0a..2b35c994 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -322,8 +322,6 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): SASRec training task ("Shifted Sequence") does not work without causal masking. Set this parameter to ``False`` only when you change the training task with custom `data_preparator_type` or if you are absolutely sure of what you are doing. - u2i_dist : Distance, default Distance.DOT - U2I distance metric. item_net_block_types : sequence of `type(ItemNetBase)`, default `(IdEmbeddingsItemNet, CatFeaturesItemNet)` Type of network returning item embeddings. (IdEmbeddingsItemNet,) - item embeddings based on ids. @@ -339,6 +337,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): Type of data preparator used for dataset processing and dataloader creation. lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule` Type of lightning module defining training procedure. + similarity_module_type : type(SimilarityModuleBase), default `SimilarityModuleBase` + Type of similarity module. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` @@ -372,6 +372,9 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): lightning_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `lightning_module_type` initialization. Make sure all dict values have JSON serializable types. + similarity_module_kwargs: optional(dict), default ``None`` + Additional keyword arguments to pass during `similarity_module_type` initialization. + Make sure all dict values have JSON serializable types. """ config_class = SASRecModelConfig diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 4e2c3655..e0b41cf6 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -1,48 +1,110 @@ import typing as tp +import numpy as np import torch import torch.nn as nn +from scipy import sparse -from rectools.models.rank import Distance +from rectools.models.base import InternalRecoTriplet +from rectools.models.rank import Distance, TorchRanker +from rectools.types import InternalIdsArray class SimilarityModuleBase(nn.Module): + """Similarity module base.""" - def __init__(self, loss_type: str, *args: tp.Any, **kwargs: tp.Any) -> None: - self.loss_type = loss_type + def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + def _get_pos_neg_logits( + self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: torch.Tensor + ) -> torch.Tensor: + raise NotImplementedError() + + def forward( + self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Forward pass to get logits.""" + raise NotImplementedError() - def forward(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + def _recommend_u2i( + self, + user_embs: torch.Tensor, + item_embs: torch.Tensor, + user_ids: InternalIdsArray, + k: int, + sorted_item_ids_to_recommend: InternalIdsArray, + ui_csr_for_filter: tp.Optional[sparse.csr_matrix], + ) -> InternalRecoTriplet: + """Recommend to users.""" raise NotImplementedError() class SimilarityModuleDistance(SimilarityModuleBase): + """Similarity module distance.""" - dist_available: tp.List[Distance] = [Distance.DOT, Distance.COSINE] + dist_available: tp.List[str] = ["dot", "cosine"] + # dist_available_values: tp.List[int] = [Distance.DOT.value, Distance.COSINE.value] epsilon_cosine_dist: float = 1e-8 - def __init__(self, loss_type: str, dist: Distance = Distance.DOT) -> None: + def __init__(self, dist: str = "dot") -> None: + super().__init__() if dist not in self.dist_available: - raise ValueError("`dist` can only be either `Distance.DOT` or `Distance.COSINE`.") + raise ValueError("`dist` can only be either `dot` or `cosine`.") self.dist = dist - self.loss_type = loss_type + + def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + logits = session_embs @ item_embs.T + return logits + + def _get_pos_neg_logits( + self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: torch.Tensor + ) -> torch.Tensor: + pos_neg_embs = item_embs[item_ids] # [batch_size, session_max_len, len(item_ids), n_factors] + # [batch_size, session_max_len,len(item_ids)] + logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) + return logits def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: embeddings = embeddings / (torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) + self.epsilon_cosine_dist) return embeddings - def _calc_custom_score(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: # TODO - raise ValueError(f"loss {self.loss} is not supported in `DistanceSimilarity`") # pragma: no cover - - def forward(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - if self.dist == Distance.COSINE: + def forward( + self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Forward pass to get logits.""" + if self.dist == "cosine": session_embs = self._get_embeddings_norm(session_embs) item_embs = self._get_embeddings_norm(item_embs) - if self.loss_type == "softmax": - scores = session_embs @ item_embs.T - elif self.loss_type in ["BCE", "gBCE"]: - scores = (item_embs @ session_embs.unsqueeze(-1)).squeeze(-1) - else: # TODO: think about it - scores = self._calc_custom_score(session_embs, item_embs) # pragma: no cover - return scores + if item_ids is None: + logits = self._get_full_catalog_logits(session_embs, item_embs) + else: + logits = self._get_pos_neg_logits(session_embs, item_embs, item_ids) + return logits + + def _recommend_u2i( + self, + user_embs: torch.Tensor, + item_embs: torch.Tensor, + user_ids: InternalIdsArray, + k: int, + sorted_item_ids_to_recommend: InternalIdsArray, + ui_csr_for_filter: tp.Optional[sparse.csr_matrix], + ) -> InternalRecoTriplet: + """Recommend to users.""" + ranker = TorchRanker( + distance=Distance.DOT if self.dist == "dot" else Distance.COSINE, + device=item_embs.device, + subjects_factors=user_embs[user_ids], + objects_factors=item_embs, + ) + user_ids_indices, all_reco_ids, all_scores = ranker.rank( + subject_ids=np.arange(len(user_ids)), # n_rec_users + k=k, + filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + n_item_extra_tokens] + sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal + ) + all_user_ids = user_ids[user_ids_indices] + return all_user_ids, all_reco_ids, all_scores diff --git a/rectools/models/nn/transformers/torch_backbone.py b/rectools/models/nn/transformers/torch_backbone.py index e302ded8..4169ce2b 100644 --- a/rectools/models/nn/transformers/torch_backbone.py +++ b/rectools/models/nn/transformers/torch_backbone.py @@ -15,9 +15,14 @@ import typing as tp import torch +from scipy import sparse + +from rectools.models.base import InternalRecoTriplet +from rectools.types import InternalIdsArray from ..item_net import ItemNetBase from .net_blocks import PositionalEncodingBase, TransformerLayersBase +from .similarity import SimilarityModuleBase class TransformerTorchBackbone(torch.nn.Module): @@ -36,6 +41,8 @@ class TransformerTorchBackbone(torch.nn.Module): Positional encoding layer. transformer_layers : TransformerLayersBase Transformer layers. + similarity_module : SimilarityModuleBase + Similarity module. use_causal_attn : bool, default True If ``True``, causal mask is used in multi-head self-attention. use_key_padding_mask : bool, default False @@ -49,6 +56,7 @@ def __init__( item_model: ItemNetBase, pos_encoding_layer: PositionalEncodingBase, transformer_layers: TransformerLayersBase, + similarity_module: SimilarityModuleBase, use_causal_attn: bool = True, use_key_padding_mask: bool = False, ) -> None: @@ -58,6 +66,7 @@ def __init__( self.pos_encoding_layer = pos_encoding_layer self.emb_dropout = torch.nn.Dropout(dropout_rate) self.transformer_layers = transformer_layers + self.similarity_module = similarity_module self.use_causal_attn = use_causal_attn self.use_key_padding_mask = use_key_padding_mask self.n_heads = n_heads @@ -157,8 +166,10 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to def forward( self, - sessions: torch.Tensor, # [batch_size, session_max_len] - ) -> tp.Tuple[torch.Tensor, torch.Tensor]: + sessions: torch.Tensor, # [batch_size, session_max_len], + item_ids: tp.Optional[torch.Tensor] = None, + last_n_items: tp.Optional[int] = None, + ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass to get item and session embeddings. Get item embeddings. @@ -168,11 +179,34 @@ def forward( ---------- sessions : torch.Tensor User sessions in the form of sequences of items ids. + TODO Returns ------- - (torch.Tensor, torch.Tensor) + (torch.Tensor, torch.Tensor, torch.Tensor) """ item_embs = self.item_model.get_all_embeddings() # [n_items + n_item_extra_tokens, n_factors] session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] - return item_embs, session_embs + logits = self.similarity_module(session_embs, item_embs, item_ids) + if last_n_items: + logits = logits[:, -last_n_items:, :] + return item_embs, session_embs, logits + + def _recommend_u2i( + self, + user_embs: torch.Tensor, + item_embs: torch.Tensor, + user_ids: InternalIdsArray, + k: int, + sorted_item_ids_to_recommend: InternalIdsArray, + ui_csr_for_filter: tp.Optional[sparse.csr_matrix], + ) -> InternalRecoTriplet: + """Recommend to users.""" + return self.similarity_module._recommend_u2i( + user_embs=user_embs, + item_embs=item_embs, + user_ids=user_ids, + k=k, + sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, + ui_csr_for_filter=ui_csr_for_filter, + ) diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index b8aef5c7..5fa01465 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -28,7 +28,7 @@ from rectools.models import BERT4RecModel, SASRecModel, load_model from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet from rectools.models.nn.transformers.base import TransformerModelBase -from rectools.models.rank import Distance +from rectools.models.nn.transformers.similarity import SimilarityModuleDistance from tests.models.data import INTERACTIONS from tests.models.utils import assert_save_load_do_not_change_model @@ -113,6 +113,7 @@ def test_save_load_for_unfitted_model( config = { "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), + "similarity_module_type": SimilarityModuleDistance, } if not default_trainer: config["get_trainer_func"] = custom_trainer @@ -151,6 +152,7 @@ def test_save_load_for_fitted_model( config = { "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), + "similarity_module_type": SimilarityModuleDistance, } if not default_trainer: config["get_trainer_func"] = custom_trainer @@ -172,6 +174,7 @@ def test_load_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, + "similarity_module_type": SimilarityModuleDistance, } ) dataset = request.getfixturevalue(test_dataset) @@ -198,6 +201,7 @@ def test_raises_when_save_model_loaded_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, + "similarity_module_type": SimilarityModuleDistance, } ) model.fit(dataset) @@ -222,6 +226,7 @@ def test_load_weights_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_multiple_ckpt, + "similarity_module_type": SimilarityModuleDistance, } ) model.fit(dataset) @@ -247,6 +252,7 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, + "similarity_module_type": SimilarityModuleDistance, } ) model.fit(dataset) @@ -260,6 +266,7 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, + "similarity_module_type": SimilarityModuleDistance, } ) with pytest.raises(RuntimeError): @@ -302,6 +309,7 @@ def test_log_metrics( "verbose": verbose, "get_val_mask_func": get_val_mask_func, "loss": loss, + "similarity_module_type": SimilarityModuleDistance, } ) model._trainer = trainer # pylint: disable=protected-access @@ -320,8 +328,13 @@ def test_log_metrics( assert actual_columns == expected_columns @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) - def test_raises_when_incorrect_u2i_dist(self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset) -> None: - model_config = {"u2i_dist": Distance.EUCLIDEAN} + def test_raises_when_incorrect_similarity_dist( + self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset + ) -> None: + model_config = { + "similarity_module_type": SimilarityModuleDistance, + "similarity_module_kwargs": {"dist": "euclidean"}, + } with pytest.raises(ValueError): model = model_cls.from_config(model_config) model.fit(dataset=dataset) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index b875d2e9..7a75590d 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -33,7 +33,7 @@ TransformerLightningModule, ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable -from rectools.models.rank import Distance +from rectools.models.nn.transformers.similarity import SimilarityModuleDistance from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -213,7 +213,7 @@ def get_trainer() -> Trainer: ), ), ) - @pytest.mark.parametrize("u2i_dist", (Distance.DOT, Distance.COSINE)) + @pytest.mark.parametrize("u2i_dist", ("dot", "cosine")) def test_u2i( self, dataset_devices: Dataset, @@ -225,7 +225,7 @@ def test_u2i( expected_cpu_2: pd.DataFrame, expected_gpu_1: pd.DataFrame, expected_gpu_2: pd.DataFrame, - u2i_dist: Distance, + u2i_dist: str, ) -> None: if n_devices != 1: pytest.skip("DEBUG: skipping multi-device tests") @@ -252,7 +252,8 @@ def get_trainer() -> Trainer: recommend_torch_device=recommend_torch_device, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer, - u2i_dist=u2i_dist, + similarity_module_type=SimilarityModuleDistance, + similarity_module_kwargs={"dist": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -296,14 +297,14 @@ def get_trainer() -> Trainer: ), ), ) - @pytest.mark.parametrize("u2i_dist", (Distance.DOT, Distance.COSINE)) + @pytest.mark.parametrize("u2i_dist", ("dot", "cosine")) def test_u2i_losses( self, dataset_devices: Dataset, loss: str, get_trainer_func: TrainerCallable, expected: pd.DataFrame, - u2i_dist: Distance, + u2i_dist: str, ) -> None: model = BERT4RecModel( n_negatives=2, @@ -319,7 +320,8 @@ def test_u2i_losses( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, loss=loss, - u2i_dist=u2i_dist, + similarity_module_type=SimilarityModuleDistance, + similarity_module_kwargs={"dist": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -373,6 +375,7 @@ def test_with_whitelist( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -447,6 +450,7 @@ def test_i2i( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset) target_items = np.array([12, 14, 17]) @@ -473,6 +477,7 @@ def test_second_fit_refits_model(self, dataset_hot_users_items: Dataset) -> None deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=custom_trainer, + similarity_module_type=SimilarityModuleDistance, ) assert_second_fit_refits_model(model, dataset_hot_users_items, pre_fit_callback=self._seed_everything) @@ -515,6 +520,7 @@ def test_recommend_for_cold_user_with_hot_item( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset_devices) users = np.array([20]) @@ -592,6 +598,7 @@ def _collate_fn_train( get_trainer_func=get_trainer_func, data_preparator_type=NextActionDataPreparator, data_preparator_kwargs={"n_last_targets": 1}, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset_devices) @@ -830,13 +837,13 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "recommend_torch_device": None, "recommend_batch_size": 256, "train_min_user_interactions": 2, - "u2i_dist": Distance.DOT, "item_net_block_types": (IdEmbeddingsItemNet,), "item_net_constructor_type": SumOfEmbeddingsConstructor, "pos_encoding_type": LearnableInversePositionalEncoding, "transformer_layers_type": PreLNTransformerLayers, "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": TransformerLightningModule, + "similarity_module_type": SimilarityModuleDistance, "mask_prob": 0.15, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, @@ -845,6 +852,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, + "similarity_module_kwargs": {"dist": "dot"}, } return config @@ -884,7 +892,8 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", - "u2i_dist": Distance.DOT.value, + "similarity_module_type": "rectools.models.nn.transformers.similarity.SimilarityModuleDistance", + "similarity_module_kwargs": {"dist": "dot"}, } expected.update(simple_types_params) if use_custom_trainer: diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 60d3b9e6..a950e8da 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -35,7 +35,7 @@ TransformerTorchBackbone, ) from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers -from rectools.models.rank import Distance +from rectools.models.nn.transformers.similarity import SimilarityModuleDistance from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -244,7 +244,7 @@ def get_trainer() -> Trainer: ), ), ) - @pytest.mark.parametrize("u2i_dist", (Distance.DOT, Distance.COSINE)) + @pytest.mark.parametrize("u2i_dist", ("dot", "cosine")) def test_u2i( self, dataset_devices: Dataset, @@ -255,7 +255,7 @@ def test_u2i( expected_cpu_1: pd.DataFrame, expected_cpu_2: pd.DataFrame, expected_gpu: pd.DataFrame, - u2i_dist: Distance, + u2i_dist: str, ) -> None: if devices != 1: @@ -283,7 +283,8 @@ def get_trainer() -> Trainer: recommend_torch_device=recommend_torch_device, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer, - u2i_dist=u2i_dist, + similarity_module_type=SimilarityModuleDistance, + similarity_module_kwargs={"dist": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -312,7 +313,7 @@ def get_trainer() -> Trainer: Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), - Distance.DOT, + "dot", ), ( "gBCE", @@ -323,7 +324,7 @@ def get_trainer() -> Trainer: Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), - Distance.DOT, + "dot", ), ( "BCE", @@ -334,7 +335,7 @@ def get_trainer() -> Trainer: Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), - Distance.COSINE, + "cosine", ), ( "gBCE", @@ -345,7 +346,7 @@ def get_trainer() -> Trainer: Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), - Distance.COSINE, + "cosine", ), ), ) @@ -355,7 +356,7 @@ def test_u2i_losses( loss: str, get_trainer_func: TrainerCallable, expected: pd.DataFrame, - u2i_dist: Distance, + u2i_dist: str, ) -> None: model = SASRecModel( n_negatives=2, @@ -369,7 +370,8 @@ def test_u2i_losses( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, loss=loss, - u2i_dist=u2i_dist, + similarity_module_type=SimilarityModuleDistance, + similarity_module_kwargs={"dist": u2i_dist}, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -410,6 +412,7 @@ def test_u2i_with_key_and_attn_masks( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, use_key_padding_mask=True, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -450,6 +453,7 @@ def test_u2i_with_item_features( item_net_block_types=(IdEmbeddingsItemNet, CatFeaturesItemNet), get_trainer_func=get_trainer_func, use_key_padding_mask=True, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset_item_features) users = np.array([10, 30, 40]) @@ -502,6 +506,7 @@ def test_with_whitelist( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -575,6 +580,7 @@ def test_i2i( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset) target_items = np.array([12, 14, 17]) @@ -601,6 +607,7 @@ def test_second_fit_refits_model(self, dataset_hot_users_items: Dataset) -> None deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=custom_trainer, + similarity_module_type=SimilarityModuleDistance, ) assert_second_fit_refits_model(model, dataset_hot_users_items, pre_fit_callback=self._seed_everything) @@ -642,6 +649,7 @@ def test_recommend_for_cold_user_with_hot_item( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset) users = np.array([20]) @@ -695,6 +703,7 @@ def test_warn_when_hot_user_has_cold_items_in_recommend( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=SimilarityModuleDistance, ) model.fit(dataset=dataset) users = np.array([10, 20, 50]) @@ -718,12 +727,12 @@ def test_warn_when_hot_user_has_cold_items_in_recommend( ) def test_raises_when_loss_is_not_supported(self, dataset: Dataset) -> None: - model = SASRecModel(loss="gbce") + model = SASRecModel(loss="gbce", similarity_module_type=SimilarityModuleDistance) with pytest.raises(ValueError): model.fit(dataset=dataset) def test_torch_model(self, dataset: Dataset) -> None: - model = SASRecModel() + model = SASRecModel(similarity_module_type=SimilarityModuleDistance) model.fit(dataset) assert isinstance(model.torch_model, TransformerTorchBackbone) @@ -925,13 +934,13 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "recommend_torch_device": None, "recommend_batch_size": 256, "train_min_user_interactions": 2, - "u2i_dist": Distance.DOT, "item_net_block_types": (IdEmbeddingsItemNet,), "item_net_constructor_type": SumOfEmbeddingsConstructor, "pos_encoding_type": LearnableInversePositionalEncoding, "transformer_layers_type": SASRecTransformerLayers, "data_preparator_type": SASRecDataPreparator, "lightning_module_type": TransformerLightningModule, + "similarity_module_type": SimilarityModuleDistance, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, "data_preparator_kwargs": None, @@ -939,6 +948,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, + "similarity_module_kwargs": {"dist": "dot"}, } return config @@ -978,7 +988,8 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.sasrec.SASRecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", - "u2i_dist": Distance.DOT.value, + "similarity_module_type": "rectools.models.nn.transformers.similarity.SimilarityModuleDistance", + "similarity_module_kwargs": {"dist": "dot"}, } expected.update(simple_types_params) if use_custom_trainer: @@ -1013,12 +1024,13 @@ def get_reco(model: SASRecModel) -> pd.DataFrame: model_1 = model.from_config(initial_config) reco_1 = get_reco(model_1) - config_1 = model_1.get_config(simple_types=simple_types) + config_1 = model_1.get_config(mode="pydantic", simple_types=False) + print(f"CONFIG 1: {config_1}") self._seed_everything() model_2 = model.from_config(config_1) reco_2 = get_reco(model_2) - config_2 = model_2.get_config(simple_types=simple_types) + config_2 = model_2.get_config(mode="pydantic", simple_types=False) assert config_1 == config_2 pd.testing.assert_frame_equal(reco_1, reco_2) From 846ba465c54377e9d6691b4ca99bc6b7b9c8851b Mon Sep 17 00:00:00 2001 From: In48semenov Date: Thu, 27 Mar 2025 12:13:30 +0300 Subject: [PATCH 10/30] Renamed similaty classes, added properties for requiers negs, fixed comments. --- rectools/models/nn/transformers/base.py | 6 +- rectools/models/nn/transformers/bert4rec.py | 6 +- rectools/models/nn/transformers/lightning.py | 58 +++++++++++++------ rectools/models/nn/transformers/sasrec.py | 6 +- rectools/models/nn/transformers/similarity.py | 23 ++++---- .../models/nn/transformers/torch_backbone.py | 35 ++--------- tests/models/nn/transformers/test_base.py | 22 +++---- tests/models/nn/transformers/test_bert4rec.py | 27 +++++---- tests/models/nn/transformers/test_sasrec.py | 40 ++++++------- 9 files changed, 108 insertions(+), 115 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 27192738..775bbbcf 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -46,7 +46,7 @@ PreLNTransformerLayers, TransformerLayersBase, ) -from .similarity import SimilarityModuleBase +from .similarity import DistanceSimilarityModule, SimilarityModuleBase from .torch_backbone import TransformerTorchBackbone InitKwargs = tp.Dict[str, tp.Any] @@ -194,7 +194,7 @@ class TransformerModelConfig(ModelConfig): pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: TransformerLightningModuleType = TransformerLightningModule - similarity_module_type: SimilarityModuleType = SimilarityModuleBase + similarity_module_type: SimilarityModuleType = DistanceSimilarityModule get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None @@ -250,7 +250,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, - similarity_module_type: tp.Type[SimilarityModuleBase] = SimilarityModuleBase, + similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, data_preparator_kwargs: tp.Optional[InitKwargs] = None, diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index f88b7833..c2f2d814 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -44,7 +44,7 @@ PreLNTransformerLayers, TransformerLayersBase, ) -from .similarity import SimilarityModuleBase +from .similarity import DistanceSimilarityModule, SimilarityModuleBase class BERT4RecDataPreparator(TransformerDataPreparatorBase): @@ -257,7 +257,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): Type of data preparator used for dataset processing and dataloader creation. lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule` Type of lightning module defining training procedure. - similarity_module_type : type(SimilarityModuleBase), default `SimilarityModuleBase` + similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. get_val_mask_func : Callable, default ``None`` Function to get validation mask. @@ -326,7 +326,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers, data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, - similarity_module_type: tp.Type[SimilarityModuleBase] = SimilarityModuleBase, + similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index fc8316e0..5949abed 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -94,6 +94,20 @@ def __init__( self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) + @property + def requires_negatives(self) -> tp.Optional[bool]: + """Indicator for determining the need for negatives for loss functions.""" + if self.loss == "softmax": + return False + + if self.loss == "BCE": + return True + + if self.loss == "gBCE": + return True + + return None + def configure_optimizers(self) -> torch.optim.Adam: """Choose what optimizers and learning-rate schedulers to use in optimization""" optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas) @@ -150,18 +164,20 @@ def on_train_start(self) -> None: def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Training step.""" x, y, w = batch["x"], batch["y"], batch["yw"] + + if self.requires_negatives: + negatives = batch["negatives"] + # [batch_size, session_max_len, n_negatives + 1] + pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) + logits = self.torch_model(sessions=x, item_ids=pos_neg) + elif self.requires_negatives is not None: + logits = self.torch_model(sessions=x) + if self.loss == "softmax": - _, _, logits = self.torch_model(sessions=x) loss = self._calc_softmax_loss(logits, y, w) elif self.loss == "BCE": - negatives = batch["negatives"] - pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] - _, _, logits = self.torch_model(sessions=x, item_ids=pos_neg) loss = self._calc_bce_loss(logits, y, w) elif self.loss == "gBCE": - negatives = batch["negatives"] - pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] - _, _, logits = self.torch_model(sessions=x, item_ids=pos_neg) loss = self._calc_gbce_loss(logits, y, w, negatives) else: loss = self._calc_custom_loss(batch, batch_idx) @@ -190,23 +206,29 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> # y: [batch_size, 1] # yw: [batch_size, 1] x, y, w = batch["x"], batch["y"], batch["yw"] + + if self.requires_negatives: + negatives = batch["negatives"] + # [batch_size, session_max_len, n_negatives + 1] + pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) + logits = self.torch_model(sessions=x, item_ids=pos_neg) + elif self.requires_negatives is not None: + logits = self.torch_model(sessions=x) + + if self.requires_negatives is not None: + logits = logits[:, -1:, :] + outputs = {} if self.loss == "softmax": - _, _, logits = self.torch_model(sessions=x, last_n_items=1) outputs["loss"] = self._calc_softmax_loss(logits, y, w) outputs["logits"] = logits.squeeze() elif self.loss == "BCE": negatives = batch["negatives"] - pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] - _, _, pos_neg_logits = self.torch_model(sessions=x, item_ids=pos_neg, last_n_items=1) - outputs["loss"] = self._calc_bce_loss(pos_neg_logits, y, w) - outputs["pos_neg_logits"] = pos_neg_logits.squeeze() + outputs["loss"] = self._calc_bce_loss(logits, y, w) + outputs["pos_neg_logits"] = logits.squeeze() elif self.loss == "gBCE": - negatives = batch["negatives"] - pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] - _, _, pos_neg_logits = self.torch_model(sessions=x, item_ids=pos_neg, last_n_items=1) - outputs["loss"] = self._calc_gbce_loss(pos_neg_logits, y, w, negatives) - outputs["pos_neg_logits"] = pos_neg_logits.squeeze() + outputs["loss"] = self._calc_gbce_loss(logits, y, w, negatives) + outputs["pos_neg_logits"] = logits.squeeze() else: outputs = self._calc_custom_loss_outputs(batch, batch_idx) # pragma: no cover @@ -328,7 +350,7 @@ def _recommend_u2i( user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device) - all_user_ids, all_reco_ids, all_scores = self.torch_model._recommend_u2i( + all_user_ids, all_reco_ids, all_scores = self.torch_model.similarity_module._recommend_u2i( user_embs=user_embs, item_embs=item_embs, user_ids=user_ids, diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 2b35c994..4bc36907 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -44,7 +44,7 @@ PositionalEncodingBase, TransformerLayersBase, ) -from .similarity import SimilarityModuleBase +from .similarity import DistanceSimilarityModule, SimilarityModuleBase class SASRecDataPreparator(TransformerDataPreparatorBase): @@ -337,7 +337,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): Type of data preparator used for dataset processing and dataloader creation. lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule` Type of lightning module defining training procedure. - similarity_module_type : type(SimilarityModuleBase), default `SimilarityModuleBase` + similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. get_val_mask_func : Callable, default ``None`` Function to get validation mask. @@ -405,7 +405,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, # SASRec authors net data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, - similarity_module_type: tp.Type[SimilarityModuleBase] = SimilarityModuleBase, + similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index e0b41cf6..0f73aef6 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -2,7 +2,6 @@ import numpy as np import torch -import torch.nn as nn from scipy import sparse from rectools.models.base import InternalRecoTriplet @@ -10,7 +9,7 @@ from rectools.types import InternalIdsArray -class SimilarityModuleBase(nn.Module): +class SimilarityModuleBase(torch.nn.Module): """Similarity module base.""" def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: @@ -40,19 +39,18 @@ def _recommend_u2i( raise NotImplementedError() -class SimilarityModuleDistance(SimilarityModuleBase): - """Similarity module distance.""" +class DistanceSimilarityModule(SimilarityModuleBase): + """Distandce similarity module.""" dist_available: tp.List[str] = ["dot", "cosine"] - # dist_available_values: tp.List[int] = [Distance.DOT.value, Distance.COSINE.value] - epsilon_cosine_dist: float = 1e-8 + epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) - def __init__(self, dist: str = "dot") -> None: + def __init__(self, distance: str = "dot") -> None: super().__init__() - if dist not in self.dist_available: + if distance not in self.dist_available: raise ValueError("`dist` can only be either `dot` or `cosine`.") - self.dist = dist + self.distance = distance def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: logits = session_embs @ item_embs.T @@ -67,14 +65,15 @@ def _get_pos_neg_logits( return logits def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: - embeddings = embeddings / (torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) + self.epsilon_cosine_dist) + embedding_norm = torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) + embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist) return embeddings def forward( self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None ) -> torch.Tensor: """Forward pass to get logits.""" - if self.dist == "cosine": + if self.distance == "cosine": session_embs = self._get_embeddings_norm(session_embs) item_embs = self._get_embeddings_norm(item_embs) @@ -95,7 +94,7 @@ def _recommend_u2i( ) -> InternalRecoTriplet: """Recommend to users.""" ranker = TorchRanker( - distance=Distance.DOT if self.dist == "dot" else Distance.COSINE, + distance=Distance.DOT if self.distance == "dot" else Distance.COSINE, device=item_embs.device, subjects_factors=user_embs[user_ids], objects_factors=item_embs, diff --git a/rectools/models/nn/transformers/torch_backbone.py b/rectools/models/nn/transformers/torch_backbone.py index 4169ce2b..ea55962f 100644 --- a/rectools/models/nn/transformers/torch_backbone.py +++ b/rectools/models/nn/transformers/torch_backbone.py @@ -15,10 +15,6 @@ import typing as tp import torch -from scipy import sparse - -from rectools.models.base import InternalRecoTriplet -from rectools.types import InternalIdsArray from ..item_net import ItemNetBase from .net_blocks import PositionalEncodingBase, TransformerLayersBase @@ -168,8 +164,7 @@ def forward( self, sessions: torch.Tensor, # [batch_size, session_max_len], item_ids: tp.Optional[torch.Tensor] = None, - last_n_items: tp.Optional[int] = None, - ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """ Forward pass to get item and session embeddings. Get item embeddings. @@ -179,34 +174,14 @@ def forward( ---------- sessions : torch.Tensor User sessions in the form of sequences of items ids. - TODO + item_ids : optional(torch.Tensor), default ``None`` + Defined item ids for similarity calculation. Returns ------- - (torch.Tensor, torch.Tensor, torch.Tensor) + torch.Tensor """ item_embs = self.item_model.get_all_embeddings() # [n_items + n_item_extra_tokens, n_factors] session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] logits = self.similarity_module(session_embs, item_embs, item_ids) - if last_n_items: - logits = logits[:, -last_n_items:, :] - return item_embs, session_embs, logits - - def _recommend_u2i( - self, - user_embs: torch.Tensor, - item_embs: torch.Tensor, - user_ids: InternalIdsArray, - k: int, - sorted_item_ids_to_recommend: InternalIdsArray, - ui_csr_for_filter: tp.Optional[sparse.csr_matrix], - ) -> InternalRecoTriplet: - """Recommend to users.""" - return self.similarity_module._recommend_u2i( - user_embs=user_embs, - item_embs=item_embs, - user_ids=user_ids, - k=k, - sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, - ui_csr_for_filter=ui_csr_for_filter, - ) + return logits diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index 5fa01465..16ef53df 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -28,7 +28,7 @@ from rectools.models import BERT4RecModel, SASRecModel, load_model from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet from rectools.models.nn.transformers.base import TransformerModelBase -from rectools.models.nn.transformers.similarity import SimilarityModuleDistance +from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from tests.models.data import INTERACTIONS from tests.models.utils import assert_save_load_do_not_change_model @@ -113,7 +113,7 @@ def test_save_load_for_unfitted_model( config = { "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, } if not default_trainer: config["get_trainer_func"] = custom_trainer @@ -152,7 +152,7 @@ def test_save_load_for_fitted_model( config = { "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, } if not default_trainer: config["get_trainer_func"] = custom_trainer @@ -174,7 +174,7 @@ def test_load_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, } ) dataset = request.getfixturevalue(test_dataset) @@ -201,7 +201,7 @@ def test_raises_when_save_model_loaded_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, } ) model.fit(dataset) @@ -226,7 +226,7 @@ def test_load_weights_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_multiple_ckpt, - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, } ) model.fit(dataset) @@ -252,7 +252,7 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, } ) model.fit(dataset) @@ -266,7 +266,7 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, } ) with pytest.raises(RuntimeError): @@ -309,7 +309,7 @@ def test_log_metrics( "verbose": verbose, "get_val_mask_func": get_val_mask_func, "loss": loss, - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, } ) model._trainer = trainer # pylint: disable=protected-access @@ -332,8 +332,8 @@ def test_raises_when_incorrect_similarity_dist( self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset ) -> None: model_config = { - "similarity_module_type": SimilarityModuleDistance, - "similarity_module_kwargs": {"dist": "euclidean"}, + "similarity_module_type": DistanceSimilarityModule, + "similarity_module_kwargs": {"distance": "euclidean"}, } with pytest.raises(ValueError): model = model_cls.from_config(model_config) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 7a75590d..7480d1d6 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -33,7 +33,7 @@ TransformerLightningModule, ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable -from rectools.models.nn.transformers.similarity import SimilarityModuleDistance +from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -252,8 +252,8 @@ def get_trainer() -> Trainer: recommend_torch_device=recommend_torch_device, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer, - similarity_module_type=SimilarityModuleDistance, - similarity_module_kwargs={"dist": u2i_dist}, + similarity_module_type=DistanceSimilarityModule, + similarity_module_kwargs={"distance": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -320,8 +320,8 @@ def test_u2i_losses( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, loss=loss, - similarity_module_type=SimilarityModuleDistance, - similarity_module_kwargs={"dist": u2i_dist}, + similarity_module_type=DistanceSimilarityModule, + similarity_module_kwargs={"distance": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -375,7 +375,7 @@ def test_with_whitelist( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -450,7 +450,7 @@ def test_i2i( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) target_items = np.array([12, 14, 17]) @@ -477,7 +477,7 @@ def test_second_fit_refits_model(self, dataset_hot_users_items: Dataset) -> None deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=custom_trainer, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) assert_second_fit_refits_model(model, dataset_hot_users_items, pre_fit_callback=self._seed_everything) @@ -520,7 +520,7 @@ def test_recommend_for_cold_user_with_hot_item( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset_devices) users = np.array([20]) @@ -598,7 +598,7 @@ def _collate_fn_train( get_trainer_func=get_trainer_func, data_preparator_type=NextActionDataPreparator, data_preparator_kwargs={"n_last_targets": 1}, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset_devices) @@ -843,7 +843,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "transformer_layers_type": PreLNTransformerLayers, "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": TransformerLightningModule, - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, "mask_prob": 0.15, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, @@ -852,7 +852,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, - "similarity_module_kwargs": {"dist": "dot"}, + "similarity_module_kwargs": None, } return config @@ -892,8 +892,7 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", - "similarity_module_type": "rectools.models.nn.transformers.similarity.SimilarityModuleDistance", - "similarity_module_kwargs": {"dist": "dot"}, + "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", } expected.update(simple_types_params) if use_custom_trainer: diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index a950e8da..189d4e53 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -35,7 +35,7 @@ TransformerTorchBackbone, ) from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers -from rectools.models.nn.transformers.similarity import SimilarityModuleDistance +from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -283,8 +283,8 @@ def get_trainer() -> Trainer: recommend_torch_device=recommend_torch_device, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer, - similarity_module_type=SimilarityModuleDistance, - similarity_module_kwargs={"dist": u2i_dist}, + similarity_module_type=DistanceSimilarityModule, + similarity_module_kwargs={"distance": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -370,8 +370,8 @@ def test_u2i_losses( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, loss=loss, - similarity_module_type=SimilarityModuleDistance, - similarity_module_kwargs={"dist": u2i_dist}, + similarity_module_type=DistanceSimilarityModule, + similarity_module_kwargs={"distance": u2i_dist}, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -412,7 +412,7 @@ def test_u2i_with_key_and_attn_masks( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, use_key_padding_mask=True, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -453,7 +453,7 @@ def test_u2i_with_item_features( item_net_block_types=(IdEmbeddingsItemNet, CatFeaturesItemNet), get_trainer_func=get_trainer_func, use_key_padding_mask=True, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset_item_features) users = np.array([10, 30, 40]) @@ -506,7 +506,7 @@ def test_with_whitelist( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -580,7 +580,7 @@ def test_i2i( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) target_items = np.array([12, 14, 17]) @@ -607,7 +607,7 @@ def test_second_fit_refits_model(self, dataset_hot_users_items: Dataset) -> None deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=custom_trainer, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) assert_second_fit_refits_model(model, dataset_hot_users_items, pre_fit_callback=self._seed_everything) @@ -649,7 +649,7 @@ def test_recommend_for_cold_user_with_hot_item( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) users = np.array([20]) @@ -703,7 +703,7 @@ def test_warn_when_hot_user_has_cold_items_in_recommend( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, - similarity_module_type=SimilarityModuleDistance, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) users = np.array([10, 20, 50]) @@ -727,12 +727,12 @@ def test_warn_when_hot_user_has_cold_items_in_recommend( ) def test_raises_when_loss_is_not_supported(self, dataset: Dataset) -> None: - model = SASRecModel(loss="gbce", similarity_module_type=SimilarityModuleDistance) + model = SASRecModel(loss="gbce", similarity_module_type=DistanceSimilarityModule) with pytest.raises(ValueError): model.fit(dataset=dataset) def test_torch_model(self, dataset: Dataset) -> None: - model = SASRecModel(similarity_module_type=SimilarityModuleDistance) + model = SASRecModel(similarity_module_type=DistanceSimilarityModule) model.fit(dataset) assert isinstance(model.torch_model, TransformerTorchBackbone) @@ -940,7 +940,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "transformer_layers_type": SASRecTransformerLayers, "data_preparator_type": SASRecDataPreparator, "lightning_module_type": TransformerLightningModule, - "similarity_module_type": SimilarityModuleDistance, + "similarity_module_type": DistanceSimilarityModule, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, "data_preparator_kwargs": None, @@ -948,7 +948,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, - "similarity_module_kwargs": {"dist": "dot"}, + "similarity_module_kwargs": None, } return config @@ -988,8 +988,7 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.sasrec.SASRecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", - "similarity_module_type": "rectools.models.nn.transformers.similarity.SimilarityModuleDistance", - "similarity_module_kwargs": {"dist": "dot"}, + "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", } expected.update(simple_types_params) if use_custom_trainer: @@ -1024,13 +1023,12 @@ def get_reco(model: SASRecModel) -> pd.DataFrame: model_1 = model.from_config(initial_config) reco_1 = get_reco(model_1) - config_1 = model_1.get_config(mode="pydantic", simple_types=False) - print(f"CONFIG 1: {config_1}") + config_1 = model_1.get_config(simple_types=simple_types) self._seed_everything() model_2 = model.from_config(config_1) reco_2 = get_reco(model_2) - config_2 = model_2.get_config(mode="pydantic", simple_types=False) + config_2 = model_2.get_config(simple_types=simple_types) assert config_1 == config_2 pd.testing.assert_frame_equal(reco_1, reco_2) From 7e002979fc39775ec41d6c55db035507b2addd36 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Thu, 27 Mar 2025 12:13:51 +0300 Subject: [PATCH 11/30] Added epsilone for cosine similarity. --- rectools/models/rank/rank_torch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rectools/models/rank/rank_torch.py b/rectools/models/rank/rank_torch.py index ed091acc..12ccca19 100644 --- a/rectools/models/rank/rank_torch.py +++ b/rectools/models/rank/rank_torch.py @@ -54,6 +54,8 @@ class TorchRanker: Conversion is skipped if provided dtype is ``None``. """ + epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) + def __init__( self, distance: Distance, @@ -194,8 +196,8 @@ def _euclid_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> tor return torch.cdist(user_embs.unsqueeze(0), item_embs.unsqueeze(0)).squeeze(0) def _cosine_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - user_embs = user_embs / torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1) - item_embs = item_embs / torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1) + user_embs = user_embs / torch.max(torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1), self.epsilon_cosine_dist) + item_embs = item_embs / torch.max(torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1), self.epsilon_cosine_dist) return user_embs @ item_embs.T From 274d1c9fb2ce51d64824eb6ff955513d37f31ffd Mon Sep 17 00:00:00 2001 From: In48semenov Date: Thu, 27 Mar 2025 12:32:16 +0300 Subject: [PATCH 12/30] Added pylint disable. --- rectools/models/nn/transformers/lightning.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 5949abed..5e6d6ed6 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -350,13 +350,15 @@ def _recommend_u2i( user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device) - all_user_ids, all_reco_ids, all_scores = self.torch_model.similarity_module._recommend_u2i( - user_embs=user_embs, - item_embs=item_embs, - user_ids=user_ids, - k=k, - sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, - ui_csr_for_filter=ui_csr_for_filter, + all_user_ids, all_reco_ids, all_scores = ( + self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access + user_embs=user_embs, + item_embs=item_embs, + user_ids=user_ids, + k=k, + sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, + ui_csr_for_filter=ui_csr_for_filter, + ) ) return all_user_ids, all_reco_ids, all_scores From d033e6d42dc4873738dd2c2b99d5ddd6160bda03 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Thu, 27 Mar 2025 13:28:50 +0300 Subject: [PATCH 13/30] Changed Enum Distance. --- rectools/models/nn/transformers/similarity.py | 8 ++++---- rectools/models/rank/rank.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 0f73aef6..b43727cf 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -42,7 +42,7 @@ def _recommend_u2i( class DistanceSimilarityModule(SimilarityModuleBase): """Distandce similarity module.""" - dist_available: tp.List[str] = ["dot", "cosine"] + dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE] epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) def __init__(self, distance: str = "dot") -> None: @@ -50,7 +50,7 @@ def __init__(self, distance: str = "dot") -> None: if distance not in self.dist_available: raise ValueError("`dist` can only be either `dot` or `cosine`.") - self.distance = distance + self.distance = Distance(distance) def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: logits = session_embs @ item_embs.T @@ -73,7 +73,7 @@ def forward( self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None ) -> torch.Tensor: """Forward pass to get logits.""" - if self.distance == "cosine": + if self.distance == Distance.COSINE: session_embs = self._get_embeddings_norm(session_embs) item_embs = self._get_embeddings_norm(item_embs) @@ -94,7 +94,7 @@ def _recommend_u2i( ) -> InternalRecoTriplet: """Recommend to users.""" ranker = TorchRanker( - distance=Distance.DOT if self.distance == "dot" else Distance.COSINE, + distance=self.distance, device=item_embs.device, subjects_factors=user_embs[user_ids], objects_factors=item_embs, diff --git a/rectools/models/rank/rank.py b/rectools/models/rank/rank.py index ab79f80d..10c217f3 100644 --- a/rectools/models/rank/rank.py +++ b/rectools/models/rank/rank.py @@ -22,12 +22,12 @@ from rectools.types import InternalIdsArray -class Distance(Enum): +class Distance(str, Enum): """Distance metric""" - DOT = 1 # Bigger value means closer vectors - COSINE = 2 # Bigger value means closer vectors - EUCLIDEAN = 3 # Smaller value means closer vectors + DOT = "dot" # Bigger value means closer vectors + COSINE = "cosine" # Bigger value means closer vectors + EUCLIDEAN = "euclidean" # Smaller value means closer vectors class Ranker(tp.Protocol): From d434456c91fc98dfad2ac20e4c46b89c822a94aa Mon Sep 17 00:00:00 2001 From: In48semenov Date: Thu, 27 Mar 2025 17:03:14 +0300 Subject: [PATCH 14/30] Removed unnecessary kwargs in test_base.py in transformers. --- rectools/models/nn/transformers/lightning.py | 1 - rectools/models/nn/transformers/similarity.py | 6 ++---- tests/models/nn/transformers/test_base.py | 10 ---------- 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 5e6d6ed6..5626fd88 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -223,7 +223,6 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> outputs["loss"] = self._calc_softmax_loss(logits, y, w) outputs["logits"] = logits.squeeze() elif self.loss == "BCE": - negatives = batch["negatives"] outputs["loss"] = self._calc_bce_loss(logits, y, w) outputs["pos_neg_logits"] = logits.squeeze() elif self.loss == "gBCE": diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index b43727cf..6b23bcc6 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -78,10 +78,8 @@ def forward( item_embs = self._get_embeddings_norm(item_embs) if item_ids is None: - logits = self._get_full_catalog_logits(session_embs, item_embs) - else: - logits = self._get_pos_neg_logits(session_embs, item_embs, item_ids) - return logits + return self._get_full_catalog_logits(session_embs, item_embs) + return self._get_pos_neg_logits(session_embs, item_embs, item_ids) def _recommend_u2i( self, diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index 16ef53df..3666cfd9 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -28,7 +28,6 @@ from rectools.models import BERT4RecModel, SASRecModel, load_model from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet from rectools.models.nn.transformers.base import TransformerModelBase -from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from tests.models.data import INTERACTIONS from tests.models.utils import assert_save_load_do_not_change_model @@ -113,7 +112,6 @@ def test_save_load_for_unfitted_model( config = { "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), - "similarity_module_type": DistanceSimilarityModule, } if not default_trainer: config["get_trainer_func"] = custom_trainer @@ -152,7 +150,6 @@ def test_save_load_for_fitted_model( config = { "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), - "similarity_module_type": DistanceSimilarityModule, } if not default_trainer: config["get_trainer_func"] = custom_trainer @@ -174,7 +171,6 @@ def test_load_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, - "similarity_module_type": DistanceSimilarityModule, } ) dataset = request.getfixturevalue(test_dataset) @@ -201,7 +197,6 @@ def test_raises_when_save_model_loaded_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, - "similarity_module_type": DistanceSimilarityModule, } ) model.fit(dataset) @@ -226,7 +221,6 @@ def test_load_weights_from_checkpoint( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_multiple_ckpt, - "similarity_module_type": DistanceSimilarityModule, } ) model.fit(dataset) @@ -252,7 +246,6 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, - "similarity_module_type": DistanceSimilarityModule, } ) model.fit(dataset) @@ -266,7 +259,6 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( "deterministic": True, "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, - "similarity_module_type": DistanceSimilarityModule, } ) with pytest.raises(RuntimeError): @@ -309,7 +301,6 @@ def test_log_metrics( "verbose": verbose, "get_val_mask_func": get_val_mask_func, "loss": loss, - "similarity_module_type": DistanceSimilarityModule, } ) model._trainer = trainer # pylint: disable=protected-access @@ -332,7 +323,6 @@ def test_raises_when_incorrect_similarity_dist( self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset ) -> None: model_config = { - "similarity_module_type": DistanceSimilarityModule, "similarity_module_kwargs": {"distance": "euclidean"}, } with pytest.raises(ValueError): From bc61feb81db565288660e7343ae0ea925083d626 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Thu, 27 Mar 2025 17:16:21 +0300 Subject: [PATCH 15/30] Removed kwargs `item_net_block_types` from test_base.py in transformers. --- tests/models/nn/transformers/test_base.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index 3666cfd9..a29d8f24 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -26,7 +26,6 @@ from rectools import Columns from rectools.dataset import Dataset from rectools.models import BERT4RecModel, SASRecModel, load_model -from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet from rectools.models.nn.transformers.base import TransformerModelBase from tests.models.data import INTERACTIONS from tests.models.utils import assert_save_load_do_not_change_model @@ -109,10 +108,7 @@ def test_save_load_for_unfitted_model( dataset: Dataset, default_trainer: bool, ) -> None: - config = { - "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), - } + config: tp.Dict[str, tp.Any] = {"deterministic": True} if not default_trainer: config["get_trainer_func"] = custom_trainer model = model_cls.from_config(config) @@ -147,10 +143,7 @@ def test_save_load_for_fitted_model( dataset_item_features: Dataset, default_trainer: bool, ) -> None: - config = { - "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), - } + config: tp.Dict[str, tp.Any] = {"deterministic": True} if not default_trainer: config["get_trainer_func"] = custom_trainer model = model_cls.from_config(config) @@ -169,7 +162,6 @@ def test_load_from_checkpoint( model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, } ) @@ -195,7 +187,6 @@ def test_raises_when_save_model_loaded_from_checkpoint( model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, } ) @@ -219,7 +210,6 @@ def test_load_weights_from_checkpoint( model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_multiple_ckpt, } ) @@ -244,7 +234,6 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, } ) @@ -257,7 +246,6 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( model_unfitted = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, } ) From 57213fd10c873c9d48d6a9b975f01a49cb70fbd9 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Thu, 27 Mar 2025 23:50:00 +0300 Subject: [PATCH 16/30] Put filter train interactions as separate method. --- .../models/nn/transformers/data_preparator.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 2b2a899e..4fbe6b2a 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -168,6 +168,19 @@ def _process_features_for_id_map( full_feature_values = np.vstack([extra_token_feature_values, sorted_features.values]) return DenseFeatures.from_iterables(values=full_feature_values, names=raw_features.names) + def _filter_train_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: + """Filter train interactions""" + train_interactions = interactions.copy() + user_stats = train_interactions[Columns.User].value_counts() + users = user_stats[user_stats >= self.train_min_user_interactions].index + train_interactions = train_interactions[(train_interactions[Columns.User].isin(users))] + train_interactions = ( + train_interactions.sort_values(Columns.Datetime, kind="stable") + .groupby(Columns.User, sort=False) + .tail(self.session_max_len + self.train_session_max_len_addition) + ) + return train_interactions + def process_dataset_train(self, dataset: Dataset) -> None: """Process train dataset and save data.""" raw_interactions = dataset.get_raw_interactions() @@ -179,14 +192,7 @@ def process_dataset_train(self, dataset: Dataset) -> None: interactions = raw_interactions[~val_mask] # Filter train interactions - user_stats = interactions[Columns.User].value_counts() - users = user_stats[user_stats >= self.train_min_user_interactions].index - interactions = interactions[(interactions[Columns.User].isin(users))] - interactions = ( - interactions.sort_values(Columns.Datetime, kind="stable") - .groupby(Columns.User, sort=False) - .tail(self.session_max_len + self.train_session_max_len_addition) - ) + interactions = self._filter_train_interactions(interactions) # Prepare id maps user_id_map = IdMap.from_values(interactions[Columns.User].values) From 0d006b5d33a89d8b1569c1477b00587a2b35f528 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Thu, 27 Mar 2025 23:56:19 +0300 Subject: [PATCH 17/30] Fixed docs. --- rectools/models/nn/transformers/data_preparator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 4fbe6b2a..2ea6a0b1 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -169,7 +169,7 @@ def _process_features_for_id_map( return DenseFeatures.from_iterables(values=full_feature_values, names=raw_features.names) def _filter_train_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: - """Filter train interactions""" + """Filter train interactions.""" train_interactions = interactions.copy() user_stats = train_interactions[Columns.User].value_counts() users = user_stats[user_stats >= self.train_min_user_interactions].index From 325e8c7d46ce68dbe9fe480a804144d832e79c12 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 29 Mar 2025 17:56:37 +0300 Subject: [PATCH 18/30] Fixed docs and removed df copy in filter interactions method. --- rectools/models/nn/transformers/data_preparator.py | 3 +-- rectools/models/nn/transformers/similarity.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 2ea6a0b1..275396db 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -168,9 +168,8 @@ def _process_features_for_id_map( full_feature_values = np.vstack([extra_token_feature_values, sorted_features.values]) return DenseFeatures.from_iterables(values=full_feature_values, names=raw_features.names) - def _filter_train_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: + def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.DataFrame: """Filter train interactions.""" - train_interactions = interactions.copy() user_stats = train_interactions[Columns.User].value_counts() users = user_stats[user_stats >= self.train_min_user_interactions].index train_interactions = train_interactions[(train_interactions[Columns.User].isin(users))] diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 6b23bcc6..0ed4e44d 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -40,7 +40,7 @@ def _recommend_u2i( class DistanceSimilarityModule(SimilarityModuleBase): - """Distandce similarity module.""" + """Distance similarity module.""" dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE] epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) From 3ac5755052dffd9449e822e7ae53b874cbcb7903 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 29 Mar 2025 18:50:51 +0300 Subject: [PATCH 19/30] Added method get_loss_calculator, made default value n_negative as 0. --- rectools/models/nn/transformers/base.py | 2 +- rectools/models/nn/transformers/bert4rec.py | 10 +- .../models/nn/transformers/data_preparator.py | 2 +- rectools/models/nn/transformers/lightning.py | 193 ++++++++---------- rectools/models/nn/transformers/sasrec.py | 4 +- tests/dataset/test_torch_dataset.py | 10 +- 6 files changed, 101 insertions(+), 120 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 775bbbcf..0c76639f 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -317,7 +317,7 @@ def _init_data_preparator(self) -> None: batch_size=self.batch_size, dataloader_num_workers=self.dataloader_num_workers, train_min_user_interactions=self.train_min_user_interactions, - n_negatives=self.n_negatives if self.loss != "softmax" else None, + n_negatives=self.n_negatives if self.loss != "softmax" else 0, get_val_mask_func=self.get_val_mask_func, shuffle_train=True, **self._get_kwargs(self.data_preparator_kwargs), diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index c2f2d814..d5e09023 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -56,7 +56,7 @@ class BERT4RecDataPreparator(TransformerDataPreparatorBase): def __init__( self, session_max_len: int, - n_negatives: tp.Optional[int], + n_negatives: int, batch_size: int, dataloader_num_workers: int, train_min_user_interactions: int, @@ -105,7 +105,7 @@ def _collate_fn_train( Get target by replacing session elements with a MASK token with probability `mask_prob`. Truncate each session and target from right to keep `session_max_len` last items. Do left padding until `session_max_len` is reached. - If `n_negatives` is not None, generate negative items from uniform distribution. + If `n_negatives` is greater than 0, generate negative items from uniform distribution. """ batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) @@ -118,7 +118,7 @@ def _collate_fn_train( yw[i, -len(ses) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} - if self.n_negatives is not None: + if self.n_negatives > 0: negatives = torch.randint( low=self.n_item_extra_tokens, high=self.item_id_map.size, @@ -146,7 +146,7 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st yw[i, -1:] = ses_weights[target_idx] # yw[i]: [1] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} - if self.n_negatives is not None: + if self.n_negatives > 0: negatives = torch.randint( low=self.n_item_extra_tokens, high=self.item_id_map.size, @@ -387,7 +387,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals def _init_data_preparator(self) -> None: self.data_preparator: TransformerDataPreparatorBase = self.data_preparator_type( session_max_len=self.session_max_len, - n_negatives=self.n_negatives if self.loss != "softmax" else None, + n_negatives=self.n_negatives if self.loss != "softmax" else 0, batch_size=self.batch_size, dataloader_num_workers=self.dataloader_num_workers, train_min_user_interactions=self.train_min_user_interactions, diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 275396db..ee9d364f 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -119,7 +119,7 @@ def __init__( dataloader_num_workers: int, shuffle_train: bool = True, train_min_user_interactions: int = 2, - n_negatives: tp.Optional[int] = None, + n_negatives: int = 0, get_val_mask_func: tp.Optional[tp.Callable] = None, **kwargs: tp.Any, ) -> None: diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 5626fd88..c1996414 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -85,6 +85,7 @@ def __init__( self.data_preparator = data_preparator self.lr = lr self.loss = loss + self.loss_calculator, self.requires_negatives = self.get_loss_calculator() self.adam_betas = adam_betas self.gbce_t = gbce_t self.verbose = verbose @@ -94,19 +95,75 @@ def __init__( self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) - @property - def requires_negatives(self) -> tp.Optional[bool]: - """Indicator for determining the need for negatives for loss functions.""" + def get_loss_calculator(self) -> tp.Tuple[tp.Optional[tp.Callable], tp.Optional[bool]]: + """Return loss calculator and indicator for determining the need for negatives for loss functions.""" if self.loss == "softmax": - return False + return self._calc_softmax_loss, False if self.loss == "BCE": - return True + return self._calc_bce_loss, True if self.loss == "gBCE": - return True + return self._calc_gbce_loss, True - return None + return None, None + + @classmethod + def _calc_softmax_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + # We are using CrossEntropyLoss with a multi-dimensional case + + # Logits must be passed in form of [batch_size, n_items + n_item_extra_tokens, session_max_len], + # where n_items + n_item_extra_tokens is number of classes + + # Target label indexes must be passed in a form of [batch_size, session_max_len] + # (`0` index for "PAD" ix excluded from loss) + + # Loss output will have a shape of [batch_size, session_max_len] + # and will have zeros for every `0` target label + loss = torch.nn.functional.cross_entropy( + logits.transpose(1, 2), y, ignore_index=0, reduction="none" + ) # [batch_size, session_max_len] + loss = loss * w + n = (loss > 0).to(loss.dtype) + loss = torch.sum(loss) / torch.sum(n) + return loss + + def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int) -> torch.Tensor: + # https://arxiv.org/pdf/2308.07192.pdf + + dtype = torch.float64 # for consistency with the original implementation + alpha = self.data_preparator.n_negatives / (n_items - 1) # sampling rate + beta = alpha * (self.gbce_t * (1 - 1 / alpha) + 1 / alpha) + + pos_logits = logits[:, :, 0:1].to(dtype) + neg_logits = logits[:, :, 1:].to(dtype) + + epsilon = 1e-10 + pos_probs = torch.clamp(torch.sigmoid(pos_logits), epsilon, 1 - epsilon) + pos_probs_adjusted = torch.clamp(pos_probs.pow(-beta), 1 + epsilon, torch.finfo(dtype).max) + pos_probs_adjusted = torch.clamp(torch.div(1, (pos_probs_adjusted - 1)), epsilon, torch.finfo(dtype).max) + pos_logits_transformed = torch.log(pos_probs_adjusted) + logits = torch.cat([pos_logits_transformed, neg_logits], dim=-1) + return logits + + @classmethod + def _calc_bce_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + mask = y != 0 + target = torch.zeros_like(logits) + target[:, :, 0] = 1 + + loss = torch.nn.functional.binary_cross_entropy_with_logits( + logits, target, reduction="none" + ) # [batch_size, session_max_len, n_negatives + 1] + loss = loss.mean(-1) * mask * w # [batch_size, session_max_len] + loss = torch.sum(loss) / torch.sum(mask) + return loss + + def _calc_gbce_loss(self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + n_actual_items = self.torch_model.item_model.n_items - len(self.item_extra_tokens) + logits = self._get_reduced_overconfidence_logits(logits, n_actual_items) + loss = self._calc_bce_loss(logits, y, w) + return loss def configure_optimizers(self) -> torch.optim.Adam: """Choose what optimizers and learning-rate schedulers to use in optimization""" @@ -161,29 +218,27 @@ def on_train_start(self) -> None: """Initialize parameters with values from Xavier normal distribution.""" self._xavier_normal_init() - def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Training step.""" - x, y, w = batch["x"], batch["y"], batch["yw"] - + def get_batch_logits(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: + """Get bacth logits.""" + x = batch["x"] # x: [batch_size, session_max_len] if self.requires_negatives: - negatives = batch["negatives"] - # [batch_size, session_max_len, n_negatives + 1] + y, negatives = batch["y"], batch["negatives"] pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) logits = self.torch_model(sessions=x, item_ids=pos_neg) - elif self.requires_negatives is not None: + else: logits = self.torch_model(sessions=x) + return logits - if self.loss == "softmax": - loss = self._calc_softmax_loss(logits, y, w) - elif self.loss == "BCE": - loss = self._calc_bce_loss(logits, y, w) - elif self.loss == "gBCE": - loss = self._calc_gbce_loss(logits, y, w, negatives) + def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Training step.""" + if self.loss_calculator is not None: + y, w = batch["y"], batch["yw"] + logits = self.get_batch_logits(batch) + loss = self.loss_calculator(logits, y, w) else: loss = self._calc_custom_loss(batch, batch_idx) self.log(self.train_loss_name, loss, on_step=False, on_epoch=True, prog_bar=self.verbose > 0) - return loss def _calc_custom_loss(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -202,32 +257,18 @@ def on_validation_end(self) -> None: def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> tp.Dict[str, torch.Tensor]: """Validate step.""" - # x: [batch_size, session_max_len] - # y: [batch_size, 1] - # yw: [batch_size, 1] - x, y, w = batch["x"], batch["y"], batch["yw"] - - if self.requires_negatives: - negatives = batch["negatives"] - # [batch_size, session_max_len, n_negatives + 1] - pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) - logits = self.torch_model(sessions=x, item_ids=pos_neg) - elif self.requires_negatives is not None: - logits = self.torch_model(sessions=x) - - if self.requires_negatives is not None: + if self.loss_calculator is not None: + # y: [batch_size, 1] + # yw: [batch_size, 1] + y, w = batch["y"], batch["yw"] + logits = self.get_batch_logits(batch) logits = logits[:, -1:, :] - - outputs = {} - if self.loss == "softmax": - outputs["loss"] = self._calc_softmax_loss(logits, y, w) - outputs["logits"] = logits.squeeze() - elif self.loss == "BCE": - outputs["loss"] = self._calc_bce_loss(logits, y, w) - outputs["pos_neg_logits"] = logits.squeeze() - elif self.loss == "gBCE": - outputs["loss"] = self._calc_gbce_loss(logits, y, w, negatives) - outputs["pos_neg_logits"] = logits.squeeze() + loss = self.loss_calculator(logits, y, w) + type_logits = "pos_neg_logits" if self.requires_negatives else "logits" + outputs = { + "loss": loss, + type_logits: logits, + } else: outputs = self._calc_custom_loss_outputs(batch, batch_idx) # pragma: no cover @@ -239,66 +280,6 @@ def _calc_custom_loss_outputs( ) -> tp.Dict[str, torch.Tensor]: raise ValueError(f"loss {self.loss} is not supported") # pragma: no cover - def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int, n_negatives: int) -> torch.Tensor: - # https://arxiv.org/pdf/2308.07192.pdf - - dtype = torch.float64 # for consistency with the original implementation - alpha = n_negatives / (n_items - 1) # sampling rate - beta = alpha * (self.gbce_t * (1 - 1 / alpha) + 1 / alpha) - - pos_logits = logits[:, :, 0:1].to(dtype) - neg_logits = logits[:, :, 1:].to(dtype) - - epsilon = 1e-10 - pos_probs = torch.clamp(torch.sigmoid(pos_logits), epsilon, 1 - epsilon) - pos_probs_adjusted = torch.clamp(pos_probs.pow(-beta), 1 + epsilon, torch.finfo(dtype).max) - pos_probs_adjusted = torch.clamp(torch.div(1, (pos_probs_adjusted - 1)), epsilon, torch.finfo(dtype).max) - pos_logits_transformed = torch.log(pos_probs_adjusted) - logits = torch.cat([pos_logits_transformed, neg_logits], dim=-1) - return logits - - @classmethod - def _calc_softmax_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: - # We are using CrossEntropyLoss with a multi-dimensional case - - # Logits must be passed in form of [batch_size, n_items + n_item_extra_tokens, session_max_len], - # where n_items + n_item_extra_tokens is number of classes - - # Target label indexes must be passed in a form of [batch_size, session_max_len] - # (`0` index for "PAD" ix excluded from loss) - - # Loss output will have a shape of [batch_size, session_max_len] - # and will have zeros for every `0` target label - loss = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), y, ignore_index=0, reduction="none" - ) # [batch_size, session_max_len] - loss = loss * w - n = (loss > 0).to(loss.dtype) - loss = torch.sum(loss) / torch.sum(n) - return loss - - @classmethod - def _calc_bce_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: - mask = y != 0 - target = torch.zeros_like(logits) - target[:, :, 0] = 1 - - loss = torch.nn.functional.binary_cross_entropy_with_logits( - logits, target, reduction="none" - ) # [batch_size, session_max_len, n_negatives + 1] - loss = loss.mean(-1) * mask * w # [batch_size, session_max_len] - loss = torch.sum(loss) / torch.sum(mask) - return loss - - def _calc_gbce_loss( - self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor, negatives: torch.Tensor - ) -> torch.Tensor: - n_actual_items = self.torch_model.item_model.n_items - len(self.item_extra_tokens) - n_negatives = negatives.shape[2] - logits = self._get_reduced_overconfidence_logits(logits, n_actual_items, n_negatives) - loss = self._calc_bce_loss(logits, y, w) - return loss - def _xavier_normal_init(self) -> None: for _, param in self.torch_model.named_parameters(): if param.data.dim() > 1: diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 4bc36907..d86f63d9 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -71,7 +71,7 @@ def _collate_fn_train( yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} - if self.n_negatives is not None: + if self.n_negatives > 0: negatives = torch.randint( low=self.n_item_extra_tokens, high=self.item_id_map.size, @@ -97,7 +97,7 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st yw[i, -1:] = ses_weights[target_idx] # yw[i]: [1] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} - if self.n_negatives is not None: + if self.n_negatives > 0: negatives = torch.randint( low=self.n_item_extra_tokens, high=self.item_id_map.size, diff --git a/tests/dataset/test_torch_dataset.py b/tests/dataset/test_torch_dataset.py index 0881a5be..56c64563 100644 --- a/tests/dataset/test_torch_dataset.py +++ b/tests/dataset/test_torch_dataset.py @@ -155,8 +155,8 @@ def test_getitem_reconstructs_users(self, dataset: Dataset) -> None: all_user_features.append(user_features.view(1, -1)) all_interactions.append(interactions.view(1, -1)) - all_user_features = torch.cat(all_user_features, 0).numpy() # type: ignore - all_interactions = torch.cat(all_interactions, 0).numpy() # type: ignore + all_user_features = torch.cat(all_user_features, 0).numpy() + all_interactions = torch.cat(all_interactions, 0).numpy() ui_matrix = dataset.get_user_item_matrix().toarray() assert np.allclose(all_user_features, dataset.user_features.get_sparse().toarray()) # type: ignore @@ -198,8 +198,8 @@ def test_getitem_reconstructs_users(self, dataset: Dataset) -> None: all_user_features.append(user_features.view(1, -1)) all_interactions.append(interactions.view(1, -1)) - all_user_features = torch.cat(all_user_features, 0).numpy() # type: ignore - all_interactions = torch.cat(all_interactions, 0).numpy() # type: ignore + all_user_features = torch.cat(all_user_features, 0).numpy() + all_interactions = torch.cat(all_interactions, 0).numpy() ui_matrix = dataset.get_user_item_matrix().toarray() assert np.allclose(all_user_features, dataset.user_features.get_sparse().toarray()) # type: ignore @@ -236,7 +236,7 @@ def test_getitem_reconstructs_items(self, dataset: Dataset) -> None: item_features = items_dataset[idx] all_item_features.append(item_features.view(1, -1)) - all_item_features = torch.cat(all_item_features, 0).numpy() # type: ignore + all_item_features = torch.cat(all_item_features, 0).numpy() assert np.allclose(all_item_features, dataset.item_features.get_sparse().toarray()) # type: ignore def test_raises_attribute_error(self, dataset_no_features: Dataset) -> None: From c43091843df3da65c5386c69c7926172f71a1f0c Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 29 Mar 2025 19:55:38 +0300 Subject: [PATCH 20/30] Returned Optional value for n_negatives. --- rectools/models/nn/transformers/base.py | 2 +- rectools/models/nn/transformers/bert4rec.py | 10 +++++----- rectools/models/nn/transformers/data_preparator.py | 2 +- rectools/models/nn/transformers/lightning.py | 6 +++++- rectools/models/nn/transformers/sasrec.py | 4 ++-- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 0c76639f..775bbbcf 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -317,7 +317,7 @@ def _init_data_preparator(self) -> None: batch_size=self.batch_size, dataloader_num_workers=self.dataloader_num_workers, train_min_user_interactions=self.train_min_user_interactions, - n_negatives=self.n_negatives if self.loss != "softmax" else 0, + n_negatives=self.n_negatives if self.loss != "softmax" else None, get_val_mask_func=self.get_val_mask_func, shuffle_train=True, **self._get_kwargs(self.data_preparator_kwargs), diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index d5e09023..c2f2d814 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -56,7 +56,7 @@ class BERT4RecDataPreparator(TransformerDataPreparatorBase): def __init__( self, session_max_len: int, - n_negatives: int, + n_negatives: tp.Optional[int], batch_size: int, dataloader_num_workers: int, train_min_user_interactions: int, @@ -105,7 +105,7 @@ def _collate_fn_train( Get target by replacing session elements with a MASK token with probability `mask_prob`. Truncate each session and target from right to keep `session_max_len` last items. Do left padding until `session_max_len` is reached. - If `n_negatives` is greater than 0, generate negative items from uniform distribution. + If `n_negatives` is not None, generate negative items from uniform distribution. """ batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) @@ -118,7 +118,7 @@ def _collate_fn_train( yw[i, -len(ses) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} - if self.n_negatives > 0: + if self.n_negatives is not None: negatives = torch.randint( low=self.n_item_extra_tokens, high=self.item_id_map.size, @@ -146,7 +146,7 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st yw[i, -1:] = ses_weights[target_idx] # yw[i]: [1] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} - if self.n_negatives > 0: + if self.n_negatives is not None: negatives = torch.randint( low=self.n_item_extra_tokens, high=self.item_id_map.size, @@ -387,7 +387,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals def _init_data_preparator(self) -> None: self.data_preparator: TransformerDataPreparatorBase = self.data_preparator_type( session_max_len=self.session_max_len, - n_negatives=self.n_negatives if self.loss != "softmax" else 0, + n_negatives=self.n_negatives if self.loss != "softmax" else None, batch_size=self.batch_size, dataloader_num_workers=self.dataloader_num_workers, train_min_user_interactions=self.train_min_user_interactions, diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index ee9d364f..275396db 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -119,7 +119,7 @@ def __init__( dataloader_num_workers: int, shuffle_train: bool = True, train_min_user_interactions: int = 2, - n_negatives: int = 0, + n_negatives: tp.Optional[int] = None, get_val_mask_func: tp.Optional[tp.Callable] = None, **kwargs: tp.Any, ) -> None: diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index c1996414..8c9a0a3e 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -132,7 +132,11 @@ def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int) # https://arxiv.org/pdf/2308.07192.pdf dtype = torch.float64 # for consistency with the original implementation - alpha = self.data_preparator.n_negatives / (n_items - 1) # sampling rate + n_negatives = self.data_preparator.n_negatives + if n_negatives is not None: + alpha = n_negatives / (n_items - 1) # sampling rate + else: + raise ValueError("`n_negatives` is not defined. Please ensure that `n_negatives` is set.") beta = alpha * (self.gbce_t * (1 - 1 / alpha) + 1 / alpha) pos_logits = logits[:, :, 0:1].to(dtype) diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index d86f63d9..4bc36907 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -71,7 +71,7 @@ def _collate_fn_train( yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} - if self.n_negatives > 0: + if self.n_negatives is not None: negatives = torch.randint( low=self.n_item_extra_tokens, high=self.item_id_map.size, @@ -97,7 +97,7 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st yw[i, -1:] = ses_weights[target_idx] # yw[i]: [1] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} - if self.n_negatives > 0: + if self.n_negatives is not None: negatives = torch.randint( low=self.n_item_extra_tokens, high=self.item_id_map.size, From 2c7128777995636cb47eb4bae60a5ee2e9b2359b Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 29 Mar 2025 19:57:51 +0300 Subject: [PATCH 21/30] n_negatives is Optional[int] --- tests/models/nn/transformers/test_bert4rec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 7480d1d6..a8121c3d 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -541,7 +541,7 @@ class NextActionDataPreparator(BERT4RecDataPreparator): def __init__( self, session_max_len: int, - n_negatives: tp.Optional[int], + n_negatives: int, batch_size: int, dataloader_num_workers: int, train_min_user_interactions: int, From 113dba4ea072c05c90206dfbc4e385d99d42112f Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 29 Mar 2025 19:59:39 +0300 Subject: [PATCH 22/30] Returned # type: ignore to test_torch_dataset.py --- tests/dataset/test_torch_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/dataset/test_torch_dataset.py b/tests/dataset/test_torch_dataset.py index 56c64563..0881a5be 100644 --- a/tests/dataset/test_torch_dataset.py +++ b/tests/dataset/test_torch_dataset.py @@ -155,8 +155,8 @@ def test_getitem_reconstructs_users(self, dataset: Dataset) -> None: all_user_features.append(user_features.view(1, -1)) all_interactions.append(interactions.view(1, -1)) - all_user_features = torch.cat(all_user_features, 0).numpy() - all_interactions = torch.cat(all_interactions, 0).numpy() + all_user_features = torch.cat(all_user_features, 0).numpy() # type: ignore + all_interactions = torch.cat(all_interactions, 0).numpy() # type: ignore ui_matrix = dataset.get_user_item_matrix().toarray() assert np.allclose(all_user_features, dataset.user_features.get_sparse().toarray()) # type: ignore @@ -198,8 +198,8 @@ def test_getitem_reconstructs_users(self, dataset: Dataset) -> None: all_user_features.append(user_features.view(1, -1)) all_interactions.append(interactions.view(1, -1)) - all_user_features = torch.cat(all_user_features, 0).numpy() - all_interactions = torch.cat(all_interactions, 0).numpy() + all_user_features = torch.cat(all_user_features, 0).numpy() # type: ignore + all_interactions = torch.cat(all_interactions, 0).numpy() # type: ignore ui_matrix = dataset.get_user_item_matrix().toarray() assert np.allclose(all_user_features, dataset.user_features.get_sparse().toarray()) # type: ignore @@ -236,7 +236,7 @@ def test_getitem_reconstructs_items(self, dataset: Dataset) -> None: item_features = items_dataset[idx] all_item_features.append(item_features.view(1, -1)) - all_item_features = torch.cat(all_item_features, 0).numpy() + all_item_features = torch.cat(all_item_features, 0).numpy() # type: ignore assert np.allclose(all_item_features, dataset.item_features.get_sparse().toarray()) # type: ignore def test_raises_attribute_error(self, dataset_no_features: Dataset) -> None: From e1c40b51ab15090cb1b4792cc2d14f450423549c Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 29 Mar 2025 20:07:34 +0300 Subject: [PATCH 23/30] Added `pragma: no cover` to `_get_reduced_overconfidence_logits` --- rectools/models/nn/transformers/lightning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 8c9a0a3e..f0dbc1b0 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -136,7 +136,9 @@ def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int) if n_negatives is not None: alpha = n_negatives / (n_items - 1) # sampling rate else: - raise ValueError("`n_negatives` is not defined. Please ensure that `n_negatives` is set.") + raise ValueError( + "`n_negatives` is not defined. Please ensure that `n_negatives` is set." + ) # pragma: no cover beta = alpha * (self.gbce_t * (1 - 1 / alpha) + 1 / alpha) pos_logits = logits[:, :, 0:1].to(dtype) From f1583c5127de16760fef237733181c919f2a3a64 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sat, 29 Mar 2025 22:40:37 +0300 Subject: [PATCH 24/30] Added annotation for get_loss_calculator. --- rectools/models/nn/transformers/lightning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index f0dbc1b0..9d3525cc 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -95,7 +95,11 @@ def __init__( self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) - def get_loss_calculator(self) -> tp.Tuple[tp.Optional[tp.Callable], tp.Optional[bool]]: + def get_loss_calculator( + self, + ) -> tp.Tuple[ + tp.Optional[tp.Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], tp.Optional[bool] + ]: """Return loss calculator and indicator for determining the need for negatives for loss functions.""" if self.loss == "softmax": return self._calc_softmax_loss, False From da6d12256271e7f4216c37fbbbff624016c1d3c0 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Sun, 30 Mar 2025 14:04:43 +0300 Subject: [PATCH 25/30] Separated logic `get_loss_calculator` and `requires_negatives`. --- rectools/models/nn/transformers/base.py | 3 +- rectools/models/nn/transformers/lightning.py | 32 +++++++++++++------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 775bbbcf..b8b1208e 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -312,12 +312,13 @@ def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs: return kwargs def _init_data_preparator(self) -> None: + requires_negatives = self.lightning_module_type.requires_negatives(self.loss) self.data_preparator = self.data_preparator_type( session_max_len=self.session_max_len, batch_size=self.batch_size, dataloader_num_workers=self.dataloader_num_workers, train_min_user_interactions=self.train_min_user_interactions, - n_negatives=self.n_negatives if self.loss != "softmax" else None, + n_negatives=self.n_negatives if requires_negatives else None, get_val_mask_func=self.get_val_mask_func, shuffle_train=True, **self._get_kwargs(self.data_preparator_kwargs), diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 9d3525cc..5bfb660d 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -85,7 +85,8 @@ def __init__( self.data_preparator = data_preparator self.lr = lr self.loss = loss - self.loss_calculator, self.requires_negatives = self.get_loss_calculator() + self.loss_calculator = self.get_loss_calculator() + self._requires_negatives = self.requires_negatives(loss) self.adam_betas = adam_betas self.gbce_t = gbce_t self.verbose = verbose @@ -95,22 +96,31 @@ def __init__( self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) + @staticmethod + def requires_negatives(loss: str) -> tp.Optional[bool]: + """Return flag for determining the need for negatives.""" + if loss == "softmax": + return False + + if loss in ["BCE", "gBCE"]: + return True + + return None + def get_loss_calculator( self, - ) -> tp.Tuple[ - tp.Optional[tp.Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], tp.Optional[bool] - ]: - """Return loss calculator and indicator for determining the need for negatives for loss functions.""" + ) -> tp.Optional[tp.Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]: + """Return loss calculator.""" if self.loss == "softmax": - return self._calc_softmax_loss, False + return self._calc_softmax_loss if self.loss == "BCE": - return self._calc_bce_loss, True + return self._calc_bce_loss if self.loss == "gBCE": - return self._calc_gbce_loss, True + return self._calc_gbce_loss - return None, None + return None @classmethod def _calc_softmax_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: @@ -231,7 +241,7 @@ def on_train_start(self) -> None: def get_batch_logits(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: """Get bacth logits.""" x = batch["x"] # x: [batch_size, session_max_len] - if self.requires_negatives: + if self._requires_negatives: y, negatives = batch["y"], batch["negatives"] pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) logits = self.torch_model(sessions=x, item_ids=pos_neg) @@ -274,7 +284,7 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> logits = self.get_batch_logits(batch) logits = logits[:, -1:, :] loss = self.loss_calculator(logits, y, w) - type_logits = "pos_neg_logits" if self.requires_negatives else "logits" + type_logits = "pos_neg_logits" if self._requires_negatives else "logits" outputs = { "loss": loss, type_logits: logits, From 606432b823017ee7a647d5d0b82ad58bd988dbb6 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Mon, 31 Mar 2025 13:57:33 +0300 Subject: [PATCH 26/30] Added device for epsilone cosine dist. --- rectools/models/nn/transformers/similarity.py | 1 + rectools/models/rank/rank_torch.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 0ed4e44d..66e17ebd 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -65,6 +65,7 @@ def _get_pos_neg_logits( return logits def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: + self.epsilon_cosine_dist.to(embeddings.device) embedding_norm = torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist) return embeddings diff --git a/rectools/models/rank/rank_torch.py b/rectools/models/rank/rank_torch.py index 12ccca19..dcd33080 100644 --- a/rectools/models/rank/rank_torch.py +++ b/rectools/models/rank/rank_torch.py @@ -74,6 +74,8 @@ def __init__( self.subjects_factors = self._normalize_tensor(subjects_factors) self.objects_factors = self._normalize_tensor(objects_factors) + self.epsilon_cosine_dist.to(self.device) + def rank( self, subject_ids: InternalIds, From 9ac6150a5f097a056547d4f8591584d8f01672e1 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Mon, 31 Mar 2025 13:59:37 +0300 Subject: [PATCH 27/30] Added copyright. --- rectools/models/nn/transformers/similarity.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 66e17ebd..93cdf65f 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -1,3 +1,17 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import typing as tp import numpy as np From cf66a58d2e3346bb01286ce4f0372cc6365d2c11 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Mon, 31 Mar 2025 14:52:45 +0300 Subject: [PATCH 28/30] Updated CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97483034..bba9d340 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +### Added +- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) + ## [0.12.0] - 24.02.2025 ### Added From 5ed7605abcee2ab4c7acb3c185d5f52bd4e88a94 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Mon, 31 Mar 2025 16:11:20 +0300 Subject: [PATCH 29/30] Made `epsilon_cosine_dist` as Parameter. --- rectools/models/nn/transformers/similarity.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 93cdf65f..5d99ad3f 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -57,7 +57,7 @@ class DistanceSimilarityModule(SimilarityModuleBase): """Distance similarity module.""" dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE] - epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) + epsilon_cosine_dist: torch.nn.Parameter = torch.nn.Parameter(torch.Tensor([1e-8]), requires_grad=False) def __init__(self, distance: str = "dot") -> None: super().__init__() @@ -79,7 +79,6 @@ def _get_pos_neg_logits( return logits def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: - self.epsilon_cosine_dist.to(embeddings.device) embedding_norm = torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist) return embeddings From ccffb1689c939d1f18d687bc72ce2f5d518d9986 Mon Sep 17 00:00:00 2001 From: In48semenov Date: Mon, 31 Mar 2025 16:33:52 +0300 Subject: [PATCH 30/30] Fixed epsilon_cosine_dist device --- rectools/models/nn/transformers/similarity.py | 4 ++-- rectools/models/rank/rank_torch.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 5d99ad3f..ec37ba14 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -57,7 +57,7 @@ class DistanceSimilarityModule(SimilarityModuleBase): """Distance similarity module.""" dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE] - epsilon_cosine_dist: torch.nn.Parameter = torch.nn.Parameter(torch.Tensor([1e-8]), requires_grad=False) + epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) def __init__(self, distance: str = "dot") -> None: super().__init__() @@ -80,7 +80,7 @@ def _get_pos_neg_logits( def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: embedding_norm = torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) - embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist) + embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist.to(embeddings)) return embeddings def forward( diff --git a/rectools/models/rank/rank_torch.py b/rectools/models/rank/rank_torch.py index dcd33080..6a0c1f2c 100644 --- a/rectools/models/rank/rank_torch.py +++ b/rectools/models/rank/rank_torch.py @@ -74,8 +74,6 @@ def __init__( self.subjects_factors = self._normalize_tensor(subjects_factors) self.objects_factors = self._normalize_tensor(objects_factors) - self.epsilon_cosine_dist.to(self.device) - def rank( self, subject_ids: InternalIds, @@ -198,8 +196,12 @@ def _euclid_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> tor return torch.cdist(user_embs.unsqueeze(0), item_embs.unsqueeze(0)).squeeze(0) def _cosine_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - user_embs = user_embs / torch.max(torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1), self.epsilon_cosine_dist) - item_embs = item_embs / torch.max(torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1), self.epsilon_cosine_dist) + user_embs = user_embs / torch.max( + torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1), self.epsilon_cosine_dist.to(user_embs) + ) + item_embs = item_embs / torch.max( + torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1), self.epsilon_cosine_dist.to(user_embs) + ) return user_embs @ item_embs.T