From 6535d2ee0bf4a8bb31d06d48f18e223d8c014a6d Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Tue, 1 Apr 2025 11:49:42 +0300 Subject: [PATCH 1/3] add out_dim --- rectools/models/nn/item_net.py | 15 +++++++++++++++ tests/models/nn/test_item_net.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index 0eb160fe..95cc6c1e 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -46,6 +46,11 @@ def get_all_embeddings(self) -> torch.Tensor: """Return item embeddings.""" raise NotImplementedError() + @property + def out_dim(self) -> int: + """Return item embedding output dimension.""" + raise NotImplementedError() + @property def device(self) -> torch.device: """Return ItemNet device.""" @@ -222,6 +227,11 @@ def from_dataset_schema( ) return None + @property + def out_dim(self) -> int: + """Return categorical item embedding output dimension.""" + return self.embedding_bag.embedding_dim + class IdEmbeddingsItemNet(ItemNetBase): """ @@ -317,6 +327,11 @@ def from_dataset_schema( n_items = dataset_schema.items.n_hot return cls(n_factors, n_items, dropout_rate) + @property + def out_dim(self) -> int: + """Return item embedding output dimension.""" + return self.ids_emb.embedding_dim + class ItemNetConstructorBase(ItemNetBase): """ diff --git a/tests/models/nn/test_item_net.py b/tests/models/nn/test_item_net.py index b0ae369b..2598d950 100644 --- a/tests/models/nn/test_item_net.py +++ b/tests/models/nn/test_item_net.py @@ -60,6 +60,12 @@ def test_embedding_shape_after_model_pass(self, n_items: int, n_factors: int) -> expected_item_ids = item_id_embeddings(items) assert expected_item_ids.shape == (n_items, n_factors) + @pytest.mark.parametrize("n_factors", ((2), (10))) + def test_out_dim(self, n_factors: int) -> None: + item_id_embeddings = IdEmbeddingsItemNet.from_dataset(DATASET, n_factors=n_factors, dropout_rate=0.5) + out_dim = item_id_embeddings.out_dim + assert out_dim == n_factors + @pytest.mark.filterwarnings("ignore::DeprecationWarning") class TestCatFeaturesItemNet: @@ -295,6 +301,15 @@ def test_warns_when_dataset_schema_categorical_features_are_none(self) -> None: """ ) + @pytest.mark.parametrize("n_factors", ((2), (10))) + def test_out_dim(self, dataset_item_features: Dataset, n_factors: int) -> None: + cat_item_embeddings = CatFeaturesItemNet.from_dataset( + dataset_item_features, n_factors=n_factors, dropout_rate=0.5 + ) + if cat_item_embeddings is not None: + out_dim = cat_item_embeddings.out_dim + assert out_dim == n_factors + @pytest.mark.filterwarnings("ignore::DeprecationWarning") class TestSumOfEmbeddingsConstructor: From c382251b8a68decb67cf97181f1cd356a0887bf8 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Tue, 1 Apr 2025 15:01:30 +0300 Subject: [PATCH 2/3] out_dim for sum of embeddings constructor --- CHANGELOG.md | 1 + rectools/models/nn/item_net.py | 10 ++++++++++ tests/models/nn/test_item_net.py | 23 ++++++++++++++++++++--- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bba9d340..06cf7177 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) +- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276)) ## [0.12.0] - 24.02.2025 diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index 95cc6c1e..bdfc5853 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -449,6 +449,11 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: """ raise NotImplementedError() + @property + def out_dim(self) -> int: + """Return item net constructor output dimension.""" + raise NotImplementedError() + class SumOfEmbeddingsConstructor(ItemNetConstructorBase): """ @@ -482,3 +487,8 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: item_emb = self.item_net_blocks[idx_block](items) item_embs.append(item_emb) return torch.sum(torch.stack(item_embs, dim=0), dim=0) + + @property + def out_dim(self) -> int: + """Return item net constructor output dimension.""" + return self.item_net_blocks[0].out_dim diff --git a/tests/models/nn/test_item_net.py b/tests/models/nn/test_item_net.py index 2598d950..115e1f3a 100644 --- a/tests/models/nn/test_item_net.py +++ b/tests/models/nn/test_item_net.py @@ -306,9 +306,9 @@ def test_out_dim(self, dataset_item_features: Dataset, n_factors: int) -> None: cat_item_embeddings = CatFeaturesItemNet.from_dataset( dataset_item_features, n_factors=n_factors, dropout_rate=0.5 ) - if cat_item_embeddings is not None: - out_dim = cat_item_embeddings.out_dim - assert out_dim == n_factors + assert isinstance(cat_item_embeddings, CatFeaturesItemNet) + out_dim = cat_item_embeddings.out_dim + assert out_dim == n_factors @pytest.mark.filterwarnings("ignore::DeprecationWarning") @@ -500,3 +500,20 @@ def test_raise_when_no_item_net_blocks( SumOfEmbeddingsConstructor.from_dataset( ds, n_factors=10, dropout_rate=0.5, item_net_block_types=item_net_block_types ) + + @pytest.mark.parametrize( + "item_net_block_types,n_factors", + ( + ((IdEmbeddingsItemNet,), 8), + ((IdEmbeddingsItemNet, CatFeaturesItemNet), 16), + ((CatFeaturesItemNet,), 16), + ), + ) + def test_out_dim( + self, dataset_item_features: Dataset, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]], n_factors: int + ) -> None: + item_net = SumOfEmbeddingsConstructor.from_dataset( + dataset_item_features, n_factors=n_factors, dropout_rate=0.5, item_net_block_types=item_net_block_types + ) + out_dim = item_net.out_dim + assert out_dim == n_factors From bb1c9e606d2024aa8b7255102ac37192a7db8129 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Tue, 1 Apr 2025 16:10:47 +0300 Subject: [PATCH 3/3] remove out_dim from base constructor --- rectools/models/nn/item_net.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index bdfc5853..65c2f98f 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -449,11 +449,6 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: """ raise NotImplementedError() - @property - def out_dim(self) -> int: - """Return item net constructor output dimension.""" - raise NotImplementedError() - class SumOfEmbeddingsConstructor(ItemNetConstructorBase): """