Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone` parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276))
- `TransformerBackboneBase`, `backbone_type` and `backbone_kwargs` parameters to transformer-based models ([#277](https://github.com/MobileTeleSystems/RecTools/pull/277))
- `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274))

## [0.12.0] - 24.02.2025
Expand Down
36 changes: 18 additions & 18 deletions examples/tutorials/transformers_advanced_training_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"epoch,step,train_loss,val_loss\r",
"epoch,step,train_loss,val_loss\r\n",
"\r\n",
"0,1,,22.365339279174805\r",
"0,1,,22.365339279174805\r\n",
"\r\n",
"0,1,22.38391876220703,\r",
"0,1,22.38391876220703,\r\n",
"\r\n",
"1,3,,22.189851760864258\r",
"1,3,,22.189851760864258\r\n",
"\r\n",
"1,3,22.898216247558594,\r",
"1,3,22.898216247558594,\r\n",
"\r\n"
]
}
Expand Down Expand Up @@ -526,23 +526,23 @@
"name": "stdout",
"output_type": "stream",
"text": [
"epoch,step,train_loss,val_loss\r",
"epoch,step,train_loss,val_loss\r\n",
"\r\n",
"0,1,,22.343637466430664\r",
"0,1,,22.343637466430664\r\n",
"\r\n",
"0,1,22.36273765563965,\r",
"0,1,22.36273765563965,\r\n",
"\r\n",
"1,3,,22.159835815429688\r",
"1,3,,22.159835815429688\r\n",
"\r\n",
"1,3,22.33755874633789,\r",
"1,3,22.33755874633789,\r\n",
"\r\n",
"2,5,,21.94308853149414\r",
"2,5,,21.94308853149414\r\n",
"\r\n",
"2,5,22.244243621826172,\r",
"2,5,22.244243621826172,\r\n",
"\r\n",
"3,7,,21.702259063720703\r",
"3,7,,21.702259063720703\r\n",
"\r\n",
"3,7,22.196012496948242,\r",
"3,7,22.196012496948242,\r\n",
"\r\n"
]
}
Expand Down Expand Up @@ -898,7 +898,7 @@
" ) -> None:\n",
" logits = outputs[\"logits\"]\n",
" if logits is None:\n",
" logits = pl_module.torch_model.encode_sessions(batch[\"x\"], pl_module.item_embs)[:, -1, :]\n",
" logits = pl_module.torch_model.encode_sessions(batch, pl_module.item_embs)[:, -1, :]\n",
" _, sorted_batch_recos = logits.topk(k=self.top_k)\n",
"\n",
" batch_recos = sorted_batch_recos.tolist()\n",
Expand Down Expand Up @@ -2039,9 +2039,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "rectools",
"display_name": ".venv",
"language": "python",
"name": "rectools"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -2053,7 +2053,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
27 changes: 22 additions & 5 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
TransformerLayersBase,
)
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
from .torch_backbone import TransformerTorchBackbone
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone

InitKwargs = tp.Dict[str, tp.Any]

Expand Down Expand Up @@ -108,6 +108,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
),
]

