diff --git a/CHANGELOG.md b/CHANGELOG.md index 4815ac75..8292b9d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) +- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone` parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) - `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276)) +- `TransformerBackboneBase`, `backbone_type` and `backbone_kwargs` parameters to transformer-based models ([#277](https://github.com/MobileTeleSystems/RecTools/pull/277)) - `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274)) ## [0.12.0] - 24.02.2025 diff --git a/examples/tutorials/transformers_advanced_training_guide.ipynb b/examples/tutorials/transformers_advanced_training_guide.ipynb index cb1e243a..f3b0a60d 100644 --- a/examples/tutorials/transformers_advanced_training_guide.ipynb +++ b/examples/tutorials/transformers_advanced_training_guide.ipynb @@ -412,15 +412,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "epoch,step,train_loss,val_loss\r", + "epoch,step,train_loss,val_loss\r\n", "\r\n", - "0,1,,22.365339279174805\r", + "0,1,,22.365339279174805\r\n", "\r\n", - "0,1,22.38391876220703,\r", + "0,1,22.38391876220703,\r\n", "\r\n", - "1,3,,22.189851760864258\r", + "1,3,,22.189851760864258\r\n", "\r\n", - "1,3,22.898216247558594,\r", + "1,3,22.898216247558594,\r\n", "\r\n" ] } @@ -526,23 +526,23 @@ "name": "stdout", "output_type": "stream", "text": [ - "epoch,step,train_loss,val_loss\r", + "epoch,step,train_loss,val_loss\r\n", "\r\n", - "0,1,,22.343637466430664\r", + "0,1,,22.343637466430664\r\n", "\r\n", - "0,1,22.36273765563965,\r", + "0,1,22.36273765563965,\r\n", "\r\n", - "1,3,,22.159835815429688\r", + "1,3,,22.159835815429688\r\n", "\r\n", - "1,3,22.33755874633789,\r", + "1,3,22.33755874633789,\r\n", "\r\n", - "2,5,,21.94308853149414\r", + "2,5,,21.94308853149414\r\n", "\r\n", - "2,5,22.244243621826172,\r", + "2,5,22.244243621826172,\r\n", "\r\n", - "3,7,,21.702259063720703\r", + "3,7,,21.702259063720703\r\n", "\r\n", - "3,7,22.196012496948242,\r", + "3,7,22.196012496948242,\r\n", "\r\n" ] } @@ -898,7 +898,7 @@ " ) -> None:\n", " logits = outputs[\"logits\"]\n", " if logits is None:\n", - " logits = pl_module.torch_model.encode_sessions(batch[\"x\"], pl_module.item_embs)[:, -1, :]\n", + " logits = pl_module.torch_model.encode_sessions(batch, pl_module.item_embs)[:, -1, :]\n", " _, sorted_batch_recos = logits.topk(k=self.top_k)\n", "\n", " batch_recos = sorted_batch_recos.tolist()\n", @@ -2039,9 +2039,9 @@ ], "metadata": { "kernelspec": { - "display_name": "rectools", + "display_name": ".venv", "language": "python", - "name": "rectools" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -2053,7 +2053,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index b8b1208e..fec4a99b 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -47,7 +47,7 @@ TransformerLayersBase, ) from .similarity import DistanceSimilarityModule, SimilarityModuleBase -from .torch_backbone import TransformerTorchBackbone +from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone InitKwargs = tp.Dict[str, tp.Any] @@ -108,6 +108,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: ), ] +TransformerBackboneType = tpe.Annotated[ + tp.Type[TransformerBackboneBase], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + TransformerDataPreparatorType = tpe.Annotated[ tp.Type[TransformerDataPreparatorBase], BeforeValidator(_get_class_obj), @@ -195,6 +205,7 @@ class TransformerModelConfig(ModelConfig): transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: TransformerLightningModuleType = TransformerLightningModule similarity_module_type: SimilarityModuleType = DistanceSimilarityModule + backbone_type: TransformerBackboneType = TransformerTorchBackbone get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None @@ -203,6 +214,7 @@ class TransformerModelConfig(ModelConfig): pos_encoding_kwargs: tp.Optional[InitKwargs] = None lightning_module_kwargs: tp.Optional[InitKwargs] = None similarity_module_kwargs: tp.Optional[InitKwargs] = None + backbone_kwargs: tp.Optional[InitKwargs] = None TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig) @@ -251,6 +263,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, + backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, data_preparator_kwargs: tp.Optional[InitKwargs] = None, @@ -259,6 +272,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, + backbone_kwargs: tp.Optional[InitKwargs] = None, **kwargs: tp.Any, ) -> None: super().__init__(verbose=verbose) @@ -288,6 +302,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.item_net_constructor_type = item_net_constructor_type self.pos_encoding_type = pos_encoding_type self.lightning_module_type = lightning_module_type + self.backbone_type = backbone_type self.get_val_mask_func = get_val_mask_func self.get_trainer_func = get_trainer_func self.data_preparator_kwargs = data_preparator_kwargs @@ -296,6 +311,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.pos_encoding_kwargs = pos_encoding_kwargs self.lightning_module_kwargs = lightning_module_kwargs self.similarity_module_kwargs = similarity_module_kwargs + self.backbone_kwargs = backbone_kwargs self._init_data_preparator() self._init_trainer() @@ -377,11 +393,11 @@ def _init_transformer_layers(self) -> TransformerLayersBase: def _init_similarity_module(self) -> SimilarityModuleBase: return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs)) - def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone: + def _init_torch_model(self, item_model: ItemNetBase) -> TransformerBackboneBase: pos_encoding_layer = self._init_pos_encoding_layer() transformer_layers = self._init_transformer_layers() similarity_module = self._init_similarity_module() - return TransformerTorchBackbone( + return self.backbone_type( n_heads=self.n_heads, dropout_rate=self.dropout_rate, item_model=item_model, @@ -390,11 +406,12 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone similarity_module=similarity_module, use_causal_attn=self.use_causal_attn, use_key_padding_mask=self.use_key_padding_mask, + **self._get_kwargs(self.backbone_kwargs), ) def _init_lightning_model( self, - torch_model: TransformerTorchBackbone, + torch_model: TransformerBackboneBase, dataset_schema: DatasetSchemaDict, item_external_ids: ExternalIds, model_config: tp.Dict[str, tp.Any], @@ -490,7 +507,7 @@ def _recommend_i2i( ) @property - def torch_model(self) -> TransformerTorchBackbone: + def torch_model(self) -> TransformerBackboneBase: """Pytorch model.""" return self.lightning_model.torch_model diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 44eb9cff..1b3ea6ed 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -45,6 +45,7 @@ TransformerLayersBase, ) from .similarity import DistanceSimilarityModule, SimilarityModuleBase +from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone class BERT4RecDataPreparator(TransformerDataPreparatorBase): @@ -259,6 +260,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): Type of lightning module defining training procedure. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. + backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` + Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` @@ -295,6 +298,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): similarity_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `similarity_module_type` initialization. Make sure all dict values have JSON serializable types. + backbone_kwargs: optional(dict), default ``None`` + Additional keyword arguments to pass during `backbone_type` initialization. + Make sure all dict values have JSON serializable types. """ config_class = BERT4RecModelConfig @@ -327,6 +333,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, + backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -340,6 +347,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, + backbone_kwargs: tp.Optional[InitKwargs] = None, ): self.mask_prob = mask_prob @@ -373,6 +381,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, + backbone_type=backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, data_preparator_kwargs=data_preparator_kwargs, @@ -382,6 +391,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, similarity_module_kwargs=similarity_module_kwargs, + backbone_kwargs=backbone_kwargs, ) def _init_data_preparator(self) -> None: diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index d99f3556..74365726 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -26,7 +26,7 @@ from rectools.types import InternalIdsArray from .data_preparator import TransformerDataPreparatorBase -from .torch_backbone import TransformerTorchBackbone +from .torch_backbone import TransformerBackboneBase # #### -------------- Lightning Base Model -------------- #### # @@ -38,7 +38,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma Parameters ---------- - torch_model : TransformerTorchBackbone + torch_model : TransformerBackboneBase Torch model to make recommendations. lr : float Learning rate. @@ -61,7 +61,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma def __init__( self, - torch_model: TransformerTorchBackbone, + torch_model: TransformerBackboneBase, model_config: tp.Dict[str, tp.Any], dataset_schema: DatasetSchemaDict, item_external_ids: ExternalIds, @@ -250,13 +250,12 @@ def on_train_start(self) -> None: def get_batch_logits(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: """Get bacth logits.""" - x = batch["x"] # x: [batch_size, session_max_len] if self._requires_negatives: y, negatives = batch["y"], batch["negatives"] pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) - logits = self.torch_model(sessions=x, item_ids=pos_neg) + logits = self.torch_model(batch=batch, candidate_item_ids=pos_neg) else: - logits = self.torch_model(sessions=x) + logits = self.torch_model(batch=batch) return logits def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -338,7 +337,8 @@ def _get_user_item_embeddings( item_embs = self.torch_model.item_model.get_all_embeddings() user_embs = [] for batch in recommend_dataloader: - batch_embs = self.torch_model.encode_sessions(batch["x"].to(device), item_embs)[:, -1, :] + batch = {k: v.to(device) for k, v in batch.items()} + batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :] user_embs.append(batch_embs) return torch.cat(user_embs), item_embs @@ -360,17 +360,14 @@ def _recommend_u2i( user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device) - all_user_ids, all_reco_ids, all_scores = ( - self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access - user_embs=user_embs, - item_embs=item_embs, - user_ids=user_ids, - k=k, - sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, - ui_csr_for_filter=ui_csr_for_filter, - ) + return self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access + user_embs=user_embs, + item_embs=item_embs, + user_ids=user_ids, + k=k, + sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, + ui_csr_for_filter=ui_csr_for_filter, ) - return all_user_ids, all_reco_ids, all_scores def _recommend_i2i( self, diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index b24f9079..a99d0dd2 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -45,6 +45,7 @@ TransformerLayersBase, ) from .similarity import DistanceSimilarityModule, SimilarityModuleBase +from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone class SASRecDataPreparator(TransformerDataPreparatorBase): @@ -339,6 +340,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): Type of lightning module defining training procedure. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. + backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` + Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` @@ -375,6 +378,9 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): similarity_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `similarity_module_type` initialization. Make sure all dict values have JSON serializable types. + backbone_kwargs: optional(dict), default ``None`` + Additional keyword arguments to pass during `backbone_type` initialization. + Make sure all dict values have JSON serializable types. """ config_class = SASRecModelConfig @@ -406,6 +412,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, + backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -418,6 +425,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, + backbone_kwargs: tp.Optional[InitKwargs] = None, ): super().__init__( transformer_layers_type=transformer_layers_type, @@ -449,6 +457,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, + backbone_type=backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, data_preparator_kwargs=data_preparator_kwargs, @@ -457,4 +466,5 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, similarity_module_kwargs=similarity_module_kwargs, + backbone_kwargs=backbone_kwargs, ) diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index ec37ba14..da1ac615 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -24,18 +24,21 @@ class SimilarityModuleBase(torch.nn.Module): - """Similarity module base.""" + """Base class for similarity module.""" def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: raise NotImplementedError() def _get_pos_neg_logits( - self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: torch.Tensor + self, session_embs: torch.Tensor, item_embs: torch.Tensor, candidate_item_ids: torch.Tensor ) -> torch.Tensor: raise NotImplementedError() def forward( - self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None + self, + session_embs: torch.Tensor, + item_embs: torch.Tensor, + candidate_item_ids: tp.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass to get logits.""" raise NotImplementedError() @@ -71,9 +74,10 @@ def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch. return logits def _get_pos_neg_logits( - self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: torch.Tensor + self, session_embs: torch.Tensor, item_embs: torch.Tensor, candidate_item_ids: torch.Tensor ) -> torch.Tensor: - pos_neg_embs = item_embs[item_ids] # [batch_size, session_max_len, len(item_ids), n_factors] + # [batch_size, session_max_len, len(candidate_item_ids), n_factors] + pos_neg_embs = item_embs[candidate_item_ids] # [batch_size, session_max_len,len(item_ids)] logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) return logits @@ -84,16 +88,19 @@ def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: return embeddings def forward( - self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None + self, + session_embs: torch.Tensor, + item_embs: torch.Tensor, + candidate_item_ids: tp.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass to get logits.""" if self.distance == Distance.COSINE: session_embs = self._get_embeddings_norm(session_embs) item_embs = self._get_embeddings_norm(item_embs) - if item_ids is None: + if candidate_item_ids is None: return self._get_full_catalog_logits(session_embs, item_embs) - return self._get_pos_neg_logits(session_embs, item_embs, item_ids) + return self._get_pos_neg_logits(session_embs, item_embs, candidate_item_ids) def _recommend_u2i( self, diff --git a/rectools/models/nn/transformers/torch_backbone.py b/rectools/models/nn/transformers/torch_backbone.py index ea55962f..6a78dc94 100644 --- a/rectools/models/nn/transformers/torch_backbone.py +++ b/rectools/models/nn/transformers/torch_backbone.py @@ -21,7 +21,101 @@ from .similarity import SimilarityModuleBase -class TransformerTorchBackbone(torch.nn.Module): +class TransformerBackboneBase(torch.nn.Module): + """Base class for transformer torch backbone.""" + + def __init__( + self, + n_heads: int, + dropout_rate: float, + item_model: ItemNetBase, + pos_encoding_layer: PositionalEncodingBase, + transformer_layers: TransformerLayersBase, + similarity_module: SimilarityModuleBase, + use_causal_attn: bool = True, + use_key_padding_mask: bool = False, + **kwargs: tp.Any, + ) -> None: + """ + Initialize transformer torch backbone. + + Parameters + ---------- + n_heads : int + Number of attention heads. + dropout_rate : float + Probability of a hidden unit to be zeroed. + item_model : ItemNetBase + Network for item embeddings. + pos_encoding_layer : PositionalEncodingBase + Positional encoding layer. + transformer_layers : TransformerLayersBase + Transformer layers. + similarity_module : SimilarityModuleBase + Similarity module. + use_causal_attn : bool, default True + If ``True``, causal mask is used in multi-head self-attention. + use_key_padding_mask : bool, default False + If ``True``, key padding mask is used in multi-head self-attention. + **kwargs : Any + Additional keyword arguments for future extensions. + """ + super().__init__() + + self.item_model = item_model + self.pos_encoding_layer = pos_encoding_layer + self.emb_dropout = torch.nn.Dropout(dropout_rate) + self.transformer_layers = transformer_layers + self.similarity_module = similarity_module + self.use_causal_attn = use_causal_attn + self.use_key_padding_mask = use_key_padding_mask + self.n_heads = n_heads + + def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Tensor) -> torch.Tensor: + """ + Pass user history through item embeddings. + Add positional encoding. + Pass history through transformer blocks. + + Parameters + ---------- + batch : Dict[str, torch.Tensor] + Dictionary containing user sessions data. + item_embs : torch.Tensor + Item embeddings. + + Returns + ------- + torch.Tensor. [batch_size, session_max_len, n_factors] + Encoded session embeddings. + """ + raise NotImplementedError() + + def forward( + self, + batch: tp.Dict[str, torch.Tensor], # batch["x"]: [batch_size, session_max_len] + candidate_item_ids: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass to get item and session embeddings. + Get item embeddings. + Pass user sessions through transformer blocks. + + Parameters + ---------- + batch : Dict[str, torch.Tensor] + Dictionary containing user sessions data, with "x" key containing session tensor. + candidate_item_ids : optional(torch.Tensor), default ``None`` + Defined item ids for similarity calculation. + + Returns + ------- + torch.Tensor + """ + raise NotImplementedError() + + +class TransformerTorchBackbone(TransformerBackboneBase): """ Torch model for encoding user sessions based on transformer architecture. @@ -43,6 +137,8 @@ class TransformerTorchBackbone(torch.nn.Module): If ``True``, causal mask is used in multi-head self-attention. use_key_padding_mask : bool, default False If ``True``, key padding mask is used in multi-head self-attention. + **kwargs : Any + Additional keyword arguments for future extensions. """ def __init__( @@ -55,17 +151,19 @@ def __init__( similarity_module: SimilarityModuleBase, use_causal_attn: bool = True, use_key_padding_mask: bool = False, + **kwargs: tp.Any, ) -> None: - super().__init__() - - self.item_model = item_model - self.pos_encoding_layer = pos_encoding_layer - self.emb_dropout = torch.nn.Dropout(dropout_rate) - self.transformer_layers = transformer_layers - self.similarity_module = similarity_module - self.use_causal_attn = use_causal_attn - self.use_key_padding_mask = use_key_padding_mask - self.n_heads = n_heads + super().__init__( + n_heads=n_heads, + dropout_rate=dropout_rate, + item_model=item_model, + pos_encoding_layer=pos_encoding_layer, + transformer_layers=transformer_layers, + similarity_module=similarity_module, + use_causal_attn=use_causal_attn, + use_key_padding_mask=use_key_padding_mask, + **kwargs, + ) @staticmethod def _convert_mask_to_float(mask: torch.Tensor, query: torch.Tensor) -> torch.Tensor: @@ -119,7 +217,7 @@ def _merge_masks( torch.diagonal(res, dim1=1, dim2=2).zero_() return res - def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Tensor) -> torch.Tensor: """ Pass user history through item embeddings. Add positional encoding. @@ -127,8 +225,8 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to Parameters ---------- - sessions : torch.Tensor - User sessions in the form of sequences of items ids. + batch : Dict[str, torch.Tensor] + Dictionary containing user sessions data. item_embs : torch.Tensor Item embeddings. @@ -137,6 +235,7 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to torch.Tensor. [batch_size, session_max_len, n_factors] Encoded session embeddings. """ + sessions = batch["x"] # [batch_size, session_max_len] session_max_len = sessions.shape[1] attn_mask = None key_padding_mask = None @@ -162,8 +261,8 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to def forward( self, - sessions: torch.Tensor, # [batch_size, session_max_len], - item_ids: tp.Optional[torch.Tensor] = None, + batch: tp.Dict[str, torch.Tensor], # batch["x"]: [batch_size, session_max_len] + candidate_item_ids: tp.Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass to get item and session embeddings. @@ -172,9 +271,9 @@ def forward( Parameters ---------- - sessions : torch.Tensor - User sessions in the form of sequences of items ids. - item_ids : optional(torch.Tensor), default ``None`` + batch : Dict[str, torch.Tensor] + Dictionary containing user sessions data, with "x" key containing session tensor. + candidate_item_ids : optional(torch.Tensor), default ``None`` Defined item ids for similarity calculation. Returns @@ -182,6 +281,6 @@ def forward( torch.Tensor """ item_embs = self.item_model.get_all_embeddings() # [n_items + n_item_extra_tokens, n_factors] - session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] - logits = self.similarity_module(session_embs, item_embs, item_ids) + session_embs = self.encode_sessions(batch, item_embs) # [batch_size, session_max_len, n_factors] + logits = self.similarity_module(session_embs, item_embs, candidate_item_ids) return logits diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 1c15c453..c2d2c3d6 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -34,6 +34,7 @@ ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable from rectools.models.nn.transformers.similarity import DistanceSimilarityModule +from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -854,6 +855,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": TransformerLightningModule, "similarity_module_type": DistanceSimilarityModule, + "backbone_type": TransformerTorchBackbone, "mask_prob": 0.15, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, @@ -863,6 +865,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "pos_encoding_kwargs": None, "lightning_module_kwargs": None, "similarity_module_kwargs": None, + "backbone_kwargs": None, } return config @@ -903,6 +906,7 @@ def test_get_config( "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", + "backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", } expected.update(simple_types_params) if use_custom_trainer: diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 385e3197..de5605bb 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -32,10 +32,10 @@ LearnableInversePositionalEncoding, TrainerCallable, TransformerLightningModule, - TransformerTorchBackbone, ) from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers from rectools.models.nn.transformers.similarity import DistanceSimilarityModule +from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -963,6 +963,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "data_preparator_type": SASRecDataPreparator, "lightning_module_type": TransformerLightningModule, "similarity_module_type": DistanceSimilarityModule, + "backbone_type": TransformerTorchBackbone, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, "data_preparator_kwargs": None, @@ -971,6 +972,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "pos_encoding_kwargs": None, "lightning_module_kwargs": None, "similarity_module_kwargs": None, + "backbone_kwargs": None, } return config @@ -1011,6 +1013,7 @@ def test_get_config( "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", + "backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", } expected.update(simple_types_params) if use_custom_trainer: