Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor:
return torch.cat([jt.values() for jt in jd.values()])


@torch.fx.wrap
def _cat_jagged_lengths(jd: Dict[str, JaggedTensor]) -> torch.Tensor:
return torch.cat([jt.lengths() for jt in jd.values()])


# TODO: keep the old implementation for backward compatibility and will remove it later
@torch.fx.wrap
def _mcc_lazy_init(
Expand Down Expand Up @@ -416,10 +421,11 @@ def forward(

assert output is not None
values: torch.Tensor = _cat_jagged_values(output)
lengths: torch.Tensor = _cat_jagged_lengths(output)
return KeyedJaggedTensor(
keys=features.keys(),
values=values,
lengths=features.lengths(),
lengths=lengths,
weights=features.weights_or_none(),
)

Expand Down
81 changes: 80 additions & 1 deletion torchrec/modules/tests/test_mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
from typing import Dict

import torch
from torchrec.fb.modules.hash_mc_evictions import (
HashZchEvictionConfig,
HashZchEvictionPolicyName,
)
from torchrec.fb.modules.hash_mc_modules import HashZchManagedCollisionModule
from torchrec.modules.embedding_configs import EmbeddingConfig
from torchrec.modules.mc_modules import (
average_threshold_filter,
DistanceLFU_EvictionPolicy,
dynamic_threshold_filter,
LFU_EvictionPolicy,
LRU_EvictionPolicy,
ManagedCollisionCollection,
MCHManagedCollisionModule,
probabilistic_threshold_filter,
)
from torchrec.sparse.jagged_tensor import JaggedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor


class TestEvictionPolicy(unittest.TestCase):
Expand Down Expand Up @@ -427,3 +434,75 @@ def test_fx_jit_script_not_training(self) -> None:
model.train(False)
gm = torch.fx.symbolic_trace(model)
torch.jit.script(gm)

def test_mc_module_forward(self) -> None:
embedding_configs = [
EmbeddingConfig(
name="t1",
num_embeddings=100,
embedding_dim=8,
feature_names=["f1", "f2"],
),
EmbeddingConfig(
name="t2",
num_embeddings=100,
embedding_dim=8,
feature_names=["f3", "f4"],
),
]

mc_modules = {
"t1": HashZchManagedCollisionModule(
zch_size=100,
device=torch.device("cpu"),
total_num_buckets=1,
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
eviction_config=HashZchEvictionConfig(
features=[],
single_ttl=10,
),
),
"t2": HashZchManagedCollisionModule(
zch_size=100,
device=torch.device("cpu"),
total_num_buckets=1,
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
eviction_config=HashZchEvictionConfig(
features=[],
single_ttl=10,
),
),
}
for mc_module in mc_modules.values():
mc_module.reset_inference_mode()
mc_ebc = ManagedCollisionCollection(
# Pyre-ignore [6]: In call `ManagedCollisionCollection.__init__`, for argument `managed_collision_modules`, expected `Dict[str, ManagedCollisionModule]` but got `Dict[str, HashZchManagedCollisionModule]`
managed_collision_modules=mc_modules,
embedding_configs=embedding_configs,
)
kjt = KeyedJaggedTensor(
keys=["f1", "f2", "f3", "f4"],
values=torch.cat(
[
torch.arange(0, 20, 2, dtype=torch.int64, device="cpu"),
torch.arange(30, 60, 3, dtype=torch.int64, device="cpu"),
torch.arange(20, 30, 2, dtype=torch.int64, device="cpu"),
torch.arange(0, 20, 2, dtype=torch.int64, device="cpu"),
]
),
lengths=torch.cat(
[
torch.tensor([4, 6], dtype=torch.int64, device="cpu"),
torch.tensor([5, 5], dtype=torch.int64, device="cpu"),
torch.tensor([1, 4], dtype=torch.int64, device="cpu"),
torch.tensor([7, 3], dtype=torch.int64, device="cpu"),
]
),
)
res = mc_ebc.forward(kjt)
self.assertTrue(torch.equal(res.lengths(), kjt.lengths()))
self.assertTrue(
torch.equal(
res.lengths(), torch.tensor([4, 6, 5, 5, 1, 4, 7, 3], dtype=torch.int64)
)
)
Loading