TransformerBackboneType = tpe.Annotated[
tp.Type[TransformerBackboneBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]

TransformerDataPreparatorType = tpe.Annotated[
tp.Type[TransformerDataPreparatorBase],
BeforeValidator(_get_class_obj),
Expand Down Expand Up @@ -195,6 +205,7 @@ class TransformerModelConfig(ModelConfig):
transformer_layers_type: TransformerLayersType = PreLNTransformerLayers
lightning_module_type: TransformerLightningModuleType = TransformerLightningModule
similarity_module_type: SimilarityModuleType = DistanceSimilarityModule
backbone_type: TransformerBackboneType = TransformerTorchBackbone
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
data_preparator_kwargs: tp.Optional[InitKwargs] = None
Expand All @@ -203,6 +214,7 @@ class TransformerModelConfig(ModelConfig):
pos_encoding_kwargs: tp.Optional[InitKwargs] = None
lightning_module_kwargs: tp.Optional[InitKwargs] = None
similarity_module_kwargs: tp.Optional[InitKwargs] = None
backbone_kwargs: tp.Optional[InitKwargs] = None


TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig)
Expand Down Expand Up @@ -251,6 +263,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
Expand All @@ -259,6 +272,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
backbone_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
super().__init__(verbose=verbose)
Expand Down Expand Up @@ -288,6 +302,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.item_net_constructor_type = item_net_constructor_type
self.pos_encoding_type = pos_encoding_type
self.lightning_module_type = lightning_module_type
self.backbone_type = backbone_type
self.get_val_mask_func = get_val_mask_func
self.get_trainer_func = get_trainer_func
self.data_preparator_kwargs = data_preparator_kwargs
Expand All @@ -296,6 +311,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.pos_encoding_kwargs = pos_encoding_kwargs
self.lightning_module_kwargs = lightning_module_kwargs
self.similarity_module_kwargs = similarity_module_kwargs
self.backbone_kwargs = backbone_kwargs

self._init_data_preparator()
self._init_trainer()
Expand Down Expand Up @@ -377,11 +393,11 @@ def _init_transformer_layers(self) -> TransformerLayersBase:
def _init_similarity_module(self) -> SimilarityModuleBase:
return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs))

