Skip to content

Commit c382251

Browse files
committed
out_dim for sum of embeddings constructor
1 parent 8c1a299 commit c382251

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
11+
- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276))
1112

1213
## [0.12.0] - 24.02.2025
1314

rectools/models/nn/item_net.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,11 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
449449
"""
450450
raise NotImplementedError()
451451

452+
@property
453+
def out_dim(self) -> int:
454+
"""Return item net constructor output dimension."""
455+
raise NotImplementedError()
456+
452457

453458
class SumOfEmbeddingsConstructor(ItemNetConstructorBase):
454459
"""
@@ -482,3 +487,8 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
482487
item_emb = self.item_net_blocks[idx_block](items)
483488
item_embs.append(item_emb)
484489
return torch.sum(torch.stack(item_embs, dim=0), dim=0)
490+
491+
@property
492+
def out_dim(self) -> int:
493+
"""Return item net constructor output dimension."""
494+
return self.item_net_blocks[0].out_dim

tests/models/nn/test_item_net.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,9 @@ def test_out_dim(self, dataset_item_features: Dataset, n_factors: int) -> None:
306306
cat_item_embeddings = CatFeaturesItemNet.from_dataset(
307307
dataset_item_features, n_factors=n_factors, dropout_rate=0.5
308308
)
309-
if cat_item_embeddings is not None:
310-
out_dim = cat_item_embeddings.out_dim
311-
assert out_dim == n_factors
309+
assert isinstance(cat_item_embeddings, CatFeaturesItemNet)
310+
out_dim = cat_item_embeddings.out_dim
311+
assert out_dim == n_factors
312312

313313

314314
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@@ -500,3 +500,20 @@ def test_raise_when_no_item_net_blocks(
500500
SumOfEmbeddingsConstructor.from_dataset(
501501
ds, n_factors=10, dropout_rate=0.5, item_net_block_types=item_net_block_types
502502
)
503+
504+
@pytest.mark.parametrize(
505+
"item_net_block_types,n_factors",
506+
(
507+
((IdEmbeddingsItemNet,), 8),
508+
((IdEmbeddingsItemNet, CatFeaturesItemNet), 16),
509+
((CatFeaturesItemNet,), 16),
510+
),
511+
)
512+
def test_out_dim(
513+
self, dataset_item_features: Dataset, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]], n_factors: int
514+
) -> None:
515+
item_net = SumOfEmbeddingsConstructor.from_dataset(
516+
dataset_item_features, n_factors=n_factors, dropout_rate=0.5, item_net_block_types=item_net_block_types
517+
)
518+
out_dim = item_net.out_dim
519+
assert out_dim == n_factors

0 commit comments

Comments
 (0)