diff --git a/CHANGELOG.md b/CHANGELOG.md index 06cf7177..4815ac75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased ### Added - `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) - `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276)) +- `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274)) ## [0.12.0] - 24.02.2025 diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index c2f2d814..44eb9cff 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -209,7 +209,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): train_min_user_interactions : int, default 2 Minimum number of interactions user should have to be used for training. Should be greater than 1. - loss : {"softmax", "BCE", "gBCE"}, default "softmax" + loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax" Loss function. n_negatives : int, default 1 Number of negatives for BCE and gBCE losses. diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 5bfb660d..d99f3556 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -102,7 +102,7 @@ def requires_negatives(loss: str) -> tp.Optional[bool]: if loss == "softmax": return False - if loss in ["BCE", "gBCE"]: + if loss in ["BCE", "gBCE", "sampled_softmax"]: return True return None @@ -120,6 +120,9 @@ def get_loss_calculator( if self.loss == "gBCE": return self._calc_gbce_loss + if self.loss == "sampled_softmax": + return self._calc_sampled_softmax_loss + return None @classmethod @@ -185,6 +188,13 @@ def _calc_gbce_loss(self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor loss = self._calc_bce_loss(logits, y, w) return loss + def _calc_sampled_softmax_loss(self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + # We put positive logits at index 1 since index 0 is used to ignore padding + logits[:, :, [0, 1]] = logits[:, :, [1, 0]] + target = (y != 0).long() + loss = self._calc_softmax_loss(logits, target, w) + return loss + def configure_optimizers(self) -> torch.optim.Adam: """Choose what optimizers and learning-rate schedulers to use in optimization""" optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas) diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 4bc36907..b24f9079 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -289,7 +289,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): train_min_user_interactions : int, default 2 Minimum number of interactions user should have to be used for training. Should be greater than 1. - loss : {"softmax", "BCE", "gBCE"}, default "softmax" + loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax" Loss function. n_negatives : int, default 1 Number of negatives for BCE and gBCE losses. diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index a29d8f24..3b3b5110 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -261,7 +261,7 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( (True, ["epoch", "step", "train_loss", "val_loss"]), ), ) - @pytest.mark.parametrize("loss", ("softmax", "BCE", "gBCE")) + @pytest.mark.parametrize("loss", ("softmax", "BCE", "gBCE", "sampled_softmax")) def test_log_metrics( self, model_cls: tp.Type[TransformerModelBase], diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index a8121c3d..1c15c453 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -295,6 +295,16 @@ def get_trainer() -> Trainer: } ), ), + ( + "sampled_softmax", + pd.DataFrame( + { + Columns.User: [30, 40, 40], + Columns.Item: [12, 12, 13], + Columns.Rank: [1, 1, 2], + } + ), + ), ), ) @pytest.mark.parametrize("u2i_dist", ("dot", "cosine")) diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 189d4e53..385e3197 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -326,6 +326,17 @@ def get_trainer() -> Trainer: ), "dot", ), + ( + "sampled_softmax", + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 13, 17, 14, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + "dot", + ), ( "BCE", pd.DataFrame( @@ -348,6 +359,17 @@ def get_trainer() -> Trainer: ), "cosine", ), + ( + "sampled_softmax", + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 13, 14, 17, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + "cosine", + ), ), ) def test_u2i_losses(