Skip to content

Commit 22baa4f

Browse files
EddyLXJmeta-codesync[bot]
authored andcommitted
st publish mode only load weight (#3538)
Summary: X-link: pytorch/FBGEMM#5116 Pull Request resolved: #3538 X-link: https://github.com/facebookresearch/FBGEMM/pull/2122 For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host. So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that backend will only have metaheader + weight. For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error. after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb Reviewed By: emlin Differential Revision: D85830053 fbshipit-source-id: 0eddbe9e69ea8271e8c77dc0147e87a08f0b3934
1 parent 979f102 commit 22baa4f

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,13 +477,15 @@ def _populate_zero_collision_tbe_params(
477477
else False
478478
)
479479
)
480+
480481
tbe_params["kv_zch_params"] = KVZCHParams(
481482
bucket_offsets=bucket_offsets,
482483
bucket_sizes=bucket_sizes,
483484
enable_optimizer_offloading=True,
484485
backend_return_whole_row=(backend_type == BackendType.DRAM),
485486
eviction_policy=eviction_policy,
486487
embedding_cache_mode=embedding_cache_mode_,
488+
load_ckpt_without_opt=eviction_tbe_config.load_ckpt_without_opt,
487489
)
488490

489491

torchrec/distributed/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ class KeyValueParams:
664664
enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE
665665
res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings
666666
kvzch_tbe_config: Optional[KVZCHTBEConfig]: KVZCH config for TBE
667+
load_ckpt_without_opt: bool: whether it is st publish
667668
668669
# Parameter Server (PS) Attributes
669670
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
@@ -690,6 +691,7 @@ class KeyValueParams:
690691
)
691692
res_store_shards: Optional[int] = None # shards to store the raw embeddings
692693
kvzch_tbe_config: Optional[KVZCHTBEConfig] = None
694+
load_ckpt_without_opt: bool = False # is st publish
693695

694696
# Parameter Server (PS) Attributes
695697
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
@@ -719,6 +721,7 @@ def __hash__(self) -> int:
719721
self.enable_raw_embedding_streaming,
720722
self.res_store_shards,
721723
self.kvzch_tbe_config,
724+
self.load_ckpt_without_opt,
722725
)
723726
)
724727

0 commit comments

Comments
 (0)