diff --git a/CHANGELOG.md b/CHANGELOG.md index bba9d340..06cf7177 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272)) +- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276)) ## [0.12.0] - 24.02.2025 diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index 0eb160fe..65c2f98f 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -46,6 +46,11 @@ def get_all_embeddings(self) -> torch.Tensor: """Return item embeddings.""" raise NotImplementedError() + @property + def out_dim(self) -> int: + """Return item embedding output dimension.""" + raise NotImplementedError() + @property def device(self) -> torch.device: """Return ItemNet device.""" @@ -222,6 +227,11 @@ def from_dataset_schema( ) return None + @property + def out_dim(self) -> int: + """Return categorical item embedding output dimension.""" + return self.embedding_bag.embedding_dim + class IdEmbeddingsItemNet(ItemNetBase): """ @@ -317,6 +327,11 @@ def from_dataset_schema( n_items = dataset_schema.items.n_hot return cls(n_factors, n_items, dropout_rate) + @property + def out_dim(self) -> int: + """Return item embedding output dimension.""" + return self.ids_emb.embedding_dim + class ItemNetConstructorBase(ItemNetBase): """ @@ -467,3 +482,8 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: item_emb = self.item_net_blocks[idx_block](items) item_embs.append(item_emb) return torch.sum(torch.stack(item_embs, dim=0), dim=0) + + @property + def out_dim(self) -> int: + """Return item net constructor output dimension.""" + return self.item_net_blocks[0].out_dim diff --git a/tests/models/nn/test_item_net.py b/tests/models/nn/test_item_net.py index b0ae369b..115e1f3a 100644 --- a/tests/models/nn/test_item_net.py +++ b/tests/models/nn/test_item_net.py @@ -60,6 +60,12 @@ def test_embedding_shape_after_model_pass(self, n_items: int, n_factors: int) -> expected_item_ids = item_id_embeddings(items) assert expected_item_ids.shape == (n_items, n_factors) + @pytest.mark.parametrize("n_factors", ((2), (10))) + def test_out_dim(self, n_factors: int) -> None: + item_id_embeddings = IdEmbeddingsItemNet.from_dataset(DATASET, n_factors=n_factors, dropout_rate=0.5) + out_dim = item_id_embeddings.out_dim + assert out_dim == n_factors + @pytest.mark.filterwarnings("ignore::DeprecationWarning") class TestCatFeaturesItemNet: @@ -295,6 +301,15 @@ def test_warns_when_dataset_schema_categorical_features_are_none(self) -> None: """ ) + @pytest.mark.parametrize("n_factors", ((2), (10))) + def test_out_dim(self, dataset_item_features: Dataset, n_factors: int) -> None: + cat_item_embeddings = CatFeaturesItemNet.from_dataset( + dataset_item_features, n_factors=n_factors, dropout_rate=0.5 + ) + assert isinstance(cat_item_embeddings, CatFeaturesItemNet) + out_dim = cat_item_embeddings.out_dim + assert out_dim == n_factors + @pytest.mark.filterwarnings("ignore::DeprecationWarning") class TestSumOfEmbeddingsConstructor: @@ -485,3 +500,20 @@ def test_raise_when_no_item_net_blocks( SumOfEmbeddingsConstructor.from_dataset( ds, n_factors=10, dropout_rate=0.5, item_net_block_types=item_net_block_types ) + + @pytest.mark.parametrize( + "item_net_block_types,n_factors", + ( + ((IdEmbeddingsItemNet,), 8), + ((IdEmbeddingsItemNet, CatFeaturesItemNet), 16), + ((CatFeaturesItemNet,), 16), + ), + ) + def test_out_dim( + self, dataset_item_features: Dataset, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]], n_factors: int + ) -> None: + item_net = SumOfEmbeddingsConstructor.from_dataset( + dataset_item_features, n_factors=n_factors, dropout_rate=0.5, item_net_block_types=item_net_block_types + ) + out_dim = item_net.out_dim + assert out_dim == n_factors