From 0642e84c78837cad97ec8a1b4b404b6e071e3801 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 2 Apr 2025 13:36:06 +0300 Subject: [PATCH 1/7] backbone and similarity --- ...transformers_advanced_training_guide.ipynb | 36 +++++++++---------- rectools/models/nn/transformers/base.py | 15 +++++++- rectools/models/nn/transformers/bert4rec.py | 5 +++ rectools/models/nn/transformers/lightning.py | 8 ++--- rectools/models/nn/transformers/sasrec.py | 5 +++ rectools/models/nn/transformers/similarity.py | 23 +++++++----- .../models/nn/transformers/torch_backbone.py | 21 +++++------ tests/models/nn/transformers/test_bert4rec.py | 5 ++- tests/models/nn/transformers/test_sasrec.py | 4 ++- 9 files changed, 79 insertions(+), 43 deletions(-) diff --git a/examples/tutorials/transformers_advanced_training_guide.ipynb b/examples/tutorials/transformers_advanced_training_guide.ipynb index cb1e243a..f3b0a60d 100644 --- a/examples/tutorials/transformers_advanced_training_guide.ipynb +++ b/examples/tutorials/transformers_advanced_training_guide.ipynb @@ -412,15 +412,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "epoch,step,train_loss,val_loss\r", + "epoch,step,train_loss,val_loss\r\n", "\r\n", - "0,1,,22.365339279174805\r", + "0,1,,22.365339279174805\r\n", "\r\n", - "0,1,22.38391876220703,\r", + "0,1,22.38391876220703,\r\n", "\r\n", - "1,3,,22.189851760864258\r", + "1,3,,22.189851760864258\r\n", "\r\n", - "1,3,22.898216247558594,\r", + "1,3,22.898216247558594,\r\n", "\r\n" ] } @@ -526,23 +526,23 @@ "name": "stdout", "output_type": "stream", "text": [ - "epoch,step,train_loss,val_loss\r", + "epoch,step,train_loss,val_loss\r\n", "\r\n", - "0,1,,22.343637466430664\r", + "0,1,,22.343637466430664\r\n", "\r\n", - "0,1,22.36273765563965,\r", + "0,1,22.36273765563965,\r\n", "\r\n", - "1,3,,22.159835815429688\r", + "1,3,,22.159835815429688\r\n", "\r\n", - "1,3,22.33755874633789,\r", + "1,3,22.33755874633789,\r\n", "\r\n", - "2,5,,21.94308853149414\r", + "2,5,,21.94308853149414\r\n", "\r\n", - "2,5,22.244243621826172,\r", + "2,5,22.244243621826172,\r\n", "\r\n", - "3,7,,21.702259063720703\r", + "3,7,,21.702259063720703\r\n", "\r\n", - "3,7,22.196012496948242,\r", + "3,7,22.196012496948242,\r\n", "\r\n" ] } @@ -898,7 +898,7 @@ " ) -> None:\n", " logits = outputs[\"logits\"]\n", " if logits is None:\n", - " logits = pl_module.torch_model.encode_sessions(batch[\"x\"], pl_module.item_embs)[:, -1, :]\n", + " logits = pl_module.torch_model.encode_sessions(batch, pl_module.item_embs)[:, -1, :]\n", " _, sorted_batch_recos = logits.topk(k=self.top_k)\n", "\n", " batch_recos = sorted_batch_recos.tolist()\n", @@ -2039,9 +2039,9 @@ ], "metadata": { "kernelspec": { - "display_name": "rectools", + "display_name": ".venv", "language": "python", - "name": "rectools" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -2053,7 +2053,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index b8b1208e..0ce25d7e 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -108,6 +108,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: ), ] +TransformerTorchBackboneType = tpe.Annotated[ + tp.Type[TransformerTorchBackbone], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + TransformerDataPreparatorType = tpe.Annotated[ tp.Type[TransformerDataPreparatorBase], BeforeValidator(_get_class_obj), @@ -195,6 +205,7 @@ class TransformerModelConfig(ModelConfig): transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: TransformerLightningModuleType = TransformerLightningModule similarity_module_type: SimilarityModuleType = DistanceSimilarityModule + torch_backbone_type: TransformerTorchBackboneType = TransformerTorchBackbone get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None @@ -251,6 +262,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, + torch_backbone_type: tp.Type[TransformerTorchBackbone] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, data_preparator_kwargs: tp.Optional[InitKwargs] = None, @@ -288,6 +300,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.item_net_constructor_type = item_net_constructor_type self.pos_encoding_type = pos_encoding_type self.lightning_module_type = lightning_module_type + self.torch_backbone_type = torch_backbone_type self.get_val_mask_func = get_val_mask_func self.get_trainer_func = get_trainer_func self.data_preparator_kwargs = data_preparator_kwargs @@ -381,7 +394,7 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone pos_encoding_layer = self._init_pos_encoding_layer() transformer_layers = self._init_transformer_layers() similarity_module = self._init_similarity_module() - return TransformerTorchBackbone( + return self.torch_backbone_type( n_heads=self.n_heads, dropout_rate=self.dropout_rate, item_model=item_model, diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index c2f2d814..fc71d958 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -45,6 +45,7 @@ TransformerLayersBase, ) from .similarity import DistanceSimilarityModule, SimilarityModuleBase +from .torch_backbone import TransformerTorchBackbone class BERT4RecDataPreparator(TransformerDataPreparatorBase): @@ -259,6 +260,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): Type of lightning module defining training procedure. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. + torch_backbone_type : type(TransformerTorchBackbone), default `TransformerTorchBackbone` + Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` @@ -327,6 +330,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, + torch_backbone_type: tp.Type[TransformerTorchBackbone] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -373,6 +377,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, + torch_backbone_type=torch_backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, data_preparator_kwargs=data_preparator_kwargs, diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 5bfb660d..2d1c043f 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -240,13 +240,12 @@ def on_train_start(self) -> None: def get_batch_logits(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: """Get bacth logits.""" - x = batch["x"] # x: [batch_size, session_max_len] if self._requires_negatives: y, negatives = batch["y"], batch["negatives"] pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1) - logits = self.torch_model(sessions=x, item_ids=pos_neg) + logits = self.torch_model(batch=batch, candidate_item_ids=pos_neg) else: - logits = self.torch_model(sessions=x) + logits = self.torch_model(batch=batch) return logits def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -328,7 +327,8 @@ def _get_user_item_embeddings( item_embs = self.torch_model.item_model.get_all_embeddings() user_embs = [] for batch in recommend_dataloader: - batch_embs = self.torch_model.encode_sessions(batch["x"].to(device), item_embs)[:, -1, :] + batch = {k: v.to(device) for k, v in batch.items()} + batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :] user_embs.append(batch_embs) return torch.cat(user_embs), item_embs diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 4bc36907..001b23ad 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -45,6 +45,7 @@ TransformerLayersBase, ) from .similarity import DistanceSimilarityModule, SimilarityModuleBase +from .torch_backbone import TransformerTorchBackbone class SASRecDataPreparator(TransformerDataPreparatorBase): @@ -339,6 +340,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): Type of lightning module defining training procedure. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. + torch_backbone_type : type(TransformerTorchBackbone), default `TransformerTorchBackbone` + Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. get_trainer_func : Callable, default ``None`` @@ -406,6 +409,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, + torch_backbone_type: tp.Type[TransformerTorchBackbone] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -449,6 +453,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, + torch_backbone_type=torch_backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, data_preparator_kwargs=data_preparator_kwargs, diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index ec37ba14..da1ac615 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -24,18 +24,21 @@ class SimilarityModuleBase(torch.nn.Module): - """Similarity module base.""" + """Base class for similarity module.""" def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: raise NotImplementedError() def _get_pos_neg_logits( - self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: torch.Tensor + self, session_embs: torch.Tensor, item_embs: torch.Tensor, candidate_item_ids: torch.Tensor ) -> torch.Tensor: raise NotImplementedError() def forward( - self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None + self, + session_embs: torch.Tensor, + item_embs: torch.Tensor, + candidate_item_ids: tp.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass to get logits.""" raise NotImplementedError() @@ -71,9 +74,10 @@ def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch. return logits def _get_pos_neg_logits( - self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: torch.Tensor + self, session_embs: torch.Tensor, item_embs: torch.Tensor, candidate_item_ids: torch.Tensor ) -> torch.Tensor: - pos_neg_embs = item_embs[item_ids] # [batch_size, session_max_len, len(item_ids), n_factors] + # [batch_size, session_max_len, len(candidate_item_ids), n_factors] + pos_neg_embs = item_embs[candidate_item_ids] # [batch_size, session_max_len,len(item_ids)] logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) return logits @@ -84,16 +88,19 @@ def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: return embeddings def forward( - self, session_embs: torch.Tensor, item_embs: torch.Tensor, item_ids: tp.Optional[torch.Tensor] = None + self, + session_embs: torch.Tensor, + item_embs: torch.Tensor, + candidate_item_ids: tp.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass to get logits.""" if self.distance == Distance.COSINE: session_embs = self._get_embeddings_norm(session_embs) item_embs = self._get_embeddings_norm(item_embs) - if item_ids is None: + if candidate_item_ids is None: return self._get_full_catalog_logits(session_embs, item_embs) - return self._get_pos_neg_logits(session_embs, item_embs, item_ids) + return self._get_pos_neg_logits(session_embs, item_embs, candidate_item_ids) def _recommend_u2i( self, diff --git a/rectools/models/nn/transformers/torch_backbone.py b/rectools/models/nn/transformers/torch_backbone.py index ea55962f..c29c0189 100644 --- a/rectools/models/nn/transformers/torch_backbone.py +++ b/rectools/models/nn/transformers/torch_backbone.py @@ -119,7 +119,7 @@ def _merge_masks( torch.diagonal(res, dim1=1, dim2=2).zero_() return res - def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Tensor) -> torch.Tensor: """ Pass user history through item embeddings. Add positional encoding. @@ -127,8 +127,8 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to Parameters ---------- - sessions : torch.Tensor - User sessions in the form of sequences of items ids. + batch : Dict[str, torch.Tensor] + Dictionary containing user sessions data. item_embs : torch.Tensor Item embeddings. @@ -137,6 +137,7 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to torch.Tensor. [batch_size, session_max_len, n_factors] Encoded session embeddings. """ + sessions = batch["x"] # [batch_size, session_max_len] session_max_len = sessions.shape[1] attn_mask = None key_padding_mask = None @@ -162,8 +163,8 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to def forward( self, - sessions: torch.Tensor, # [batch_size, session_max_len], - item_ids: tp.Optional[torch.Tensor] = None, + batch: tp.Dict[str, torch.Tensor], # batch["x"]: [batch_size, session_max_len] + candidate_item_ids: tp.Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass to get item and session embeddings. @@ -172,9 +173,9 @@ def forward( Parameters ---------- - sessions : torch.Tensor - User sessions in the form of sequences of items ids. - item_ids : optional(torch.Tensor), default ``None`` + batch : Dict[str, torch.Tensor] + Dictionary containing user sessions data, with "x" key containing session tensor. + candidate_item_ids : optional(torch.Tensor), default ``None`` Defined item ids for similarity calculation. Returns @@ -182,6 +183,6 @@ def forward( torch.Tensor """ item_embs = self.item_model.get_all_embeddings() # [n_items + n_item_extra_tokens, n_factors] - session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_max_len, n_factors] - logits = self.similarity_module(session_embs, item_embs, item_ids) + session_embs = self.encode_sessions(batch, item_embs) # [batch_size, session_max_len, n_factors] + logits = self.similarity_module(session_embs, item_embs, candidate_item_ids) return logits diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index a8121c3d..ed37b1e1 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -32,8 +32,9 @@ TrainerCallable, TransformerLightningModule, ) -from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable +from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable from rectools.models.nn.transformers.similarity import DistanceSimilarityModule +from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -844,6 +845,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": TransformerLightningModule, "similarity_module_type": DistanceSimilarityModule, + "torch_backbone_type": TransformerTorchBackbone, "mask_prob": 0.15, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, @@ -893,6 +895,7 @@ def test_get_config( "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", + "torch_backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", } expected.update(simple_types_params) if use_custom_trainer: diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 189d4e53..1bb666de 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -32,10 +32,10 @@ LearnableInversePositionalEncoding, TrainerCallable, TransformerLightningModule, - TransformerTorchBackbone, ) from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers from rectools.models.nn.transformers.similarity import DistanceSimilarityModule +from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -941,6 +941,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "data_preparator_type": SASRecDataPreparator, "lightning_module_type": TransformerLightningModule, "similarity_module_type": DistanceSimilarityModule, + "torch_backbone_type": TransformerTorchBackbone, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, "data_preparator_kwargs": None, @@ -989,6 +990,7 @@ def test_get_config( "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", + "torch_backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", } expected.update(simple_types_params) if use_custom_trainer: From fb5ba64564bacb1801e221365e00fbb81c6bab81 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 2 Apr 2025 13:58:40 +0300 Subject: [PATCH 2/7] linter --- tests/models/nn/transformers/test_bert4rec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index ed37b1e1..9c62a57f 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -32,7 +32,7 @@ TrainerCallable, TransformerLightningModule, ) -from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable +from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone from tests.models.data import DATASET From 0dea9496c899a754c41c425ef9ba8f38085effa4 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 2 Apr 2025 14:12:32 +0300 Subject: [PATCH 3/7] base class for torch backbone --- rectools/models/nn/transformers/base.py | 10 +- rectools/models/nn/transformers/bert4rec.py | 6 +- rectools/models/nn/transformers/sasrec.py | 6 +- .../models/nn/transformers/torch_backbone.py | 120 ++++++++++++++++-- 4 files changed, 120 insertions(+), 22 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 0ce25d7e..bf939215 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -47,7 +47,7 @@ TransformerLayersBase, ) from .similarity import DistanceSimilarityModule, SimilarityModuleBase -from .torch_backbone import TransformerTorchBackbone +from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone InitKwargs = tp.Dict[str, tp.Any] @@ -108,8 +108,8 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: ), ] -TransformerTorchBackboneType = tpe.Annotated[ - tp.Type[TransformerTorchBackbone], +TransformerBackboneType = tpe.Annotated[ + tp.Type[TransformerBackboneBase], BeforeValidator(_get_class_obj), PlainSerializer( func=get_class_or_function_full_path, @@ -205,7 +205,7 @@ class TransformerModelConfig(ModelConfig): transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: TransformerLightningModuleType = TransformerLightningModule similarity_module_type: SimilarityModuleType = DistanceSimilarityModule - torch_backbone_type: TransformerTorchBackboneType = TransformerTorchBackbone + torch_backbone_type: TransformerBackboneType = TransformerTorchBackbone get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None @@ -262,7 +262,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, - torch_backbone_type: tp.Type[TransformerTorchBackbone] = TransformerTorchBackbone, + torch_backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, data_preparator_kwargs: tp.Optional[InitKwargs] = None, diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index fc71d958..a1cf6f0d 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -45,7 +45,7 @@ TransformerLayersBase, ) from .similarity import DistanceSimilarityModule, SimilarityModuleBase -from .torch_backbone import TransformerTorchBackbone +from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone class BERT4RecDataPreparator(TransformerDataPreparatorBase): @@ -260,7 +260,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): Type of lightning module defining training procedure. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. - torch_backbone_type : type(TransformerTorchBackbone), default `TransformerTorchBackbone` + torch_backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. @@ -330,7 +330,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, - torch_backbone_type: tp.Type[TransformerTorchBackbone] = TransformerTorchBackbone, + torch_backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 001b23ad..43c5227b 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -45,7 +45,7 @@ TransformerLayersBase, ) from .similarity import DistanceSimilarityModule, SimilarityModuleBase -from .torch_backbone import TransformerTorchBackbone +from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone class SASRecDataPreparator(TransformerDataPreparatorBase): @@ -340,7 +340,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): Type of lightning module defining training procedure. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. - torch_backbone_type : type(TransformerTorchBackbone), default `TransformerTorchBackbone` + torch_backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. @@ -409,7 +409,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, - torch_backbone_type: tp.Type[TransformerTorchBackbone] = TransformerTorchBackbone, + torch_backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, diff --git a/rectools/models/nn/transformers/torch_backbone.py b/rectools/models/nn/transformers/torch_backbone.py index c29c0189..6a78dc94 100644 --- a/rectools/models/nn/transformers/torch_backbone.py +++ b/rectools/models/nn/transformers/torch_backbone.py @@ -21,7 +21,101 @@ from .similarity import SimilarityModuleBase -class TransformerTorchBackbone(torch.nn.Module): +class TransformerBackboneBase(torch.nn.Module): + """Base class for transformer torch backbone.""" + + def __init__( + self, + n_heads: int, + dropout_rate: float, + item_model: ItemNetBase, + pos_encoding_layer: PositionalEncodingBase, + transformer_layers: TransformerLayersBase, + similarity_module: SimilarityModuleBase, + use_causal_attn: bool = True, + use_key_padding_mask: bool = False, + **kwargs: tp.Any, + ) -> None: + """ + Initialize transformer torch backbone. + + Parameters + ---------- + n_heads : int + Number of attention heads. + dropout_rate : float + Probability of a hidden unit to be zeroed. + item_model : ItemNetBase + Network for item embeddings. + pos_encoding_layer : PositionalEncodingBase + Positional encoding layer. + transformer_layers : TransformerLayersBase + Transformer layers. + similarity_module : SimilarityModuleBase + Similarity module. + use_causal_attn : bool, default True + If ``True``, causal mask is used in multi-head self-attention. + use_key_padding_mask : bool, default False + If ``True``, key padding mask is used in multi-head self-attention. + **kwargs : Any + Additional keyword arguments for future extensions. + """ + super().__init__() + + self.item_model = item_model + self.pos_encoding_layer = pos_encoding_layer + self.emb_dropout = torch.nn.Dropout(dropout_rate) + self.transformer_layers = transformer_layers + self.similarity_module = similarity_module + self.use_causal_attn = use_causal_attn + self.use_key_padding_mask = use_key_padding_mask + self.n_heads = n_heads + + def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Tensor) -> torch.Tensor: + """ + Pass user history through item embeddings. + Add positional encoding. + Pass history through transformer blocks. + + Parameters + ---------- + batch : Dict[str, torch.Tensor] + Dictionary containing user sessions data. + item_embs : torch.Tensor + Item embeddings. + + Returns + ------- + torch.Tensor. [batch_size, session_max_len, n_factors] + Encoded session embeddings. + """ + raise NotImplementedError() + + def forward( + self, + batch: tp.Dict[str, torch.Tensor], # batch["x"]: [batch_size, session_max_len] + candidate_item_ids: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass to get item and session embeddings. + Get item embeddings. + Pass user sessions through transformer blocks. + + Parameters + ---------- + batch : Dict[str, torch.Tensor] + Dictionary containing user sessions data, with "x" key containing session tensor. + candidate_item_ids : optional(torch.Tensor), default ``None`` + Defined item ids for similarity calculation. + + Returns + ------- + torch.Tensor + """ + raise NotImplementedError() + + +class TransformerTorchBackbone(TransformerBackboneBase): """ Torch model for encoding user sessions based on transformer architecture. @@ -43,6 +137,8 @@ class TransformerTorchBackbone(torch.nn.Module): If ``True``, causal mask is used in multi-head self-attention. use_key_padding_mask : bool, default False If ``True``, key padding mask is used in multi-head self-attention. + **kwargs : Any + Additional keyword arguments for future extensions. """ def __init__( @@ -55,17 +151,19 @@ def __init__( similarity_module: SimilarityModuleBase, use_causal_attn: bool = True, use_key_padding_mask: bool = False, + **kwargs: tp.Any, ) -> None: - super().__init__() - - self.item_model = item_model - self.pos_encoding_layer = pos_encoding_layer - self.emb_dropout = torch.nn.Dropout(dropout_rate) - self.transformer_layers = transformer_layers - self.similarity_module = similarity_module - self.use_causal_attn = use_causal_attn - self.use_key_padding_mask = use_key_padding_mask - self.n_heads = n_heads + super().__init__( + n_heads=n_heads, + dropout_rate=dropout_rate, + item_model=item_model, + pos_encoding_layer=pos_encoding_layer, + transformer_layers=transformer_layers, + similarity_module=similarity_module, + use_causal_attn=use_causal_attn, + use_key_padding_mask=use_key_padding_mask, + **kwargs, + ) @staticmethod def _convert_mask_to_float(mask: torch.Tensor, query: torch.Tensor) -> torch.Tensor: From 798266421837208e68407c87da537f1e0693f49a Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 2 Apr 2025 14:31:03 +0300 Subject: [PATCH 4/7] backbone kwargs --- rectools/models/nn/transformers/base.py | 53 +++++++++++++++++-- rectools/models/nn/transformers/bert4rec.py | 11 ++-- rectools/models/nn/transformers/sasrec.py | 11 ++-- tests/models/nn/transformers/test_bert4rec.py | 5 +- tests/models/nn/transformers/test_sasrec.py | 5 +- 5 files changed, 70 insertions(+), 15 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index bf939215..2243e455 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -176,7 +176,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: class TransformerModelConfig(ModelConfig): - """Transformer model base config.""" + """Transformer model base config. data_preparator_type: TransformerDataPreparatorType n_blocks: int = 2 @@ -205,7 +205,7 @@ class TransformerModelConfig(ModelConfig): transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: TransformerLightningModuleType = TransformerLightningModule similarity_module_type: SimilarityModuleType = DistanceSimilarityModule - torch_backbone_type: TransformerBackboneType = TransformerTorchBackbone + backbone_type: TransformerBackboneType = TransformerTorchBackbone get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None @@ -214,6 +214,46 @@ class TransformerModelConfig(ModelConfig): pos_encoding_kwargs: tp.Optional[InitKwargs] = None lightning_module_kwargs: tp.Optional[InitKwargs] = None similarity_module_kwargs: tp.Optional[InitKwargs] = None + backbone_kwargs: tp.Optional[InitKwargs] = None + """ + + data_preparator_type: TransformerDataPreparatorType + n_blocks: int = 2 + n_heads: int = 4 + n_factors: int = 256 + use_pos_emb: bool = True + use_causal_attn: bool = False + use_key_padding_mask: bool = False + dropout_rate: float = 0.2 + session_max_len: int = 100 + dataloader_num_workers: int = 0 + batch_size: int = 128 + loss: str = "softmax" + n_negatives: int = 1 + gbce_t: float = 0.2 + lr: float = 0.001 + epochs: int = 3 + verbose: int = 0 + deterministic: bool = False + recommend_batch_size: int = 256 + recommend_torch_device: tp.Optional[str] = None + train_min_user_interactions: int = 2 + item_net_block_types: ItemNetBlockTypes = (IdEmbeddingsItemNet, CatFeaturesItemNet) + item_net_constructor_type: ItemNetConstructorType = SumOfEmbeddingsConstructor + pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding + transformer_layers_type: TransformerLayersType = PreLNTransformerLayers + lightning_module_type: TransformerLightningModuleType = TransformerLightningModule + similarity_module_type: SimilarityModuleType = DistanceSimilarityModule + backbone_type: TransformerBackboneType = TransformerTorchBackbone + get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None + get_trainer_func: tp.Optional[TrainerCallableSerialized] = None + data_preparator_kwargs: tp.Optional[InitKwargs] = None + transformer_layers_kwargs: tp.Optional[InitKwargs] = None + item_net_constructor_kwargs: tp.Optional[InitKwargs] = None + pos_encoding_kwargs: tp.Optional[InitKwargs] = None + lightning_module_kwargs: tp.Optional[InitKwargs] = None + similarity_module_kwargs: tp.Optional[InitKwargs] = None + backbone_kwargs: tp.Optional[InitKwargs] = None TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig) @@ -262,7 +302,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, - torch_backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, + backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, data_preparator_kwargs: tp.Optional[InitKwargs] = None, @@ -271,6 +311,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, + backbone_kwargs: tp.Optional[InitKwargs] = None, **kwargs: tp.Any, ) -> None: super().__init__(verbose=verbose) @@ -300,7 +341,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.item_net_constructor_type = item_net_constructor_type self.pos_encoding_type = pos_encoding_type self.lightning_module_type = lightning_module_type - self.torch_backbone_type = torch_backbone_type + self.backbone_type = backbone_type self.get_val_mask_func = get_val_mask_func self.get_trainer_func = get_trainer_func self.data_preparator_kwargs = data_preparator_kwargs @@ -309,6 +350,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.pos_encoding_kwargs = pos_encoding_kwargs self.lightning_module_kwargs = lightning_module_kwargs self.similarity_module_kwargs = similarity_module_kwargs + self.backbone_kwargs = backbone_kwargs self._init_data_preparator() self._init_trainer() @@ -394,7 +436,7 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone pos_encoding_layer = self._init_pos_encoding_layer() transformer_layers = self._init_transformer_layers() similarity_module = self._init_similarity_module() - return self.torch_backbone_type( + return self.backbone_type( n_heads=self.n_heads, dropout_rate=self.dropout_rate, item_model=item_model, @@ -403,6 +445,7 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone similarity_module=similarity_module, use_causal_attn=self.use_causal_attn, use_key_padding_mask=self.use_key_padding_mask, + **self._get_kwargs(self.backbone_kwargs), ) def _init_lightning_model( diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index a1cf6f0d..5644fd2a 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -260,7 +260,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): Type of lightning module defining training procedure. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. - torch_backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` + backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. @@ -298,6 +298,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): similarity_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `similarity_module_type` initialization. Make sure all dict values have JSON serializable types. + backbone_kwargs: optional(dict), default ``None`` + Additional keyword arguments to pass during `backbone_type` initialization. + Make sure all dict values have JSON serializable types. """ config_class = BERT4RecModelConfig @@ -330,7 +333,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, - torch_backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, + backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -344,6 +347,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, + backbone_kwargs: tp.Optional[InitKwargs] = None, ): self.mask_prob = mask_prob @@ -377,7 +381,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, - torch_backbone_type=torch_backbone_type, + backbone_type=backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, data_preparator_kwargs=data_preparator_kwargs, @@ -387,6 +391,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, similarity_module_kwargs=similarity_module_kwargs, + backbone_kwargs=backbone_kwargs, ) def _init_data_preparator(self) -> None: diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 43c5227b..df7c0ea3 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -340,7 +340,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): Type of lightning module defining training procedure. similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule` Type of similarity module. - torch_backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` + backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone` Type of torch backbone. get_val_mask_func : Callable, default ``None`` Function to get validation mask. @@ -378,6 +378,9 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): similarity_module_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `similarity_module_type` initialization. Make sure all dict values have JSON serializable types. + backbone_kwargs: optional(dict), default ``None`` + Additional keyword arguments to pass during `backbone_type` initialization. + Make sure all dict values have JSON serializable types. """ config_class = SASRecModelConfig @@ -409,7 +412,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator, lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, - torch_backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, + backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, recommend_batch_size: int = 256, @@ -422,6 +425,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs: tp.Optional[InitKwargs] = None, lightning_module_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, + backbone_kwargs: tp.Optional[InitKwargs] = None, ): super().__init__( transformer_layers_type=transformer_layers_type, @@ -453,7 +457,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_constructor_type=item_net_constructor_type, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, - torch_backbone_type=torch_backbone_type, + backbone_type=backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, data_preparator_kwargs=data_preparator_kwargs, @@ -462,4 +466,5 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals pos_encoding_kwargs=pos_encoding_kwargs, lightning_module_kwargs=lightning_module_kwargs, similarity_module_kwargs=similarity_module_kwargs, + backbone_kwargs=backbone_kwargs, ) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 9c62a57f..f7c87cd8 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -845,7 +845,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": TransformerLightningModule, "similarity_module_type": DistanceSimilarityModule, - "torch_backbone_type": TransformerTorchBackbone, + "backbone_type": TransformerTorchBackbone, "mask_prob": 0.15, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, @@ -855,6 +855,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "pos_encoding_kwargs": None, "lightning_module_kwargs": None, "similarity_module_kwargs": None, + "backbone_kwargs": None, } return config @@ -895,7 +896,7 @@ def test_get_config( "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", - "torch_backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", + "backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", } expected.update(simple_types_params) if use_custom_trainer: diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 1bb666de..f371935d 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -941,7 +941,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "data_preparator_type": SASRecDataPreparator, "lightning_module_type": TransformerLightningModule, "similarity_module_type": DistanceSimilarityModule, - "torch_backbone_type": TransformerTorchBackbone, + "backbone_type": TransformerTorchBackbone, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, "data_preparator_kwargs": None, @@ -950,6 +950,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "pos_encoding_kwargs": None, "lightning_module_kwargs": None, "similarity_module_kwargs": None, + "backbone_kwargs": None, } return config @@ -990,7 +991,7 @@ def test_get_config( "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", - "torch_backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", + "backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", } expected.update(simple_types_params) if use_custom_trainer: From e05eaf3074f60db10ba77451673f568818db1204 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 2 Apr 2025 14:33:37 +0300 Subject: [PATCH 5/7] changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 06cf7177..a80d13d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) +- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone` parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) - `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276)) +- `TransformerBackboneBase`, `backbone_type` and `backbone_kwargs` parameters to transformer-based models ([#277](https://github.com/MobileTeleSystems/RecTools/pull/277)) ## [0.12.0] - 24.02.2025 From 50a490312e60df7a7fdb6ad295b91d76c3aaa2d5 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 2 Apr 2025 14:35:21 +0300 Subject: [PATCH 6/7] config fix --- rectools/models/nn/transformers/base.py | 41 +------------------------ 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 2243e455..ef71f095 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -176,46 +176,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: class TransformerModelConfig(ModelConfig): - """Transformer model base config. - - data_preparator_type: TransformerDataPreparatorType - n_blocks: int = 2 - n_heads: int = 4 - n_factors: int = 256 - use_pos_emb: bool = True - use_causal_attn: bool = False - use_key_padding_mask: bool = False - dropout_rate: float = 0.2 - session_max_len: int = 100 - dataloader_num_workers: int = 0 - batch_size: int = 128 - loss: str = "softmax" - n_negatives: int = 1 - gbce_t: float = 0.2 - lr: float = 0.001 - epochs: int = 3 - verbose: int = 0 - deterministic: bool = False - recommend_batch_size: int = 256 - recommend_torch_device: tp.Optional[str] = None - train_min_user_interactions: int = 2 - item_net_block_types: ItemNetBlockTypes = (IdEmbeddingsItemNet, CatFeaturesItemNet) - item_net_constructor_type: ItemNetConstructorType = SumOfEmbeddingsConstructor - pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding - transformer_layers_type: TransformerLayersType = PreLNTransformerLayers - lightning_module_type: TransformerLightningModuleType = TransformerLightningModule - similarity_module_type: SimilarityModuleType = DistanceSimilarityModule - backbone_type: TransformerBackboneType = TransformerTorchBackbone - get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None - get_trainer_func: tp.Optional[TrainerCallableSerialized] = None - data_preparator_kwargs: tp.Optional[InitKwargs] = None - transformer_layers_kwargs: tp.Optional[InitKwargs] = None - item_net_constructor_kwargs: tp.Optional[InitKwargs] = None - pos_encoding_kwargs: tp.Optional[InitKwargs] = None - lightning_module_kwargs: tp.Optional[InitKwargs] = None - similarity_module_kwargs: tp.Optional[InitKwargs] = None - backbone_kwargs: tp.Optional[InitKwargs] = None - """ + """Transformer model base config.""" data_preparator_type: TransformerDataPreparatorType n_blocks: int = 2 From acba15fb2dcab5d1905ca49647a56c33e55d3ef4 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 2 Apr 2025 15:08:39 +0300 Subject: [PATCH 7/7] linters --- rectools/models/nn/transformers/base.py | 6 ++--- rectools/models/nn/transformers/lightning.py | 23 +++++++++----------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index ef71f095..fec4a99b 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -393,7 +393,7 @@ def _init_transformer_layers(self) -> TransformerLayersBase: def _init_similarity_module(self) -> SimilarityModuleBase: return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs)) - def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone: + def _init_torch_model(self, item_model: ItemNetBase) -> TransformerBackboneBase: pos_encoding_layer = self._init_pos_encoding_layer() transformer_layers = self._init_transformer_layers() similarity_module = self._init_similarity_module() @@ -411,7 +411,7 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone def _init_lightning_model( self, - torch_model: TransformerTorchBackbone, + torch_model: TransformerBackboneBase, dataset_schema: DatasetSchemaDict, item_external_ids: ExternalIds, model_config: tp.Dict[str, tp.Any], @@ -507,7 +507,7 @@ def _recommend_i2i( ) @property - def torch_model(self) -> TransformerTorchBackbone: + def torch_model(self) -> TransformerBackboneBase: """Pytorch model.""" return self.lightning_model.torch_model diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 2d1c043f..0c554659 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -26,7 +26,7 @@ from rectools.types import InternalIdsArray from .data_preparator import TransformerDataPreparatorBase -from .torch_backbone import TransformerTorchBackbone +from .torch_backbone import TransformerBackboneBase # #### -------------- Lightning Base Model -------------- #### # @@ -38,7 +38,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma Parameters ---------- - torch_model : TransformerTorchBackbone + torch_model : TransformerBackboneBase Torch model to make recommendations. lr : float Learning rate. @@ -61,7 +61,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma def __init__( self, - torch_model: TransformerTorchBackbone, + torch_model: TransformerBackboneBase, model_config: tp.Dict[str, tp.Any], dataset_schema: DatasetSchemaDict, item_external_ids: ExternalIds, @@ -350,17 +350,14 @@ def _recommend_u2i( user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device) - all_user_ids, all_reco_ids, all_scores = ( - self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access - user_embs=user_embs, - item_embs=item_embs, - user_ids=user_ids, - k=k, - sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, - ui_csr_for_filter=ui_csr_for_filter, - ) + return self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access + user_embs=user_embs, + item_embs=item_embs, + user_ids=user_ids, + k=k, + sorted_item_ids_to_recommend=sorted_item_ids_to_recommend, + ui_csr_for_filter=ui_csr_for_filter, ) - return all_user_ids, all_reco_ids, all_scores def _recommend_i2i( self,