Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
018ddef
Added cosine distance to u2i transformer models.
In48semenov Mar 14, 2025
269430a
Fixed docs and added cosine to test_bert4rec.
In48semenov Mar 15, 2025
3750ce8
reverted `rank torch` to the original code
In48semenov Mar 17, 2025
6f3f4a4
Fixed cosine fusion on train stage.
In48semenov Mar 17, 2025
f387186
Fixed u2i tests with cosine distance and config tests.
In48semenov Mar 17, 2025
f5cd773
Fixed code style.
In48semenov Mar 18, 2025
bfd4c63
Added test with raises u2i_dist.
In48semenov Mar 18, 2025
2548c64
Added similarity module base and distance.
In48semenov Mar 22, 2025
fe607ae
Added Similarity Module to torch model.
In48semenov Mar 25, 2025
846ba46
Renamed similaty classes, added properties for requiers negs, fixed c…
In48semenov Mar 27, 2025
7e00297
Added epsilone for cosine similarity.
In48semenov Mar 27, 2025
274d1c9
Added pylint disable.
In48semenov Mar 27, 2025
d033e6d
Changed Enum Distance.
In48semenov Mar 27, 2025
d434456
Removed unnecessary kwargs in test_base.py in transformers.
In48semenov Mar 27, 2025
bc61feb
Removed kwargs `item_net_block_types` from test_base.py in transformers.
In48semenov Mar 27, 2025
57213fd
Put filter train interactions as separate method.
In48semenov Mar 27, 2025
0d006b5
Fixed docs.
In48semenov Mar 27, 2025
325e8c7
Fixed docs and removed df copy in filter interactions method.
In48semenov Mar 29, 2025
3ac5755
Added method get_loss_calculator, made default value n_negative as 0.
In48semenov Mar 29, 2025
c430918
Returned Optional value for n_negatives.
In48semenov Mar 29, 2025
2c71287
n_negatives is Optional[int]
In48semenov Mar 29, 2025
113dba4
Returned # type: ignore to test_torch_dataset.py
In48semenov Mar 29, 2025
e1c40b5
Added `pragma: no cover` to `_get_reduced_overconfidence_logits`
In48semenov Mar 29, 2025
f1583c5
Added annotation for get_loss_calculator.
In48semenov Mar 29, 2025
da6d122
Separated logic `get_loss_calculator` and `requires_negatives`.
In48semenov Mar 30, 2025
606432b
Added device for epsilone cosine dist.
In48semenov Mar 31, 2025
9ac6150
Added copyright.
In48semenov Mar 31, 2025
cf66a58
Updated CHANGELOG.md
In48semenov Mar 31, 2025
5ed7605
Made `epsilon_cosine_dist` as Parameter.
In48semenov Mar 31, 2025
ccffb16
Fixed epsilon_cosine_dist device
In48semenov Mar 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


### Added
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))

## [0.12.0] - 24.02.2025

### Added
Expand Down
25 changes: 24 additions & 1 deletion rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
PreLNTransformerLayers,
TransformerLayersBase,
)
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
from .torch_backbone import TransformerTorchBackbone

InitKwargs = tp.Dict[str, tp.Any]
Expand Down Expand Up @@ -97,6 +98,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
),
]

SimilarityModuleType = tpe.Annotated[
tp.Type[SimilarityModuleBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]

TransformerDataPreparatorType = tpe.Annotated[
tp.Type[TransformerDataPreparatorBase],
BeforeValidator(_get_class_obj),
Expand Down Expand Up @@ -183,13 +194,15 @@ class TransformerModelConfig(ModelConfig):
pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding
transformer_layers_type: TransformerLayersType = PreLNTransformerLayers
lightning_module_type: TransformerLightningModuleType = TransformerLightningModule
similarity_module_type: SimilarityModuleType = DistanceSimilarityModule
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
data_preparator_kwargs: tp.Optional[InitKwargs] = None
transformer_layers_kwargs: tp.Optional[InitKwargs] = None
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
pos_encoding_kwargs: tp.Optional[InitKwargs] = None
lightning_module_kwargs: tp.Optional[InitKwargs] = None
similarity_module_kwargs: tp.Optional[InitKwargs] = None


TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig)
Expand Down Expand Up @@ -237,13 +250,15 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
transformer_layers_kwargs: tp.Optional[InitKwargs] = None,
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
super().__init__(verbose=verbose)
Expand All @@ -268,6 +283,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.recommend_batch_size = recommend_batch_size
self.recommend_torch_device = recommend_torch_device
self.train_min_user_interactions = train_min_user_interactions
self.similarity_module_type = similarity_module_type
self.item_net_block_types = item_net_block_types
self.item_net_constructor_type = item_net_constructor_type
self.pos_encoding_type = pos_encoding_type
Expand All @@ -279,6 +295,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.item_net_constructor_kwargs = item_net_constructor_kwargs
self.pos_encoding_kwargs = pos_encoding_kwargs
self.lightning_module_kwargs = lightning_module_kwargs
self.similarity_module_kwargs = similarity_module_kwargs

self._init_data_preparator()
self._init_trainer()
Expand All @@ -295,12 +312,13 @@ def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
return kwargs

def _init_data_preparator(self) -> None:
requires_negatives = self.lightning_module_type.requires_negatives(self.loss)
self.data_preparator = self.data_preparator_type(
session_max_len=self.session_max_len,
batch_size=self.batch_size,
dataloader_num_workers=self.dataloader_num_workers,
train_min_user_interactions=self.train_min_user_interactions,
n_negatives=self.n_negatives if self.loss != "softmax" else None,
n_negatives=self.n_negatives if requires_negatives else None,
get_val_mask_func=self.get_val_mask_func,
shuffle_train=True,
**self._get_kwargs(self.data_preparator_kwargs),
Expand Down Expand Up @@ -356,15 +374,20 @@ def _init_transformer_layers(self) -> TransformerLayersBase:
**self._get_kwargs(self.transformer_layers_kwargs),
)

def _init_similarity_module(self) -> SimilarityModuleBase:
return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs))