def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone:
def _init_torch_model(self, item_model: ItemNetBase) -> TransformerBackboneBase:
pos_encoding_layer = self._init_pos_encoding_layer()
transformer_layers = self._init_transformer_layers()
similarity_module = self._init_similarity_module()
return TransformerTorchBackbone(
return self.backbone_type(
n_heads=self.n_heads,
dropout_rate=self.dropout_rate,
item_model=item_model,
Expand All @@ -390,11 +406,12 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone
similarity_module=similarity_module,
use_causal_attn=self.use_causal_attn,
use_key_padding_mask=self.use_key_padding_mask,
**self._get_kwargs(self.backbone_kwargs),
)

def _init_lightning_model(
self,
torch_model: TransformerTorchBackbone,
torch_model: TransformerBackboneBase,
dataset_schema: DatasetSchemaDict,
item_external_ids: ExternalIds,
model_config: tp.Dict[str, tp.Any],
Expand Down Expand Up @@ -490,7 +507,7 @@ def _recommend_i2i(
)

@property
def torch_model(self) -> TransformerTorchBackbone:
def torch_model(self) -> TransformerBackboneBase:
"""Pytorch model."""
return self.lightning_model.torch_model

Expand Down
10 changes: 10 additions & 0 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
TransformerLayersBase,
)
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone


class BERT4RecDataPreparator(TransformerDataPreparatorBase):
Expand Down Expand Up @@ -259,6 +260,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
Type of lightning module defining training procedure.
similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule`
Type of similarity module.
backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone`
Type of torch backbone.
get_val_mask_func : Callable, default ``None``
Function to get validation mask.
get_trainer_func : Callable, default ``None``
Expand Down Expand Up @@ -295,6 +298,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
similarity_module_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `similarity_module_type` initialization.
Make sure all dict values have JSON serializable types.
backbone_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `backbone_type` initialization.
Make sure all dict values have JSON serializable types.
"""

config_class = BERT4RecModelConfig
Expand Down Expand Up @@ -327,6 +333,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
recommend_batch_size: int = 256,
Expand All @@ -340,6 +347,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
backbone_kwargs: tp.Optional[InitKwargs] = None,
):
self.mask_prob = mask_prob

Expand Down Expand Up @@ -373,6 +381,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
item_net_constructor_type=item_net_constructor_type,
pos_encoding_type=pos_encoding_type,
lightning_module_type=lightning_module_type,
backbone_type=backbone_type,
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
data_preparator_kwargs=data_preparator_kwargs,
Expand All @@ -382,6 +391,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
pos_encoding_kwargs=pos_encoding_kwargs,
lightning_module_kwargs=lightning_module_kwargs,
similarity_module_kwargs=similarity_module_kwargs,
backbone_kwargs=backbone_kwargs,
)

def _init_data_preparator(self) -> None:
Expand Down
31 changes: 14 additions & 17 deletions rectools/models/nn/transformers/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from rectools.types import InternalIdsArray

from .data_preparator import TransformerDataPreparatorBase
from .torch_backbone import TransformerTorchBackbone
from .torch_backbone import TransformerBackboneBase

# #### -------------- Lightning Base Model -------------- #### #

Expand All @@ -38,7 +38,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma

Parameters
----------
torch_model : TransformerTorchBackbone
torch_model : TransformerBackboneBase
Torch model to make recommendations.
lr : float
Learning rate.
Expand All @@ -61,7 +61,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma

def __init__(
self,
torch_model: TransformerTorchBackbone,
torch_model: TransformerBackboneBase,
model_config: tp.Dict[str, tp.Any],
dataset_schema: DatasetSchemaDict,
item_external_ids: ExternalIds,
Expand Down Expand Up @@ -250,13 +250,12 @@ def on_train_start(self) -> None:

def get_batch_logits(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor:
"""Get bacth logits."""
x = batch["x"] # x: [batch_size, session_max_len]
if self._requires_negatives:
y, negatives = batch["y"], batch["negatives"]
pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1)
logits = self.torch_model(sessions=x, item_ids=pos_neg)
logits = self.torch_model(batch=batch, candidate_item_ids=pos_neg)
else:
logits = self.torch_model(sessions=x)
logits = self.torch_model(batch=batch)
return logits

def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
Expand Down Expand Up @@ -338,7 +337,8 @@ def _get_user_item_embeddings(
item_embs = self.torch_model.item_model.get_all_embeddings()
user_embs = []
for batch in recommend_dataloader:
batch_embs = self.torch_model.encode_sessions(batch["x"].to(device), item_embs)[:, -1, :]
batch = {k: v.to(device) for k, v in batch.items()}
batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :]
user_embs.append(batch_embs)

return torch.cat(user_embs), item_embs
Expand All @@ -360,17 +360,14 @@ def _recommend_u2i(

user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device)

all_user_ids, all_reco_ids, all_scores = (
self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access
user_embs=user_embs,
item_embs=item_embs,
user_ids=user_ids,
k=k,
sorted_item_ids_to_recommend=sorted_item_ids_to_recommend,
ui_csr_for_filter=ui_csr_for_filter,
)
return self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access
user_embs=user_embs,
item_embs=item_embs,
user_ids=user_ids,
k=k,
sorted_item_ids_to_recommend=sorted_item_ids_to_recommend,
ui_csr_for_filter=ui_csr_for_filter,
)
return all_user_ids, all_reco_ids, all_scores

def _recommend_i2i(
self,
Expand Down
10 changes: 10 additions & 0 deletions rectools/models/nn/transformers/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
TransformerLayersBase,
)
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone


class SASRecDataPreparator(TransformerDataPreparatorBase):
Expand Down Expand Up @@ -339,6 +340,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
Type of lightning module defining training procedure.
similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule`
Type of similarity module.
backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone`
Type of torch backbone.
get_val_mask_func : Callable, default ``None``
Function to get validation mask.
get_trainer_func : Callable, default ``None``
Expand Down Expand Up @@ -375,6 +378,9 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
similarity_module_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `similarity_module_type` initialization.
Make sure all dict values have JSON serializable types.
backbone_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `backbone_type` initialization.
Make sure all dict values have JSON serializable types.
"""

config_class = SASRecModelConfig
Expand Down Expand Up @@ -406,6 +412,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator,
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
recommend_batch_size: int = 256,
Expand All @@ -418,6 +425,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
backbone_kwargs: tp.Optional[InitKwargs] = None,
):
super().__init__(
transformer_layers_type=transformer_layers_type,
Expand Down Expand Up @@ -449,6 +457,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
item_net_constructor_type=item_net_constructor_type,
pos_encoding_type=pos_encoding_type,
lightning_module_type=lightning_module_type,
backbone_type=backbone_type,
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
data_preparator_kwargs=data_preparator_kwargs,
Expand All @@ -457,4 +466,5 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
pos_encoding_kwargs=pos_encoding_kwargs,
lightning_module_kwargs=lightning_module_kwargs,
similarity_module_kwargs=similarity_module_kwargs,
backbone_kwargs=backbone_kwargs,
)
Loading