From 2887cdf076ea1dcd1a9ea12b290d7b03071b965e Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Mon, 24 Mar 2025 23:40:39 +0300 Subject: [PATCH 1/6] sampled_softmax --- rectools/models/nn/transformers/lightning.py | 21 ++++++++++++++++---- rectools/models/nn/transformers/sasrec.py | 2 +- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 05e363fc..7897a45e 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -150,7 +150,7 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to x, y, w = batch["x"], batch["y"], batch["yw"] if self.loss == "softmax": logits = self._get_full_catalog_logits(x) - loss = self._calc_softmax_loss(logits, y, w) + loss = self._calc_softmax_loss(logits, y, w, ingnore_index=0) elif self.loss == "BCE": negatives = batch["negatives"] logits = self._get_pos_neg_logits(x, y, negatives) @@ -159,6 +159,11 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to negatives = batch["negatives"] logits = self._get_pos_neg_logits(x, y, negatives) loss = self._calc_gbce_loss(logits, y, w, negatives) + elif self.loss == "sampled_softmax": + negatives = batch["negatives"] + logits = self._get_pos_neg_logits(x, y, negatives) + target = (y == 0).long() + loss = self._calc_softmax_loss(logits, target, w, ingnore_index=1) else: loss = self._calc_custom_loss(batch, batch_idx) @@ -189,7 +194,7 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> outputs = {} if self.loss == "softmax": logits = self._get_full_catalog_logits(x)[:, -1:, :] - outputs["loss"] = self._calc_softmax_loss(logits, y, w) + outputs["loss"] = self._calc_softmax_loss(logits, y, w, ingnore_index=0) outputs["logits"] = logits.squeeze() elif self.loss == "BCE": negatives = batch["negatives"] @@ -201,6 +206,12 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> pos_neg_logits = self._get_pos_neg_logits(x, y, negatives)[:, -1:, :] outputs["loss"] = self._calc_gbce_loss(pos_neg_logits, y, w, negatives) outputs["pos_neg_logits"] = pos_neg_logits.squeeze() + elif self.loss == "sampled_softmax": + negatives = batch["negatives"] + pos_neg_logits = self._get_pos_neg_logits(x, y, negatives)[:, -1:, :] + target = (y == 0).long() + outputs["loss"] = self._calc_softmax_loss(pos_neg_logits, target, w, ingnore_index=1) + outputs["pos_neg_logits"] = pos_neg_logits.squeeze() else: outputs = self._calc_custom_loss_outputs(batch, batch_idx) # pragma: no cover @@ -245,7 +256,9 @@ def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int, return logits @classmethod - def _calc_softmax_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + def _calc_softmax_loss( + cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor, ingnore_index: int + ) -> torch.Tensor: # We are using CrossEntropyLoss with a multi-dimensional case # Logits must be passed in form of [batch_size, n_items + n_item_extra_tokens, session_max_len], @@ -257,7 +270,7 @@ def _calc_softmax_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tens # Loss output will have a shape of [batch_size, session_max_len] # and will have zeros for every `0` target label loss = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), y, ignore_index=0, reduction="none" + logits.transpose(1, 2), y, ignore_index=ingnore_index, reduction="none" ) # [batch_size, session_max_len] loss = loss * w n = (loss > 0).to(loss.dtype) diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 343c9c7c..4c806f1b 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -288,7 +288,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): train_min_user_interactions : int, default 2 Minimum number of interactions user should have to be used for training. Should be greater than 1. - loss : {"softmax", "BCE", "gBCE"}, default "softmax" + loss : {"softmax", "BCE", "gBCE", "sampled_softmax}, default "softmax" Loss function. n_negatives : int, default 1 Number of negatives for BCE and gBCE losses. From aeb7cb7c279adab584b44fc5b656f6ea62d094b7 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Mon, 24 Mar 2025 23:50:29 +0300 Subject: [PATCH 2/6] ssm tests --- tests/models/nn/transformers/test_base.py | 2 +- tests/models/nn/transformers/test_sasrec.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index bc61bea5..2c247a83 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -273,7 +273,7 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model( (True, ["epoch", "step", "train_loss", "val_loss"]), ), ) - @pytest.mark.parametrize("loss", ("softmax", "BCE", "gBCE")) + @pytest.mark.parametrize("loss", ("softmax", "BCE", "gBCE", "sampled_softmax")) def test_log_metrics( self, model_cls: tp.Type[TransformerModelBase], diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 58442de3..ce4eb4ac 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -319,6 +319,16 @@ def get_trainer() -> Trainer: } ), ), + ( + "sampled_softmax", + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 13, 17, 14, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + ), ), ) def test_u2i_losses( From c106e16bb8163528db3b30bcc35bf20d13c3fe5e Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Wed, 26 Mar 2025 18:39:52 +0300 Subject: [PATCH 3/6] removed ignore_index --- rectools/models/nn/transformers/bert4rec.py | 2 +- rectools/models/nn/transformers/lightning.py | 20 +++++++++---------- rectools/models/nn/transformers/sasrec.py | 2 +- tests/models/nn/transformers/test_bert4rec.py | 10 ++++++++++ 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 71675ebd..b9e09062 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -208,7 +208,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): train_min_user_interactions : int, default 2 Minimum number of interactions user should have to be used for training. Should be greater than 1. - loss : {"softmax", "BCE", "gBCE"}, default "softmax" + loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax" Loss function. n_negatives : int, default 1 Number of negatives for BCE and gBCE losses. diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 7897a45e..617d13ce 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -150,7 +150,7 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to x, y, w = batch["x"], batch["y"], batch["yw"] if self.loss == "softmax": logits = self._get_full_catalog_logits(x) - loss = self._calc_softmax_loss(logits, y, w, ingnore_index=0) + loss = self._calc_softmax_loss(logits, y, w) elif self.loss == "BCE": negatives = batch["negatives"] logits = self._get_pos_neg_logits(x, y, negatives) @@ -162,8 +162,9 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to elif self.loss == "sampled_softmax": negatives = batch["negatives"] logits = self._get_pos_neg_logits(x, y, negatives) - target = (y == 0).long() - loss = self._calc_softmax_loss(logits, target, w, ingnore_index=1) + logits[:, :, [0, 1]] = logits[:, :, [1, 0]] + target = (y != 0).long() + loss = self._calc_softmax_loss(logits, target, w) else: loss = self._calc_custom_loss(batch, batch_idx) @@ -194,7 +195,7 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> outputs = {} if self.loss == "softmax": logits = self._get_full_catalog_logits(x)[:, -1:, :] - outputs["loss"] = self._calc_softmax_loss(logits, y, w, ingnore_index=0) + outputs["loss"] = self._calc_softmax_loss(logits, y, w) outputs["logits"] = logits.squeeze() elif self.loss == "BCE": negatives = batch["negatives"] @@ -209,8 +210,9 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> elif self.loss == "sampled_softmax": negatives = batch["negatives"] pos_neg_logits = self._get_pos_neg_logits(x, y, negatives)[:, -1:, :] - target = (y == 0).long() - outputs["loss"] = self._calc_softmax_loss(pos_neg_logits, target, w, ingnore_index=1) + pos_neg_logits[:, :, [0, 1]] = pos_neg_logits[:, :, [1, 0]] + target = (y != 0).long() + outputs["loss"] = self._calc_softmax_loss(pos_neg_logits, target, w) outputs["pos_neg_logits"] = pos_neg_logits.squeeze() else: outputs = self._calc_custom_loss_outputs(batch, batch_idx) # pragma: no cover @@ -256,9 +258,7 @@ def _get_reduced_overconfidence_logits(self, logits: torch.Tensor, n_items: int, return logits @classmethod - def _calc_softmax_loss( - cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor, ingnore_index: int - ) -> torch.Tensor: + def _calc_softmax_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor: # We are using CrossEntropyLoss with a multi-dimensional case # Logits must be passed in form of [batch_size, n_items + n_item_extra_tokens, session_max_len], @@ -270,7 +270,7 @@ def _calc_softmax_loss( # Loss output will have a shape of [batch_size, session_max_len] # and will have zeros for every `0` target label loss = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), y, ignore_index=ingnore_index, reduction="none" + logits.transpose(1, 2), y, ignore_index=0, reduction="none" ) # [batch_size, session_max_len] loss = loss * w n = (loss > 0).to(loss.dtype) diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 4c806f1b..8460eff8 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -288,7 +288,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): train_min_user_interactions : int, default 2 Minimum number of interactions user should have to be used for training. Should be greater than 1. - loss : {"softmax", "BCE", "gBCE", "sampled_softmax}, default "softmax" + loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax" Loss function. n_negatives : int, default 1 Number of negatives for BCE and gBCE losses. diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 62a73d83..6bbbbf4c 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -290,6 +290,16 @@ def get_trainer() -> Trainer: } ), ), + ( + "sampled_softmax", + pd.DataFrame( + { + Columns.User: [30, 40, 40], + Columns.Item: [12, 12, 13], + Columns.Rank: [1, 1, 2], + } + ), + ), ), ) def test_u2i_losses( From a164610ee7bee5352ad999f5c3fc57259c1555e8 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Sun, 30 Mar 2025 11:03:41 +0300 Subject: [PATCH 4/6] changelog --- CHANGELOG.md | 4 ++++ rectools/models/nn/transformers/lightning.py | 1 + 2 files changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97483034..17d952fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Added +- `sampled_softmax` loss [#274](https://github.com/MobileTeleSystems/RecTools/pull/274) ## [0.12.0] - 24.02.2025 diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 617d13ce..4d4e5e61 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -162,6 +162,7 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to elif self.loss == "sampled_softmax": negatives = batch["negatives"] logits = self._get_pos_neg_logits(x, y, negatives) + # We put positive logits at index 1 since index 0 is used to ignore padding logits[:, :, [0, 1]] = logits[:, :, [1, 0]] target = (y != 0).long() loss = self._calc_softmax_loss(logits, target, w) From db83c5669b85cf133fae215c6f54041a5cdc7570 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Wed, 2 Apr 2025 15:02:57 +0300 Subject: [PATCH 5/6] fix tests --- tests/models/nn/transformers/test_sasrec.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 24b760bd..385e3197 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -326,6 +326,17 @@ def get_trainer() -> Trainer: ), "dot", ), + ( + "sampled_softmax", + pd.DataFrame( + { + Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], + Columns.Item: [17, 15, 13, 17, 14, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], + } + ), + "dot", + ), ( "BCE", pd.DataFrame( @@ -353,10 +364,11 @@ def get_trainer() -> Trainer: pd.DataFrame( { Columns.User: [10, 10, 30, 30, 30, 40, 40, 40], - Columns.Item: [17, 15, 13, 17, 14, 13, 14, 15], + Columns.Item: [17, 15, 13, 14, 17, 13, 14, 15], Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3], } ), + "cosine", ), ), ) From 62cb4299968767e05c821d9ce1c6c11f5bf6ea91 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Wed, 2 Apr 2025 15:13:57 +0300 Subject: [PATCH 6/6] changelog --- CHANGELOG.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 378f3032..4815ac75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased -### Added -- `sampled_softmax` loss [#274](https://github.com/MobileTeleSystems/RecTools/pull/274) - ### Added - `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) - `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276)) +- `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274)) ## [0.12.0] - 24.02.2025