def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone:
pos_encoding_layer = self._init_pos_encoding_layer()
transformer_layers = self._init_transformer_layers()
similarity_module = self._init_similarity_module()
return TransformerTorchBackbone(
n_heads=self.n_heads,
dropout_rate=self.dropout_rate,
item_model=item_model,
pos_encoding_layer=pos_encoding_layer,
transformer_layers=transformer_layers,
similarity_module=similarity_module,
use_causal_attn=self.use_causal_attn,
use_key_padding_mask=self.use_key_padding_mask,
)
Expand Down
10 changes: 10 additions & 0 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
PreLNTransformerLayers,
TransformerLayersBase,
)
from .similarity import DistanceSimilarityModule, SimilarityModuleBase


class BERT4RecDataPreparator(TransformerDataPreparatorBase):
Expand Down Expand Up @@ -256,6 +257,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
Type of data preparator used for dataset processing and dataloader creation.
lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule`
Type of lightning module defining training procedure.
similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule`
Type of similarity module.
get_val_mask_func : Callable, default ``None``
Function to get validation mask.
get_trainer_func : Callable, default ``None``
Expand Down Expand Up @@ -289,6 +292,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
lightning_module_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `lightning_module_type` initialization.
Make sure all dict values have JSON serializable types.
similarity_module_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `similarity_module_type` initialization.
Make sure all dict values have JSON serializable types.
"""

config_class = BERT4RecModelConfig
Expand Down Expand Up @@ -320,6 +326,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers,
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
recommend_batch_size: int = 256,
Expand All @@ -332,6 +339,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
):
self.mask_prob = mask_prob

Expand Down Expand Up @@ -360,6 +368,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
recommend_n_threads=recommend_n_threads,
recommend_use_torch_ranking=recommend_use_torch_ranking,
train_min_user_interactions=train_min_user_interactions,
similarity_module_type=similarity_module_type,
item_net_block_types=item_net_block_types,
item_net_constructor_type=item_net_constructor_type,
pos_encoding_type=pos_encoding_type,
Expand All @@ -372,6 +381,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
item_net_constructor_kwargs=item_net_constructor_kwargs,
pos_encoding_kwargs=pos_encoding_kwargs,
lightning_module_kwargs=lightning_module_kwargs,
similarity_module_kwargs=similarity_module_kwargs,
)

def _init_data_preparator(self) -> None:
Expand Down
21 changes: 13 additions & 8 deletions rectools/models/nn/transformers/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ def _process_features_for_id_map(
full_feature_values = np.vstack([extra_token_feature_values, sorted_features.values])
return DenseFeatures.from_iterables(values=full_feature_values, names=raw_features.names)

def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.DataFrame:
"""Filter train interactions."""
user_stats = train_interactions[Columns.User].value_counts()
users = user_stats[user_stats >= self.train_min_user_interactions].index
train_interactions = train_interactions[(train_interactions[Columns.User].isin(users))]
train_interactions = (
train_interactions.sort_values(Columns.Datetime, kind="stable")
.groupby(Columns.User, sort=False)
.tail(self.session_max_len + self.train_session_max_len_addition)
)
return train_interactions

def process_dataset_train(self, dataset: Dataset) -> None:
"""Process train dataset and save data."""
raw_interactions = dataset.get_raw_interactions()
Expand All @@ -179,14 +191,7 @@ def process_dataset_train(self, dataset: Dataset) -> None:
interactions = raw_interactions[~val_mask]

# Filter train interactions
user_stats = interactions[Columns.User].value_counts()
users = user_stats[user_stats >= self.train_min_user_interactions].index
interactions = interactions[(interactions[Columns.User].isin(users))]
interactions = (
interactions.sort_values(Columns.Datetime, kind="stable")
.groupby(Columns.User, sort=False)
.tail(self.session_max_len + self.train_session_max_len_addition)
)
interactions = self._filter_train_interactions(interactions)

# Prepare id maps
user_id_map = IdMap.from_values(interactions[Columns.User].values)
Expand Down
Loading