From 666b50d5b5f88bea10948bb53c3a0118141e7810 Mon Sep 17 00:00:00 2001 From: Eddy Li Date: Mon, 7 Jul 2025 14:39:18 -0700 Subject: [PATCH] torchrec support on kvzch ebc emb lookup module (#3106) Summary: # Change logs 1. add ZeroCollisionKeyValueEmbeddingBag emb lookup 2. address existing unit test missing for ebc ssd offloading 3. add new ut for kv zch embedding bag module Embedding bag update is same as embedding update. Details can be found in D73567631 Reviewed By: duduyi2013, kausv Differential Revision: D76636927 --- .../distributed/batched_embedding_kernel.py | 385 +++++++++++++-- torchrec/distributed/embedding_lookup.py | 25 +- torchrec/distributed/embeddingbag.py | 121 ++++- ...test_model_parallel_nccl_ssd_single_gpu.py | 439 +++++++++++++++++- 4 files changed, 914 insertions(+), 56 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index cdf3e26fa..55e83e00e 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -65,6 +65,7 @@ compute_kernel_to_embedding_location, DTensorMetadata, GroupedEmbeddingConfig, + ShardedEmbeddingTable, ) from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import ( @@ -216,6 +217,42 @@ def _populate_zero_collision_tbe_params( ) +def _get_sharded_local_buckets_for_zero_collision( + embedding_tables: List[ShardedEmbeddingTable], + pg: Optional[dist.ProcessGroup] = None, +) -> List[Tuple[int, int, int]]: + """ + utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec + """ + sharded_local_buckets: List[Tuple[int, int, int]] = [] + world_size = dist.get_world_size(pg) + local_rank = dist.get_rank(pg) + + for table in embedding_tables: + total_num_buckets = none_throws(table.total_num_buckets) + assert ( + total_num_buckets % world_size == 0 + ), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}" + assert ( + table.total_num_buckets + and table.num_embeddings % table.total_num_buckets == 0 + ), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'" + bucket_offset_start = total_num_buckets // world_size * local_rank + bucket_offset_end = min( + total_num_buckets, total_num_buckets // world_size * (local_rank + 1) + ) + bucket_size = ( + table.num_embeddings + total_num_buckets - 1 + ) // total_num_buckets + sharded_local_buckets.append( + (bucket_offset_start, bucket_offset_end, bucket_size) + ) + logger.info( + f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}" + ) + return sharded_local_buckets + + class KeyValueEmbeddingFusedOptimizer(FusedOptimizer): def __init__( self, @@ -1076,6 +1113,11 @@ def __init__( assert ( len({table.embedding_dim for table in config.embedding_tables}) == 1 ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + for table in config.embedding_tables: + assert table.local_cols % 4 == 0, ( + f"table {table.name} has local_cols={table.local_cols} " + "not divisible by 4. " + ) ssd_tbe_params = _populate_ssd_tbe_params(config) compute_kernel = config.embedding_tables[0].compute_kernel @@ -1263,9 +1305,18 @@ def __init__( assert ( len({table.embedding_dim for table in config.embedding_tables}) == 1 ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + for table in config.embedding_tables: + assert table.local_cols % 4 == 0, ( + f"table {table.name} has local_cols={table.local_cols} " + "not divisible by 4. " + ) ssd_tbe_params = _populate_ssd_tbe_params(config) - self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets() + self._bucket_spec: List[Tuple[int, int, int]] = ( + _get_sharded_local_buckets_for_zero_collision( + self._config.embedding_tables, self._pg + ) + ) _populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec) compute_kernel = config.embedding_tables[0].compute_kernel embedding_location = compute_kernel_to_embedding_location(compute_kernel) @@ -1334,38 +1385,6 @@ def fused_optimizer(self) -> FusedOptimizer: """ return self._optim - def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]: - """ - utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec - """ - sharded_local_buckets: List[Tuple[int, int, int]] = [] - world_size = dist.get_world_size(self._pg) - local_rank = dist.get_rank(self._pg) - - for table in self._config.embedding_tables: - total_num_buckets = none_throws(table.total_num_buckets) - assert ( - total_num_buckets % world_size == 0 - ), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}" - assert ( - table.total_num_buckets - and table.num_embeddings % table.total_num_buckets == 0 - ), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'" - bucket_offset_start = total_num_buckets // world_size * local_rank - bucket_offset_end = min( - total_num_buckets, total_num_buckets // world_size * (local_rank + 1) - ) - bucket_size = ( - table.num_embeddings + total_num_buckets - 1 - ) // total_num_buckets - sharded_local_buckets.append( - (bucket_offset_start, bucket_offset_end, bucket_size) - ) - logger.info( - f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}" - ) - return sharded_local_buckets - def state_dict( self, destination: Optional[Dict[str, Any]] = None, @@ -1553,10 +1572,7 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: self._split_weights_res = None self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None) - return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), - ) + return super().forward(features) class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): @@ -1901,6 +1917,11 @@ def __init__( assert ( len({table.embedding_dim for table in config.embedding_tables}) == 1 ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + for table in config.embedding_tables: + assert table.local_cols % 4 == 0, ( + f"table {table.name} has local_cols={table.local_cols} " + "not divisible by 4. " + ) ssd_tbe_params = _populate_ssd_tbe_params(config) compute_kernel = config.embedding_tables[0].compute_kernel @@ -2070,6 +2091,296 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[ return self.emb_module.split_embedding_weights(no_snapshot) +class ZeroCollisionKeyValueEmbeddingBag( + BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule +): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, + backend_type: BackendType = BackendType.SSD, + ) -> None: + super().__init__(config, pg, device, sharding_type) + + assert ( + len(config.embedding_tables) > 0 + ), "Expected to see at least one table in SSD TBE, but found 0." + assert ( + len({table.embedding_dim for table in config.embedding_tables}) == 1 + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + + for table in config.embedding_tables: + assert table.local_cols % 4 == 0, ( + f"table {table.name} has local_cols={table.local_cols} " + "not divisible by 4. " + ) + + ssd_tbe_params = _populate_ssd_tbe_params(config) + self._bucket_spec: List[Tuple[int, int, int]] = ( + _get_sharded_local_buckets_for_zero_collision( + self._config.embedding_tables, self._pg + ) + ) + _populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec) + compute_kernel = config.embedding_tables[0].compute_kernel + embedding_location = compute_kernel_to_embedding_location(compute_kernel) + + # every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db + # use split weights result cache so that multiple calls in the same train iteration will only trigger once + self._split_weights_res: Optional[ + Tuple[ + List[ShardedTensor], + List[ShardedTensor], + List[ShardedTensor], + ] + ] = None + + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( + embedding_specs=list(zip(self._num_embeddings, self._local_cols)), + feature_table_map=self._feature_table_map, + ssd_cache_location=embedding_location, + pooling_mode=self._pooling, + backend_type=backend_type, + **ssd_tbe_params, + ).to(device) + + logger.info( + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" + ) + self._table_name_to_weight_count_per_rank: Dict[str, List[int]] = {} + self._init_sharded_split_embedding_weights() # this will populate self._split_weights_res + self._optim: ZeroCollisionKeyValueEmbeddingFusedOptimizer = ( + ZeroCollisionKeyValueEmbeddingFusedOptimizer( + config, + self._emb_module, + # pyre-ignore[16] + sharded_embedding_weights_by_table=self._split_weights_res[0], + table_name_to_weight_count_per_rank=self._table_name_to_weight_count_per_rank, + sharded_embedding_weight_ids=self._split_weights_res[1], + pg=pg, + ) + ) + self._param_per_table: Dict[str, nn.Parameter] = dict( + _gen_named_parameters_by_table_ssd_pmt( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) + self.init_parameters() + + def init_parameters(self) -> None: + """ + An advantage of KV TBE is that we don't need to init weights. Hence skipping. + """ + pass + + @property + def emb_module( + self, + ) -> SSDTableBatchedEmbeddingBags: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + """ + SSD Embedding fuses backward with backward. + """ + return self._optim + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + no_snapshot: bool = True, + ) -> Dict[str, Any]: + """ + Args: + no_snapshot (bool): the tensors in the returned dict are + PartiallyMaterializedTensors. this argument controls wether the + PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the + PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the + PartiallyMaterializedTensor has a RocksDB snapshot handle + """ + # in the case no_snapshot=False, a flush is required. we rely on the flush operation in + # ShardedEmbeddingBagCollection._pre_state_dict_hook() + + emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + emb_table.local_metadata.placement._device = torch.device("cpu") + ret = get_state_dict( + emb_table_config_copy, + emb_tables, + self._pg, + destination, + prefix, + ) + return ret + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Only allowed ways to get state_dict. + """ + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after PEA deprecation + # pyre-ignore [6] + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param + + # pyre-ignore [15] + def named_split_embedding_weights( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, Union[PartiallyMaterializedTensor, torch.Tensor]]]: + assert ( + remove_duplicate + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights()[0], + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, tensor + + # initialize sharded _split_weights_res if it's None + # this method is used to generate sharded embedding weights once for all following state_dict + # calls in checkpointing and publishing. + # When training is resumed, the cached value will be reset to None and the value needs to be + # rebuilt for next checkpointing and publishing, as the weight id, weight embedding will be updated + # during training in backend k/v store. + def _init_sharded_split_embedding_weights( + self, prefix: str = "", force_regenerate: bool = False + ) -> None: + if not force_regenerate and self._split_weights_res is not None: + return + + pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights( + no_snapshot=False, + ) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + none_throws( + none_throws( + emb_table.local_metadata, + f"local_metadata is None for emb_table: {emb_table.name}", + ).placement, + f"placement is None for local_metadata of emb table: {emb_table.name}", + )._device = torch.device("cpu") + + pmt_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, + pmt_list, + self._pg, + prefix, + self._table_name_to_weight_count_per_rank, + ) + weight_id_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, + weight_ids_list, # pyre-ignore [6] + self._pg, + prefix, + self._table_name_to_weight_count_per_rank, + ) + bucket_cnt_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, + bucket_cnt_list, # pyre-ignore [6] + self._pg, + prefix, + self._table_name_to_weight_count_per_rank, + use_param_size_as_rows=True, + ) + # pyre-ignore + assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list) + assert ( + len(pmt_sharded_t_list) + == len(weight_id_sharded_t_list) + == len(bucket_cnt_sharded_t_list) + ) + self._split_weights_res = ( + pmt_sharded_t_list, + weight_id_sharded_t_list, + bucket_cnt_sharded_t_list, + ) + + def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[ + Tuple[ + str, + Union[ShardedTensor, PartiallyMaterializedTensor], + Optional[ShardedTensor], + Optional[ShardedTensor], + ] + ]: + """ + Return an iterator over embedding tables, for each table yielding + table name, + PMT for embedding table with a valid RocksDB snapshot to support tensor IO + optional ShardedTensor for weight_id + optional ShardedTensor for bucket_cnt + """ + self._init_sharded_split_embedding_weights() + # pyre-ignore[16] + self._optim.set_sharded_embedding_weight_ids(self._split_weights_res[1]) + + pmt_sharded_t_list = self._split_weights_res[0] + weight_id_sharded_t_list = self._split_weights_res[1] + bucket_cnt_sharded_t_list = self._split_weights_res[2] + for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): + table_config = self._config.embedding_tables[table_idx] + key = append_prefix(prefix, f"{table_config.name}") + + yield key, pmt_sharded_t, weight_id_sharded_t_list[ + table_idx + ], bucket_cnt_sharded_t_list[table_idx] + + def flush(self) -> None: + """ + Flush the embeddings in cache back to SSD. Should be pretty expensive. + """ + self.emb_module.flush() + + def purge(self) -> None: + """ + Reset the cache space. This is needed when we load state dict. + """ + # TODO: move the following to SSD TBE. + self.emb_module.lxu_cache_weights.zero_() + self.emb_module.lxu_cache_state.fill_(-1) + + def create_rocksdb_hard_link_snapshot(self) -> None: + """ + Create a RocksDB checkpoint. This is needed before we call state_dict() for publish. + """ + self.emb_module.create_rocksdb_hard_link_snapshot() + + # pyre-ignore [15] + def split_embedding_weights( + self, no_snapshot: bool = True, should_flush: bool = False + ) -> Tuple[ + Union[List[PartiallyMaterializedTensor], List[torch.Tensor]], + Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], + ]: + return self.emb_module.split_embedding_weights(no_snapshot, should_flush) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + # reset split weights during training + self._split_weights_res = None + self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None) + + return super().forward(features) + + class BatchedFusedEmbeddingBag( BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule ): diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index b4e5948f4..6a7a723e6 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -40,6 +40,7 @@ KeyValueEmbedding, KeyValueEmbeddingBag, ZeroCollisionKeyValueEmbedding, + ZeroCollisionKeyValueEmbeddingBag, ) from torchrec.distributed.comm_ops import get_gradient_division from torchrec.distributed.composable.table_batched_embedding_slice import ( @@ -511,6 +512,24 @@ def _create_embedding_kernel( device=device, sharding_type=sharding_type, ) + elif config.compute_kernel == EmbeddingComputeKernel.SSD_VIRTUAL_TABLE: + # for ssd kv + return ZeroCollisionKeyValueEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + backend_type=BackendType.SSD, + ) + elif config.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE: + # for dram kv + return ZeroCollisionKeyValueEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + backend_type=BackendType.DRAM, + ) else: raise ValueError(f"Compute kernel not supported {config.compute_kernel}") @@ -524,6 +543,8 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool: if ( table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE + or table.compute_kernel == EmbeddingComputeKernel.SSD_VIRTUAL_TABLE + or table.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE ): return True return False @@ -719,7 +740,9 @@ def get_named_split_embedding_weights_snapshot( RocksDB snapshot to support windowed access. """ for emb_module in self._emb_modules: - if isinstance(emb_module, KeyValueEmbeddingBag): + if isinstance(emb_module, KeyValueEmbeddingBag) or isinstance( + emb_module, ZeroCollisionKeyValueEmbeddingBag + ): yield from emb_module.get_named_split_embedding_weights_snapshot() def flush(self) -> None: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 2beaf3aef..7eac305e2 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -35,6 +35,7 @@ from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size +from torchrec.distributed.embedding_lookup import PartiallyMaterializedTensor from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, @@ -292,6 +293,8 @@ def create_sharding_infos_by_sharding_device_group( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), + total_num_buckets=config.total_num_buckets, + use_virtual_table=config.use_virtual_table, ), param_sharding=parameter_sharding, param=param, @@ -558,6 +561,7 @@ def __init__( tbe_module.fused_optimizer.params = params optims.append(("", tbe_module.fused_optimizer)) self._optim: CombinedOptimizer = CombinedOptimizer(optims) + self._skip_missing_weight_key: List[str] = [] for i, (sharding, lookup) in enumerate( zip(self._embedding_shardings, self._lookups) @@ -687,6 +691,8 @@ def create_grouped_sharding_infos( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), + total_num_buckets=config.total_num_buckets, + use_virtual_table=config.use_virtual_table, ), param_sharding=parameter_sharding, param=param, @@ -788,6 +794,27 @@ def _pre_load_state_dict_hook( to transform from ShardedTensors/DTensors into tensors """ for table_name in self._model_parallel_name_to_local_shards.keys(): + if self._table_name_to_config[table_name].use_virtual_table: + # weight_id and bucket are generated at the runtime of state_dict instead of registered class + # so we need to erase them before passing into load_state_dict + weight_key = f"{prefix}embedding_bags.{table_name}.weight" + weight_id_key = f"{prefix}embedding_bags.{table_name}.weight_id" + bucket_key = f"{prefix}embedding_bags.{table_name}.bucket" + if weight_id_key in state_dict: + del state_dict[weight_id_key] + if bucket_key in state_dict: + del state_dict[bucket_key] + assert weight_key in state_dict + assert ( + len(self._model_parallel_name_to_local_shards[table_name]) == 1 + ), "currently only support 1 shard per rank" + + # for loading state_dict into virtual table, we skip the weights assignment + # if needed, for now this should be handled separately outside of load_state_dict call + self._skip_missing_weight_key.append(weight_key) + del state_dict[weight_key] + continue + key = f"{prefix}embedding_bags.{table_name}.weight" # gather model shards from both DTensor and ShardedTensor maps model_shards_sharded_tensor = self._model_parallel_name_to_local_shards[ @@ -947,6 +974,8 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no shards_wrapper["global_stride"] = v.stride() shards_wrapper["placements"] = v.placements elif isinstance(v, ShardedTensor): + # for virtual table, we only populate the shardedTensor for Embedding Table during + # initial state_dict calls, skip weight id and bucket tensor self._model_parallel_name_to_local_shards[table_name].extend( v.local_shards() ) @@ -956,6 +985,10 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `named_parameters_by_table`. ) in lookup.named_parameters_by_table(): + # for virtual table, currently we don't expose id tensor and bucket tensor + # because they are not updated in real time, and they are created on the fly + # whenever state_dict is called + # reference: ƒbgs _gen_named_parameters_by_table_ssd_pmt self.embedding_bags[table_name].register_parameter("weight", tbe_slice) for table_name in self._model_parallel_name_to_local_shards.keys(): @@ -1061,7 +1094,9 @@ def extract_sharded_kvtensors( sharded_t, ) in module._model_parallel_name_to_sharded_tensor.items(): if _model_parallel_name_to_compute_kernel[table_name] in { - EmbeddingComputeKernel.KEY_VALUE.value + EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, }: ret[table_name] = sharded_t return ret @@ -1093,24 +1128,88 @@ def post_state_dict_hook( return sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors) + virtual_table_sharded_t_map: Optional[ + Dict[str, Tuple[ShardedTensor, ShardedTensor]] + ] = None for lookup, sharding in zip(module._lookups, module._embedding_shardings): if not isinstance(sharding, DpPooledEmbeddingSharding): for ( - key, - v, - _, - _, + table_name, + weights_t, + weight_ids_sharded_t, + id_cnt_per_bucket_sharded_t, ) in ( lookup.get_named_split_embedding_weights_snapshot() # pyre-ignore ): - assert key in sharded_kvtensors_copy - sharded_kvtensors_copy[key].local_shards()[0].tensor = v + assert table_name in sharded_kvtensors_copy + if self._table_name_to_config[table_name].use_virtual_table: + assert isinstance(weights_t, ShardedTensor) + if virtual_table_sharded_t_map is None: + virtual_table_sharded_t_map = {} + assert ( + weight_ids_sharded_t is not None + and id_cnt_per_bucket_sharded_t is not None + ) + # The logic here assumes there is only one shard per table on any particular rank + # if there are cases each rank has >1 shards, we need to update here accordingly + sharded_kvtensors_copy[table_name] = weights_t + virtual_table_sharded_t_map[table_name] = ( + weight_ids_sharded_t, + id_cnt_per_bucket_sharded_t, + ) + else: + assert isinstance(weights_t, PartiallyMaterializedTensor) + assert ( + weight_ids_sharded_t is None + and id_cnt_per_bucket_sharded_t is None + ) + # The logic here assumes there is only one shard per table on any particular rank + # if there are cases each rank has >1 shards, we need to update here accordingly + # pyre-ignore + sharded_kvtensors_copy[table_name].local_shards()[ + 0 + ].tensor = weights_t + + def update_destination( + table_name: str, + tensor_name: str, + destination: Dict[str, torch.Tensor], + value: torch.Tensor, + ) -> None: + destination_key = f"{prefix}embedding_bags.{table_name}.{tensor_name}" + destination[destination_key] = value + for ( table_name, sharded_kvtensor, ) in sharded_kvtensors_copy.items(): - destination_key = f"{prefix}embedding_bags.{table_name}.weight" - destination[destination_key] = sharded_kvtensor + update_destination(table_name, "weight", destination, sharded_kvtensor) + if ( + virtual_table_sharded_t_map + and table_name in virtual_table_sharded_t_map + ): + update_destination( + table_name, + "weight_id", + destination, + virtual_table_sharded_t_map[table_name][0], + ) + update_destination( + table_name, + "bucket", + destination, + virtual_table_sharded_t_map[table_name][1], + ) + + def _post_load_state_dict_hook( + module: "ShardedEmbeddingBagCollection", + incompatible_keys: _IncompatibleKeys, + ) -> None: + if incompatible_keys.missing_keys: + # has to remove the key inplace + for skip_key in module._skip_missing_weight_key: + if skip_key in incompatible_keys.missing_keys: + incompatible_keys.missing_keys.remove(skip_key) if not skip_registering: self.register_state_dict_pre_hook(self._pre_state_dict_hook) @@ -1118,6 +1217,8 @@ def post_state_dict_hook( self._register_load_state_dict_pre_hook( self._pre_load_state_dict_hook, with_module=True ) + self.register_load_state_dict_post_hook(_post_load_state_dict_hook) + self.reset_parameters() def reset_parameters(self) -> None: @@ -1128,6 +1229,8 @@ def reset_parameters(self) -> None: for table_config in self._embedding_bag_configs: if self.module_sharding_plan[table_config.name].compute_kernel in { EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, }: continue assert table_config.init_fn is not None diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py index 6b9ee360b..60f29d2f6 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -20,6 +20,7 @@ KeyValueEmbedding, KeyValueEmbeddingBag, ZeroCollisionKeyValueEmbedding, + ZeroCollisionKeyValueEmbeddingBag, ) from torchrec.distributed.embedding_types import ( EmbeddingComputeKernel, @@ -853,6 +854,435 @@ def test_ssd_load_state_dict( self._compare_models(m1, m2, is_deterministic=is_deterministic) +class ZeroCollisionModelParallelTest(ModelParallelSingleRankBase): + def _create_tables(self) -> None: + num_features = 4 + self.tables += [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=256, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + total_num_buckets=10, + use_virtual_table=True, + ) + for i in range(num_features) + ] + + @staticmethod + def _copy_ssd_emb_modules( + m1: DistributedModelParallel, m2: DistributedModelParallel + ) -> None: + """ + Util function to copy and set the SSD TBE modules of two models. It + requires both DMP modules to have the same sharding plan. + """ + for lookup1, lookup2 in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + m1.module.sparse.ebc._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + m2.module.sparse.ebc._lookups, + ): + for emb_module1, emb_module2 in zip( + lookup1._emb_modules, lookup2._emb_modules + ): + ssd_emb_modules = { + ZeroCollisionKeyValueEmbeddingBag, + ZeroCollisionKeyValueEmbedding, + } + if type(emb_module1) in ssd_emb_modules: + assert type(emb_module1) is type(emb_module2), ( + "Expect two emb_modules to be of the same type, either both " + "SSDEmbeddingBag or SSDEmbeddingBag." + ) + emb_module1.flush() + emb_module2.flush() + + emb1_kv = { + t: (sharded_t, sharded_w_id, bucket) + for t, sharded_t, sharded_w_id, bucket in emb_module1.get_named_split_embedding_weights_snapshot() + } + for ( + t, + sharded_t2, + _, + _, + ) in emb_module2.get_named_split_embedding_weights_snapshot(): + assert t in emb1_kv + sharded_t1 = emb1_kv[t][0] + sharded_w1_id = emb1_kv[t][1] + w1_id = sharded_w1_id.local_shards()[0].tensor + + pmt1 = sharded_t1.local_shards()[0].tensor + w1 = pmt1.get_weights_by_ids(w1_id) + + # write value into ssd for both emb module for later comparison + pmt2 = sharded_t2.local_shards()[0].tensor + pmt2.wrapped.set_weights_and_ids(w1, w1_id.view(-1)) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + emb_module1.purge() + emb_module2.purge() + + @staticmethod + def _copy_fused_modules_into_ssd_emb_modules( + fused_m: DistributedModelParallel, ssd_m: DistributedModelParallel + ) -> None: + """ + Util function to copy from fused embedding module to SSD TBE for initialization. It + requires both DMP modules to have the same sharding plan. + """ + + for fused_lookup, ssd_lookup in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + fused_m.module.sparse.ebc._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + ssd_m.module.sparse.ebc._lookups, + ): + for fused_emb_module, ssd_emb_module in zip( + fused_lookup._emb_modules, ssd_lookup._emb_modules + ): + ssd_emb_modules = { + ZeroCollisionKeyValueEmbeddingBag, + ZeroCollisionKeyValueEmbedding, + } + if type(ssd_emb_module) in ssd_emb_modules: + fused_state_dict = fused_emb_module.state_dict() + for ( + t, + sharded_t, + _, + _, + ) in ssd_emb_module.get_named_split_embedding_weights_snapshot(): + weight_key = f"{t}.weight" + fused_sharded_t = fused_state_dict[weight_key] + fused_weight = fused_sharded_t.local_shards()[0].tensor.to( + "cpu" + ) + + # write value into ssd for both emb module for later comparison + pmt = sharded_t.local_shards()[0].tensor + pmt.wrapped.set_range(0, 0, fused_weight.size(0), fused_weight) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + fused_emb_module.purge() + ssd_emb_module.purge() + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + ] + ), + sharding_type=st.sampled_from( + [ + # TODO: add other test cases when kv embedding support other sharding + ShardingType.ROW_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_kv_zch_load_state_dict( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + This test checks that if SSD TBE is deterministic. That is, if two SSD + TBEs start with the same state, they would produce the same output. + """ + self._set_table_weights_precision(dtype) + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models( + m1, m2, is_deterministic=is_deterministic, use_virtual_table=True + ) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + ] + ), + sharding_type=st.sampled_from( + [ + # TODO: add other test cases when kv embedding support other sharding + ShardingType.ROW_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_kv_zch_numerical_accuracy( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Make sure it produces same numbers as normal TBE. + """ + self._set_table_weights_precision(dtype) + + base_kernel_type = EmbeddingComputeKernel.FUSED.value + learning_rate = 0.1 + fused_params = { + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + "learning_rate": learning_rate, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + base_kernel_type, # base kernel type + fused_params=fused_params, + ), + ), + ] + ssd_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ), + ] + ssd_constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + # for fused model, we need to change the table config to non-kvzch + ssd_tables = copy.deepcopy(self.tables) + for table in self.tables: + table.total_num_buckets = None + table.use_virtual_table = False + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + self.tables = ssd_tables + (ssd_model, _), batch = self._generate_dmps_and_batch( + ssd_sharders, constraints=ssd_constraints + ) + + # load state dict for dense modules + ssd_model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict()) + ) + + # for this to work, we expect the order of lookups to be the same + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + assert len(fused_model.module.sparse.ebc._lookups) == len( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + ssd_model.module.sparse.ebc._lookups + ), "Expect same number of lookups" + + for fused_lookup, ssd_lookup in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + fused_model.module.sparse.ebc._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + ssd_model.module.sparse.ebc._lookups, + ): + assert len(fused_lookup._emb_modules) == len( + ssd_lookup._emb_modules + ), "Expect same number of emb modules" + + self._copy_fused_modules_into_ssd_emb_modules(fused_model, ssd_model) + + if is_training: + self._train_models(fused_model, ssd_model, batch) + self._eval_models( + fused_model, ssd_model, batch, is_deterministic=is_deterministic + ) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + ] + ), + sharding_type=st.sampled_from( + [ + # TODO: add other test cases when kv embedding support other sharding + ShardingType.ROW_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_kv_zch_fused_optimizer( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Purpose of this test is to make sure it works with warm up policy. + """ + self._set_table_weights_precision(dtype) + + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + base_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.2, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, batch = self._generate_dmps_and_batch( + base_sharders, # pyre-ignore + constraints=constraints, + ) + base_model, _ = models + + test_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, _ = self._generate_dmps_and_batch( + test_sharders, # pyre-ignore + constraints=constraints, + ) + test_model, _ = models + + # load state dict for dense modules + test_model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", base_model.state_dict()) + ) + self._copy_ssd_emb_modules(base_model, test_model) + + self._eval_models( + base_model, test_model, batch, is_deterministic=is_deterministic + ) + + # change learning rate for test_model + fused_opt = test_model.fused_optimizer + # pyre-ignore + fused_opt.param_groups[0]["lr"] = 0.2 + fused_opt.zero_grad() + + if is_training: + self._train_models(base_model, test_model, batch) + self._eval_models( + base_model, test_model, batch, is_deterministic=is_deterministic + ) + self._compare_models(base_model, test_model, is_deterministic=is_deterministic) + + # TODO: uncomment this when we have multiple kernels in rw support(unblock input dist) + # def test_ssd_mixed_kernels + + # TODO: uncomment this when we support different sharding types, e.g. tw, tw_rw together with rw + # def test_ssd_mixed_sharding_types + + class ZeroCollisionSequenceModelParallelStateDictTest(ModelParallelSingleRankBase): def setUp(self, backend: str = "nccl") -> None: self.shared_features = [] @@ -1196,15 +1626,6 @@ def test_kv_zch_numerical_accuracy( fused_model, ssd_model, batch, is_deterministic=is_deterministic ) - # TODO: uncomment this when we have optimizer plumb through - # def test_ssd_fused_optimizer( - - # TODO: uncomment this when we have multiple kernels in rw support(unblock input dist) - # def test_ssd_mixed_kernels - - # TODO: uncomment this when we support different sharding types, e.g. tw, tw_rw together with rw - # def test_ssd_mixed_sharding_types - # TODO: remove after development is done def main() -> None: