diff --git a/CHANGELOG.md b/CHANGELOG.md index 97483034..bba9d340 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +### Added +- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) + ## [0.12.0] - 24.02.2025 ### Added diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 92c35020..b8b1208e 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -46,6 +46,7 @@ PreLNTransformerLayers, TransformerLayersBase, ) +from .similarity import DistanceSimilarityModule, SimilarityModuleBase from .torch_backbone import TransformerTorchBackbone InitKwargs = tp.Dict[str, tp.Any] @@ -97,6 +98,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: ), ] +SimilarityModuleType = tpe.Annotated[ + tp.Type[SimilarityModuleBase], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + TransformerDataPreparatorType = tpe.Annotated[ tp.Type[TransformerDataPreparatorBase], BeforeValidator(_get_class_obj), @@ -183,6 +194,7 @@ class TransformerModelConfig(ModelConfig): pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: TransformerLightningModuleType = TransformerLightningModule + similarity_module_type: SimilarityModuleType = DistanceSimilarityModule get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None @@ -190,6 +202,7 @@ class TransformerModelConfig(ModelConfig): item_net_constructor_kwargs: tp.Optional[InitKwargs] = None pos_encoding_kwargs: tp.Optional[InitKwargs] = None lightning_module_kwargs: tp.Optional[InitKwargs] = None + similarity_module_kwargs: tp.Optional[InitKwargs] = None TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig) @@ -237,6 +250,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor, pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, + similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, data_preparator_kwargs: tp.Optional[InitKwargs] = None, @@ -244,6 +258,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs: tp.Optional[InitKwargs] = None, pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, + similarity_module_kwargs: tp.Optional[InitKwargs] = None, **kwargs: tp.Any, ) -> None: super().__init__(verbose=verbose) @@ -268,6 +283,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.recommend_batch_size = recommend_batch_size self.recommend_torch_device = recommend_torch_device self.train_min_user_interactions = train_min_user_interactions + self.similarity_module_type = similarity_module_type self.item_net_block_types = item_net_block_types self.item_net_constructor_type = item_net_constructor_type self.pos_encoding_type = pos_encoding_type @@ -279,6 +295,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.item_net_constructor_kwargs = item_net_constructor_kwargs self.pos_encoding_kwargs = pos_encoding_kwargs self.lightning_module_kwargs = lightning_module_kwargs + self.similarity_module_kwargs = similarity_module_kwargs self._init_data_preparator() self._init_trainer() @@ -295,12 +312,13 @@ def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs: return kwargs def _init_data_preparator(self) -> None: + requires_negatives = self.lightning_module_type.requires_negatives(self.loss) self.data_preparator = self.data_preparator_type( session_max_len=self.session_max_len, batch_size=self.batch_size, dataloader_num_workers=self.dataloader_num_workers, train_min_user_interactions=self.train_min_user_interactions, - n_negatives=self.n_negatives if self.loss != "softmax" else None, + n_negatives=self.n_negatives if requires_negatives else None, get_val_mask_func=self.get_val_mask_func, shuffle_train=True, **self._get_kwargs(self.data_preparator_kwargs), @@ -356,15 +374,20 @@ def _init_transformer_layers(self) -> TransformerLayersBase: **self._get_kwargs(self.transformer_layers_kwargs), ) + def _init_similarity_module(self) -> SimilarityModuleBase: + return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs)) + def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone: pos_encoding_layer = self._init_pos_encoding_layer() transformer_layers = self._init_transformer_layers() + similarity_module = self._init_similarity_module() return TransformerTorchBackbone( n_heads=self.n_heads, dropout_rate=self.dropout_rate, item_model=item_model, pos_encoding_layer=pos_encoding_layer, transformer_layers=transformer_layers, + similarity_module=similarity_module, use_causal_attn=self.use_causal_attn, use_key_padding_mask=self.use_key_padding_mask, ) diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 71675ebd..c2f2d814 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -44,6 +44,7 @@ PreLNTransformerLayers, TransformerLayersBase, ) +from .similarity import DistanceSimilarityModule, SimilarityModuleBase class BERT4RecDataPreparator(TransformerDataPreparatorBase): @@ -256,6 +257,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): Type of data preparator used for dataset processing and dataloader creation. lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule` Type of lightning module defining training procedure. + similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` + Type of similarity module. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` @@ -289,6 +292,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): lightning_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `lightning_module_type` initialization. Make sure all dict values have JSON serializable types. + similarity_module_kwargs: optional(dict), default ``None`` + Additional keyword arguments to pass during `similarity_module_type` initialization. + Make sure all dict values have JSON serializable types. """ config_class = BERT4RecModelConfig @@ -320,6 +326,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers, data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, + similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -332,6 +339,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs: tp.Optional[InitKwargs] = None, pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, + similarity_module_kwargs: tp.Optional[InitKwargs] = None, ): self.mask_prob = mask_prob @@ -360,6 +368,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_n_threads=recommend_n_threads, recommend_use_torch_ranking=recommend_use_torch_ranking, train_min_user_interactions=train_min_user_interactions, + similarity_module_type=similarity_module_type, item_net_block_types=item_net_block_types, item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, @@ -372,6 +381,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs=item_net_constructor_kwargs, pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, + similarity_module_kwargs=similarity_module_kwargs, ) def _init_data_preparator(self) -> None: diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 2b2a899e..275396db 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -168,6 +168,18 @@ def _process_features_for_id_map( full_feature_values = np.vstack([extra_token_feature_values, sorted_features.values]) return DenseFeatures.from_iterables(values=full_feature_values, names=raw_features.names) + def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.DataFrame: + """Filter train interactions.""" + user_stats = train_interactions[Columns.User].value_counts() + users = user_stats[user_stats >= self.train_min_user_interactions].index + train_interactions = train_interactions[(train_interactions[Columns.User].isin(users))] + train_interactions = ( + train_interactions.sort_values(Columns.Datetime, kind="stable") + .groupby(Columns.User, sort=False) + .tail(self.session_max_len + self.train_session_max_len_addition) + ) + return train_interactions + def process_dataset_train(self, dataset: Dataset) -> None: """Process train dataset and save data.""" raw_interactions = dataset.get_raw_interactions() @@ -179,14 +191,7 @@ def process_dataset_train(self, dataset: Dataset) -> None: interactions = raw_interactions[~val_mask] # Filter train interactions - user_stats = interactions[Columns.User].value_counts() - users = user_stats[user_stats >= self.train_min_user_interactions].index - interactions = interactions[(interactions[Columns.User].isin(users))] - interactions = ( - interactions.sort_values(Columns.Datetime, kind="stable") - .groupby(Columns.User, sort=False) - .tail(self.session_max_len + self.train_session_max_len_addition) - ) + interactions = self._filter_train_interactions(interactions) # Prepare id maps user_id_map = IdMap.from_values(interactions[Columns.User].values) diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 05e363fc..5bfb660d 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -15,7 +15,6 @@ import typing as tp from collections.abc import Hashable -import numpy as np import torch from pytorch_lightning import LightningModule from torch.utils.data import DataLoader @@ -57,6 +56,9 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma Name of the training loss. """ + u2i_dist_available = [Distance.DOT, Distance.COSINE] + epsilon_cosine_dist = 1e-8 + def __init__( self, torch_model: TransformerTorchBackbone, @@ -83,6 +85,8 @@ def __init__( self.data_preparator = data_preparator self.lr = lr self.loss = loss + self.loss_calculator = self.get_loss_calculator() + self._requires_negatives = self.requires_negatives(loss) self.adam_betas = adam_betas self.gbce_t = gbce_t self.verbose = verbose @@ -92,6 +96,95 @@ def __init__( self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) + @staticmethod + def requires_negatives(loss: str) -> tp.Optional[bool]: + """Return flag for determining the need for negatives.""" + if loss == "softmax": + return False + + if loss in ["BCE", "gBCE"]: + return True + + return None + + def get_loss_calculator( + self, + ) -> tp.Optional[tp.Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]: + """Return loss calculator.""" + if self.loss == "softmax": + return self._calc_softmax_loss + + if self.loss == "BCE": + return self._calc_bce_loss + + if self.loss == "gBCE": + return self._calc_gbce_loss + + return None + + @classmethod + def _calc_softmax_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + # We are using CrossEntropyLoss with a multi-dimensional case + + # Logits must be passed in form of [batch_size, n_items + n_item_extra_tokens, session_max_len], + # where n_items + n_item_extra_tokens is number of classes + + # Target label indexes must be passed in a form of [batch_size, session_max_len] + # (`0` index for "PAD" ix excluded from loss) + + # Loss output will have a shape of [batch_size, session_max_len] + # and will have zeros for every `0` target label + loss = torch.nn.functional.cross_entropy( + logits.transpose(1, 2), y, ignore_index=0, reduction="none" + ) # [batch_size, session_max_len] + loss = loss * w + n = (loss > 0).to(loss.dtype) + loss = torch.sum(loss) / torch.sum(n) + return loss + + def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int) -> torch.Tensor: + # https://arxiv.org/pdf/2308.07192.pdf + + dtype = torch.float64 # for consistency with the original implementation + n_negatives = self.data_preparator.n_negatives + if n_negatives is not None: + alpha = n_negatives / (n_items - 1) # sampling rate + else: + raise ValueError( + "`n_negatives` is not defined. Please ensure that `n_negatives` is set." + ) # pragma: no cover + beta = alpha * (self.gbce_t * (1 - 1 / alpha) + 1 / alpha) + + pos_logits = logits[:, :, 0:1].to(dtype) + neg_logits = logits[:, :, 1:].to(dtype) + + epsilon = 1e-10 + pos_probs = torch.clamp(torch.sigmoid(pos_logits), epsilon, 1 - epsilon) + pos_probs_adjusted = torch.clamp(pos_probs.pow(-beta), 1 + epsilon, torch.finfo(dtype).max) + pos_probs_adjusted = torch.clamp(torch.div(1, (pos_probs_adjusted - 1)), epsilon, torch.finfo(dtype).max) + pos_logits_transformed = torch.log(pos_probs_adjusted) + logits = torch.cat([pos_logits_transformed, neg_logits], dim=-1) + return logits + + @classmethod + def _calc_bce_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + mask = y != 0 + target = torch.zeros_like(logits) + target[:, :, 0] = 1 + + loss = torch.nn.functional.binary_cross_entropy_with_logits( + logits, target, reduction="none" + ) # [batch_size, session_max_len, n_negatives + 1] + loss = loss.mean(-1) * mask * w # [batch_size, session_max_len] + loss = torch.sum(loss) / torch.sum(mask) + return loss + + def _calc_gbce_loss(self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + n_actual_items = self.torch_model.item_model.n_items - len(self.item_extra_tokens) + logits = self._get_reduced_overconfidence_logits(logits, n_actual_items) + loss = self._calc_bce_loss(logits, y, w) + return loss + def configure_optimizers(self) -> torch.optim.Adam: """Choose what optimizers and learning-rate schedulers to use in optimization""" optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas) @@ -145,25 +238,27 @@ def on_train_start(self) -> None: """Initialize parameters with values from Xavier normal distribution.""" self._xavier_normal_init() + def get_batch_logits(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: + """Get bacth logits.""" + x = batch["x"] # x: [batch_size, session_max_len] + if self._requires_negatives: + y, negatives = batch["y"], batch["negatives"] + pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) + logits = self.torch_model(sessions=x, item_ids=pos_neg) + else: + logits = self.torch_model(sessions=x) + return logits + def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Training step.""" - x, y, w = batch["x"], batch["y"], batch["yw"] - if self.loss == "softmax": - logits = self._get_full_catalog_logits(x) - loss = self._calc_softmax_loss(logits, y, w) - elif self.loss == "BCE": - negatives = batch["negatives"] - logits = self._get_pos_neg_logits(x, y, negatives) - loss = self._calc_bce_loss(logits, y, w) - elif self.loss == "gBCE": - negatives = batch["negatives"] - logits = self._get_pos_neg_logits(x, y, negatives) - loss = self._calc_gbce_loss(logits, y, w, negatives) + if self.loss_calculator is not None: + y, w = batch["y"], batch["yw"] + logits = self.get_batch_logits(batch) + loss = self.loss_calculator(logits, y, w) else: loss = self._calc_custom_loss(batch, batch_idx) self.log(self.train_loss_name, loss, on_step=False, on_epoch=True, prog_bar=self.verbose > 0) - return loss def _calc_custom_loss(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -182,25 +277,18 @@ def on_validation_end(self) -> None: def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> tp.Dict[str, torch.Tensor]: """Validate step.""" - # x: [batch_size, session_max_len] - # y: [batch_size, 1] - # yw: [batch_size, 1] - x, y, w = batch["x"], batch["y"], batch["yw"] - outputs = {} - if self.loss == "softmax": - logits = self._get_full_catalog_logits(x)[:, -1:, :] - outputs["loss"] = self._calc_softmax_loss(logits, y, w) - outputs["logits"] = logits.squeeze() - elif self.loss == "BCE": - negatives = batch["negatives"] - pos_neg_logits = self._get_pos_neg_logits(x, y, negatives)[:, -1:, :] - outputs["loss"] = self._calc_bce_loss(pos_neg_logits, y, w) - outputs["pos_neg_logits"] = pos_neg_logits.squeeze() - elif self.loss == "gBCE": - negatives = batch["negatives"] - pos_neg_logits = self._get_pos_neg_logits(x, y, negatives)[:, -1:, :] - outputs["loss"] = self._calc_gbce_loss(pos_neg_logits, y, w, negatives) - outputs["pos_neg_logits"] = pos_neg_logits.squeeze() + if self.loss_calculator is not None: + # y: [batch_size, 1] + # yw: [batch_size, 1] + y, w = batch["y"], batch["yw"] + logits = self.get_batch_logits(batch) + logits = logits[:, -1:, :] + loss = self.loss_calculator(logits, y, w) + type_logits = "pos_neg_logits" if self._requires_negatives else "logits" + outputs = { + "loss": loss, + type_logits: logits, + } else: outputs = self._calc_custom_loss_outputs(batch, batch_idx) # pragma: no cover @@ -212,80 +300,6 @@ def _calc_custom_loss_outputs( ) -> tp.Dict[str, torch.Tensor]: raise ValueError(f"loss {self.loss} is not supported") # pragma: no cover - def _get_full_catalog_logits(self, x: torch.Tensor) -> torch.Tensor: - item_embs, session_embs = self.torch_model(x) - logits = session_embs @ item_embs.T - return logits - - def _get_pos_neg_logits(self, x: torch.Tensor, y: torch.Tensor, negatives: torch.Tensor) -> torch.Tensor: - # [n_items + n_item_extra_tokens, n_factors], [batch_size, session_max_len, n_factors] - item_embs, session_embs = self.torch_model(x) - pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) # [batch_size, session_max_len, n_negatives + 1] - pos_neg_embs = item_embs[pos_neg] # [batch_size, session_max_len, n_negatives + 1, n_factors] - # [batch_size, session_max_len, n_negatives + 1] - logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) - return logits - - def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int, n_negatives: int) -> torch.Tensor: - # https://arxiv.org/pdf/2308.07192.pdf - - dtype = torch.float64 # for consistency with the original implementation - alpha = n_negatives / (n_items - 1) # sampling rate - beta = alpha * (self.gbce_t * (1 - 1 / alpha) + 1 / alpha) - - pos_logits = logits[:, :, 0:1].to(dtype) - neg_logits = logits[:, :, 1:].to(dtype) - - epsilon = 1e-10 - pos_probs = torch.clamp(torch.sigmoid(pos_logits), epsilon, 1 - epsilon) - pos_probs_adjusted = torch.clamp(pos_probs.pow(-beta), 1 + epsilon, torch.finfo(dtype).max) - pos_probs_adjusted = torch.clamp(torch.div(1, (pos_probs_adjusted - 1)), epsilon, torch.finfo(dtype).max) - pos_logits_transformed = torch.log(pos_probs_adjusted) - logits = torch.cat([pos_logits_transformed, neg_logits], dim=-1) - return logits - - @classmethod - def _calc_softmax_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: - # We are using CrossEntropyLoss with a multi-dimensional case - - # Logits must be passed in form of [batch_size, n_items + n_item_extra_tokens, session_max_len], - # where n_items + n_item_extra_tokens is number of classes - - # Target label indexes must be passed in a form of [batch_size, session_max_len] - # (`0` index for "PAD" ix excluded from loss) - - # Loss output will have a shape of [batch_size, session_max_len] - # and will have zeros for every `0` target label - loss = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), y, ignore_index=0, reduction="none" - ) # [batch_size, session_max_len] - loss = loss * w - n = (loss > 0).to(loss.dtype) - loss = torch.sum(loss) / torch.sum(n) - return loss - - @classmethod - def _calc_bce_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: - mask = y != 0 - target = torch.zeros_like(logits) - target[:, :, 0] = 1 - - loss = torch.nn.functional.binary_cross_entropy_with_logits( - logits, target, reduction="none" - ) # [batch_size, session_max_len, n_negatives + 1] - loss = loss.mean(-1) * mask * w # [batch_size, session_max_len] - loss = torch.sum(loss) / torch.sum(mask) - return loss - - def _calc_gbce_loss( - self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor, negatives: torch.Tensor - ) -> torch.Tensor: - n_actual_items = self.torch_model.item_model.n_items - len(self.item_extra_tokens) - n_negatives = negatives.shape[2] - logits = self._get_reduced_overconfidence_logits(logits, n_actual_items, n_negatives) - loss = self._calc_bce_loss(logits, y, w) - return loss - def _xavier_normal_init(self) -> None: for _, param in self.torch_model.named_parameters(): if param.data.dim() > 1: @@ -336,20 +350,16 @@ def _recommend_u2i( user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device) - ranker = TorchRanker( - distance=Distance.DOT, - device=item_embs.device, - subjects_factors=user_embs[user_ids], - objects_factors=item_embs, - ) - - user_ids_indices, all_reco_ids, all_scores = ranker.rank( - subject_ids=np.arange(len(user_ids)), # n_rec_users - k=k, - filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + n_item_extra_tokens] - sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal + all_user_ids, all_reco_ids, all_scores = ( + self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access + user_embs=user_embs, + item_embs=item_embs, + user_ids=user_ids, + k=k, + sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, + ui_csr_for_filter=ui_csr_for_filter, + ) ) - all_user_ids = user_ids[user_ids_indices] return all_user_ids, all_reco_ids, all_scores def _recommend_i2i( diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 343c9c7c..4bc36907 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -44,6 +44,7 @@ PositionalEncodingBase, TransformerLayersBase, ) +from .similarity import DistanceSimilarityModule, SimilarityModuleBase class SASRecDataPreparator(TransformerDataPreparatorBase): @@ -336,6 +337,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): Type of data preparator used for dataset processing and dataloader creation. lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule` Type of lightning module defining training procedure. + similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` + Type of similarity module. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` @@ -369,6 +372,9 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): lightning_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `lightning_module_type` initialization. Make sure all dict values have JSON serializable types. + similarity_module_kwargs: optional(dict), default ``None`` + Additional keyword arguments to pass during `similarity_module_type` initialization. + Make sure all dict values have JSON serializable types. """ config_class = SASRecModelConfig @@ -399,6 +405,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, # SASRec authors net data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, + similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -410,6 +417,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs: tp.Optional[InitKwargs] = None, pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, + similarity_module_kwargs: tp.Optional[InitKwargs] = None, ): super().__init__( transformer_layers_type=transformer_layers_type, @@ -436,6 +444,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_n_threads=recommend_n_threads, recommend_use_torch_ranking=recommend_use_torch_ranking, train_min_user_interactions=train_min_user_interactions, + similarity_module_type=similarity_module_type, item_net_block_types=item_net_block_types, item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, @@ -447,4 +456,5 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_kwargs=item_net_constructor_kwargs, pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, + similarity_module_kwargs=similarity_module_kwargs, ) diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py new file mode 100644 index 00000000..ec37ba14 --- /dev/null +++ b/rectools/models/nn/transformers/similarity.py @@ -0,0 +1,121 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import numpy as np +import torch +from scipy import sparse + +from rectools.models.base import InternalRecoTriplet +from rectools.models.rank import Distance, TorchRanker +from rectools.types import InternalIdsArray + + +class SimilarityModuleBase(torch.nn.Module): + """Similarity module base.""" + + def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + def _get_pos_neg_logits( + self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: torch.Tensor + ) -> torch.Tensor: + raise NotImplementedError() + + def forward( + self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Forward pass to get logits.""" + raise NotImplementedError() + + def _recommend_u2i( + self, + user_embs: torch.Tensor, + item_embs: torch.Tensor, + user_ids: InternalIdsArray, + k: int, + sorted_item_ids_to_recommend: InternalIdsArray, + ui_csr_for_filter: tp.Optional[sparse.csr_matrix], + ) -> InternalRecoTriplet: + """Recommend to users.""" + raise NotImplementedError() + + +class DistanceSimilarityModule(SimilarityModuleBase): + """Distance similarity module.""" + + dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE] + epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) + + def __init__(self, distance: str = "dot") -> None: + super().__init__() + if distance not in self.dist_available: + raise ValueError("`dist` can only be either `dot` or `cosine`.") + + self.distance = Distance(distance) + + def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + logits = session_embs @ item_embs.T + return logits + + def _get_pos_neg_logits( + self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: torch.Tensor + ) -> torch.Tensor: + pos_neg_embs = item_embs[item_ids] # [batch_size, session_max_len, len(item_ids), n_factors] + # [batch_size, session_max_len,len(item_ids)] + logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) + return logits + + def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: + embedding_norm = torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) + embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist.to(embeddings)) + return embeddings + + def forward( + self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Forward pass to get logits.""" + if self.distance == Distance.COSINE: + session_embs = self._get_embeddings_norm(session_embs) + item_embs = self._get_embeddings_norm(item_embs) + + if item_ids is None: + return self._get_full_catalog_logits(session_embs, item_embs) + return self._get_pos_neg_logits(session_embs, item_embs, item_ids) + + def _recommend_u2i( + self, + user_embs: torch.Tensor, + item_embs: torch.Tensor, + user_ids: InternalIdsArray, + k: int, + sorted_item_ids_to_recommend: InternalIdsArray, + ui_csr_for_filter: tp.Optional[sparse.csr_matrix], + ) -> InternalRecoTriplet: + """Recommend to users.""" + ranker = TorchRanker( + distance=self.distance, + device=item_embs.device, + subjects_factors=user_embs[user_ids], + objects_factors=item_embs, + ) + user_ids_indices, all_reco_ids, all_scores = ranker.rank( + subject_ids=np.arange(len(user_ids)), # n_rec_users + k=k, + filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + n_item_extra_tokens] + sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal + ) + all_user_ids = user_ids[user_ids_indices] + return all_user_ids, all_reco_ids, all_scores diff --git a/rectools/models/nn/transformers/torch_backbone.py b/rectools/models/nn/transformers/torch_backbone.py index e302ded8..ea55962f 100644 --- a/rectools/models/nn/transformers/torch_backbone.py +++ b/rectools/models/nn/transformers/torch_backbone.py @@ -18,6 +18,7 @@ from ..item_net import ItemNetBase from .net_blocks import PositionalEncodingBase, TransformerLayersBase +from .similarity import SimilarityModuleBase class TransformerTorchBackbone(torch.nn.Module): @@ -36,6 +37,8 @@ class TransformerTorchBackbone(torch.nn.Module): Positional encoding layer. transformer_layers : TransformerLayersBase Transformer layers. + similarity_module : SimilarityModuleBase + Similarity module. use_causal_attn : bool, default True If ``True``, causal mask is used in multi-head self-attention. use_key_padding_mask : bool, default False @@ -49,6 +52,7 @@ def __init__( item_model: ItemNetBase, pos_encoding_layer: PositionalEncodingBase, transformer_layers: TransformerLayersBase, + similarity_module: SimilarityModuleBase, use_causal_attn: bool = True, use_key_padding_mask: bool = False, ) -> None: @@ -58,6 +62,7 @@ def __init__( self.pos_encoding_layer = pos_encoding_layer self.emb_dropout = torch.nn.Dropout(dropout_rate) self.transformer_layers = transformer_layers + self.similarity_module = similarity_module self.use_causal_attn = use_causal_attn self.use_key_padding_mask = use_key_padding_mask self.n_heads = n_heads @@ -157,8 +162,9 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to def forward( self, - sessions: torch.Tensor, # [batch_size, session_max_len] - ) -> tp.Tuple[torch.Tensor, torch.Tensor]: + sessions: torch.Tensor, # [batch_size, session_max_len], + item_ids: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Forward pass to get item and session embeddings. Get item embeddings. @@ -168,11 +174,14 @@ def forward( ---------- sessions : torch.Tensor User sessions in the form of sequences of items ids. + item_ids : optional(torch.Tensor), default ``None`` + Defined item ids for similarity calculation. Returns ------- - (torch.Tensor, torch.Tensor) + torch.Tensor """ item_embs = self.item_model.get_all_embeddings() # [n_items + n_item_extra_tokens, n_factors] session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] - return item_embs, session_embs + logits = self.similarity_module(session_embs, item_embs, item_ids) + return logits diff --git a/rectools/models/rank/rank.py b/rectools/models/rank/rank.py index ab79f80d..10c217f3 100644 --- a/rectools/models/rank/rank.py +++ b/rectools/models/rank/rank.py @@ -22,12 +22,12 @@ from rectools.types import InternalIdsArray -class Distance(Enum): +class Distance(str, Enum): """Distance metric""" - DOT = 1 # Bigger value means closer vectors - COSINE = 2 # Bigger value means closer vectors - EUCLIDEAN = 3 # Smaller value means closer vectors + DOT = "dot" # Bigger value means closer vectors + COSINE = "cosine" # Bigger value means closer vectors + EUCLIDEAN = "euclidean" # Smaller value means closer vectors class Ranker(tp.Protocol): diff --git a/rectools/models/rank/rank_torch.py b/rectools/models/rank/rank_torch.py index ed091acc..6a0c1f2c 100644 --- a/rectools/models/rank/rank_torch.py +++ b/rectools/models/rank/rank_torch.py @@ -54,6 +54,8 @@ class TorchRanker: Conversion is skipped if provided dtype is ``None``. """ + epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) + def __init__( self, distance: Distance, @@ -194,8 +196,12 @@ def _euclid_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> tor return torch.cdist(user_embs.unsqueeze(0), item_embs.unsqueeze(0)).squeeze(0) def _cosine_score(self, user_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: - user_embs = user_embs / torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1) - item_embs = item_embs / torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1) + user_embs = user_embs / torch.max( + torch.norm(user_embs, p=2, dim=1).unsqueeze(dim=1), self.epsilon_cosine_dist.to(user_embs) + ) + item_embs = item_embs / torch.max( + torch.norm(item_embs, p=2, dim=1).unsqueeze(dim=1), self.epsilon_cosine_dist.to(user_embs) + ) return user_embs @ item_embs.T diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index bc61bea5..a29d8f24 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -26,7 +26,6 @@ from rectools import Columns from rectools.dataset import Dataset from rectools.models import BERT4RecModel, SASRecModel, load_model -from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet from rectools.models.nn.transformers.base import TransformerModelBase from tests.models.data import INTERACTIONS from tests.models.utils import assert_save_load_do_not_change_model @@ -109,10 +108,7 @@ def test_save_load_for_unfitted_model( dataset: Dataset, default_trainer: bool, ) -> None: - config = { - "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), - } + config: tp.Dict[str, tp.Any] = {"deterministic": True} if not default_trainer: config["get_trainer_func"] = custom_trainer model = model_cls.from_config(config) @@ -147,10 +143,7 @@ def test_save_load_for_fitted_model( dataset_item_features: Dataset, default_trainer: bool, ) -> None: - config = { - "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), - } + config: tp.Dict[str, tp.Any] = {"deterministic": True} if not default_trainer: config["get_trainer_func"] = custom_trainer model = model_cls.from_config(config) @@ -169,7 +162,6 @@ def test_load_from_checkpoint( model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, } ) @@ -195,7 +187,6 @@ def test_raises_when_save_model_loaded_from_checkpoint( model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, } ) @@ -219,7 +210,6 @@ def test_load_weights_from_checkpoint( model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_multiple_ckpt, } ) @@ -244,7 +234,6 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, } ) @@ -257,7 +246,6 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( model_unfitted = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), "get_trainer_func": custom_trainer_ckpt, } ) @@ -317,3 +305,14 @@ def test_log_metrics( actual_columns = list(pd.read_csv(metrics_path).columns) assert actual_columns == expected_columns + + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + def test_raises_when_incorrect_similarity_dist( + self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset + ) -> None: + model_config = { + "similarity_module_kwargs": {"distance": "euclidean"}, + } + with pytest.raises(ValueError): + model = model_cls.from_config(model_config) + model.fit(dataset=dataset) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 62a73d83..a8121c3d 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -33,6 +33,7 @@ TransformerLightningModule, ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable +from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -212,6 +213,7 @@ def get_trainer() -> Trainer: ), ), ) + @pytest.mark.parametrize("u2i_dist", ("dot", "cosine")) def test_u2i( self, dataset_devices: Dataset, @@ -223,6 +225,7 @@ def test_u2i( expected_cpu_2: pd.DataFrame, expected_gpu_1: pd.DataFrame, expected_gpu_2: pd.DataFrame, + u2i_dist: str, ) -> None: if n_devices != 1: pytest.skip("DEBUG: skipping multi-device tests") @@ -249,6 +252,8 @@ def get_trainer() -> Trainer: recommend_torch_device=recommend_torch_device, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer, + similarity_module_type=DistanceSimilarityModule, + similarity_module_kwargs={"distance": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -292,12 +297,14 @@ def get_trainer() -> Trainer: ), ), ) + @pytest.mark.parametrize("u2i_dist", ("dot", "cosine")) def test_u2i_losses( self, dataset_devices: Dataset, loss: str, get_trainer_func: TrainerCallable, expected: pd.DataFrame, + u2i_dist: str, ) -> None: model = BERT4RecModel( n_negatives=2, @@ -313,6 +320,8 @@ def test_u2i_losses( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, loss=loss, + similarity_module_type=DistanceSimilarityModule, + similarity_module_kwargs={"distance": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -366,6 +375,7 @@ def test_with_whitelist( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -440,6 +450,7 @@ def test_i2i( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) target_items = np.array([12, 14, 17]) @@ -466,6 +477,7 @@ def test_second_fit_refits_model(self, dataset_hot_users_items: Dataset) -> None deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=custom_trainer, + similarity_module_type=DistanceSimilarityModule, ) assert_second_fit_refits_model(model, dataset_hot_users_items, pre_fit_callback=self._seed_everything) @@ -508,6 +520,7 @@ def test_recommend_for_cold_user_with_hot_item( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset_devices) users = np.array([20]) @@ -528,7 +541,7 @@ class NextActionDataPreparator(BERT4RecDataPreparator): def __init__( self, session_max_len: int, - n_negatives: tp.Optional[int], + n_negatives: int, batch_size: int, dataloader_num_workers: int, train_min_user_interactions: int, @@ -585,6 +598,7 @@ def _collate_fn_train( get_trainer_func=get_trainer_func, data_preparator_type=NextActionDataPreparator, data_preparator_kwargs={"n_last_targets": 1}, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset_devices) @@ -829,6 +843,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "transformer_layers_type": PreLNTransformerLayers, "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": TransformerLightningModule, + "similarity_module_type": DistanceSimilarityModule, "mask_prob": 0.15, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, @@ -837,6 +852,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, + "similarity_module_kwargs": None, } return config @@ -876,6 +892,7 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", + "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", } expected.update(simple_types_params) if use_custom_trainer: diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 58442de3..189d4e53 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -35,6 +35,7 @@ TransformerTorchBackbone, ) from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers +from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -243,6 +244,7 @@ def get_trainer() -> Trainer: ), ), ) + @pytest.mark.parametrize("u2i_dist", ("dot", "cosine")) def test_u2i( self, dataset_devices: Dataset, @@ -253,6 +255,7 @@ def test_u2i( expected_cpu_1: pd.DataFrame, expected_cpu_2: pd.DataFrame, expected_gpu: pd.DataFrame, + u2i_dist: str, ) -> None: if devices != 1: @@ -280,6 +283,8 @@ def get_trainer() -> Trainer: recommend_torch_device=recommend_torch_device, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer, + similarity_module_type=DistanceSimilarityModule, + similarity_module_kwargs={"distance": u2i_dist}, ) model.fit(dataset=dataset_devices) users = np.array([10, 30, 40]) @@ -297,7 +302,7 @@ def get_trainer() -> Trainer: ) @pytest.mark.parametrize( - "loss,expected", + "loss,expected,u2i_dist", ( ( "BCE", @@ -308,6 +313,7 @@ def get_trainer() -> Trainer: Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), + "dot", ), ( "gBCE", @@ -318,6 +324,29 @@ def get_trainer() -> Trainer: Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), + "dot", + ), + ( + "BCE", + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 13, 14, 17, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + "cosine", + ), + ( + "gBCE", + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 13, 14, 17, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + "cosine", ), ), ) @@ -327,6 +356,7 @@ def test_u2i_losses( loss: str, get_trainer_func: TrainerCallable, expected: pd.DataFrame, + u2i_dist: str, ) -> None: model = SASRecModel( n_negatives=2, @@ -340,6 +370,8 @@ def test_u2i_losses( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, loss=loss, + similarity_module_type=DistanceSimilarityModule, + similarity_module_kwargs={"distance": u2i_dist}, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -380,6 +412,7 @@ def test_u2i_with_key_and_attn_masks( item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, use_key_padding_mask=True, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -420,6 +453,7 @@ def test_u2i_with_item_features( item_net_block_types=(IdEmbeddingsItemNet, CatFeaturesItemNet), get_trainer_func=get_trainer_func, use_key_padding_mask=True, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset_item_features) users = np.array([10, 30, 40]) @@ -472,6 +506,7 @@ def test_with_whitelist( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) @@ -545,6 +580,7 @@ def test_i2i( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) target_items = np.array([12, 14, 17]) @@ -571,6 +607,7 @@ def test_second_fit_refits_model(self, dataset_hot_users_items: Dataset) -> None deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=custom_trainer, + similarity_module_type=DistanceSimilarityModule, ) assert_second_fit_refits_model(model, dataset_hot_users_items, pre_fit_callback=self._seed_everything) @@ -612,6 +649,7 @@ def test_recommend_for_cold_user_with_hot_item( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) users = np.array([20]) @@ -665,6 +703,7 @@ def test_warn_when_hot_user_has_cold_items_in_recommend( deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), get_trainer_func=get_trainer_func, + similarity_module_type=DistanceSimilarityModule, ) model.fit(dataset=dataset) users = np.array([10, 20, 50]) @@ -688,12 +727,12 @@ def test_warn_when_hot_user_has_cold_items_in_recommend( ) def test_raises_when_loss_is_not_supported(self, dataset: Dataset) -> None: - model = SASRecModel(loss="gbce") + model = SASRecModel(loss="gbce", similarity_module_type=DistanceSimilarityModule) with pytest.raises(ValueError): model.fit(dataset=dataset) def test_torch_model(self, dataset: Dataset) -> None: - model = SASRecModel() + model = SASRecModel(similarity_module_type=DistanceSimilarityModule) model.fit(dataset) assert isinstance(model.torch_model, TransformerTorchBackbone) @@ -901,6 +940,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "transformer_layers_type": SASRecTransformerLayers, "data_preparator_type": SASRecDataPreparator, "lightning_module_type": TransformerLightningModule, + "similarity_module_type": DistanceSimilarityModule, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, "data_preparator_kwargs": None, @@ -908,6 +948,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, + "similarity_module_kwargs": None, } return config @@ -947,6 +988,7 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.sasrec.SASRecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", + "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", } expected.update(simple_types_params) if use_custom_trainer: