diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index d67b4ff4a..896221cd7 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -54,6 +54,11 @@ def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor: return torch.cat([jt.values() for jt in jd.values()]) +@torch.fx.wrap +def _cat_jagged_lengths(jd: Dict[str, JaggedTensor]) -> torch.Tensor: + return torch.cat([jt.lengths() for jt in jd.values()]) + + # TODO: keep the old implementation for backward compatibility and will remove it later @torch.fx.wrap def _mcc_lazy_init( @@ -416,10 +421,11 @@ def forward( assert output is not None values: torch.Tensor = _cat_jagged_values(output) + lengths: torch.Tensor = _cat_jagged_lengths(output) return KeyedJaggedTensor( keys=features.keys(), values=values, - lengths=features.lengths(), + lengths=lengths, weights=features.weights_or_none(), ) diff --git a/torchrec/modules/tests/test_mc_modules.py b/torchrec/modules/tests/test_mc_modules.py index 8fac2ac25..8bddd7907 100644 --- a/torchrec/modules/tests/test_mc_modules.py +++ b/torchrec/modules/tests/test_mc_modules.py @@ -11,16 +11,23 @@ from typing import Dict import torch +from torchrec.fb.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) +from torchrec.fb.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.embedding_configs import EmbeddingConfig from torchrec.modules.mc_modules import ( average_threshold_filter, DistanceLFU_EvictionPolicy, dynamic_threshold_filter, LFU_EvictionPolicy, LRU_EvictionPolicy, + ManagedCollisionCollection, MCHManagedCollisionModule, probabilistic_threshold_filter, ) -from torchrec.sparse.jagged_tensor import JaggedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor class TestEvictionPolicy(unittest.TestCase): @@ -427,3 +434,75 @@ def test_fx_jit_script_not_training(self) -> None: model.train(False) gm = torch.fx.symbolic_trace(model) torch.jit.script(gm) + + def test_mc_module_forward(self) -> None: + embedding_configs = [ + EmbeddingConfig( + name="t1", + num_embeddings=100, + embedding_dim=8, + feature_names=["f1", "f2"], + ), + EmbeddingConfig( + name="t2", + num_embeddings=100, + embedding_dim=8, + feature_names=["f3", "f4"], + ), + ] + + mc_modules = { + "t1": HashZchManagedCollisionModule( + zch_size=100, + device=torch.device("cpu"), + total_num_buckets=1, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ), + "t2": HashZchManagedCollisionModule( + zch_size=100, + device=torch.device("cpu"), + total_num_buckets=1, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ), + } + for mc_module in mc_modules.values(): + mc_module.reset_inference_mode() + mc_ebc = ManagedCollisionCollection( + # Pyre-ignore [6]: In call `ManagedCollisionCollection.__init__`, for argument `managed_collision_modules`, expected `Dict[str, ManagedCollisionModule]` but got `Dict[str, HashZchManagedCollisionModule]` + managed_collision_modules=mc_modules, + embedding_configs=embedding_configs, + ) + kjt = KeyedJaggedTensor( + keys=["f1", "f2", "f3", "f4"], + values=torch.cat( + [ + torch.arange(0, 20, 2, dtype=torch.int64, device="cpu"), + torch.arange(30, 60, 3, dtype=torch.int64, device="cpu"), + torch.arange(20, 30, 2, dtype=torch.int64, device="cpu"), + torch.arange(0, 20, 2, dtype=torch.int64, device="cpu"), + ] + ), + lengths=torch.cat( + [ + torch.tensor([4, 6], dtype=torch.int64, device="cpu"), + torch.tensor([5, 5], dtype=torch.int64, device="cpu"), + torch.tensor([1, 4], dtype=torch.int64, device="cpu"), + torch.tensor([7, 3], dtype=torch.int64, device="cpu"), + ] + ), + ) + res = mc_ebc.forward(kjt) + self.assertTrue(torch.equal(res.lengths(), kjt.lengths())) + self.assertTrue( + torch.equal( + res.lengths(), torch.tensor([4, 6, 5, 5, 1, 4, 7, 3], dtype=torch.int64) + ) + )