Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 99 additions & 61 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,69 +492,16 @@ def dynamic_sharding_test(

assert ctx.pg is not None

num_tables = len(tables)

ranks_per_tables = [1 for _ in range(num_tables)]
new_ranks = generate_rank_placements(
world_size, num_tables, ranks_per_tables, random_seed
)

ranks_per_tables_for_CW = []
for table in tables:

# CW sharding
valid_candidates = [
i for i in range(1, world_size + 1) if table.embedding_dim % i == 0
]
ranks_per_tables_for_CW.append(random.choice(valid_candidates))

new_ranks_cw = generate_rank_placements(
world_size, num_tables, ranks_per_tables_for_CW, random_seed
)

new_per_param_sharding = {}

assert len(sharders) == 1
# pyre-ignore
kernel_type = sharders[0]._kernel_type
# Construct parameter shardings
for i in range(num_tables):
table_name = tables[i].name
table_constraint = constraints[table_name] # pyre-ignore
assert hasattr(table_constraint, "sharding_types")
assert (
len(table_constraint.sharding_types) == 1
), "Dynamic Sharding currently only supports 1 sharding type per table"
sharding_type = ShardingType(table_constraint.sharding_types[0])
sharding_type_constructor = get_sharding_constructor_from_type(
sharding_type
)

if sharding_type == ShardingType.TABLE_WISE:
new_per_param_sharding[table_name] = sharding_type_constructor(
rank=new_ranks[i][0], compute_kernel=kernel_type
)
elif sharding_type == ShardingType.COLUMN_WISE:
new_per_param_sharding[table_name] = sharding_type_constructor(
ranks=new_ranks_cw[i], compute_kernel=kernel_type
)
else:
raise NotImplementedError(
f"Dynamic Sharding currently does not support {sharding_type}"
)

new_module_sharding_plan = construct_module_sharding_plan(
local_m2.sparse.ebc,
sharder=sharders[0],
per_param_sharding=new_per_param_sharding,
local_size=world_size,
plan_1 = create_alternative_sharding_plan(
tables=tables,
world_size=world_size,
device_type="cuda" if torch.cuda.is_available() else "cpu",
random_seed=random_seed,
sharders=sharders,
constraints=constraints,
local_model=local_m2,
original_plan=plan,
)

plan_1 = copy.deepcopy(plan)
plan_1.plan["sparse.ebc"] = new_module_sharding_plan

local_m1_dmp = DistributedModelParallel(
local_m1,
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore
Expand Down Expand Up @@ -591,7 +538,7 @@ def dynamic_sharding_test(
)

new_module_sharding_plan_delta = output_sharding_plan_delta(
plan.plan["sparse.ebc"], new_module_sharding_plan # pyre-ignore
plan.plan["sparse.ebc"], plan_1.plan["sparse.ebc"] # pyre-ignore
)

dense_m1_optim = KeyedOptimizerWrapper(
Expand Down Expand Up @@ -1304,6 +1251,97 @@ def generate_rank_placements(
return placements


def create_alternative_sharding_plan(
tables: List[EmbeddingTableConfig],
world_size: int,
random_seed: int,
sharders: List[ModuleSharder[nn.Module]],
constraints: Optional[Dict[str, ParameterConstraints]],
local_model: nn.Module,
original_plan: ShardingPlan,
) -> ShardingPlan:
"""
Creates an alternative sharding plan for dynamic sharding tests.

Args:
tables: List of embedding table configurations
world_size: Number of processes in the distributed group
random_seed: Random seed for reproducible rank placement generation
sharders: List of module sharders
constraints: Parameter constraints for sharding
local_model: Local model to create sharding plan for
original_plan: Original sharding plan to copy and modify

Returns:
Modified sharding plan with alternative parameter sharding
"""
if constraints is None:
raise ValueError("constraints parameter is required for dynamic sharding")

num_tables = len(tables)

ranks_per_tables = [1 for _ in range(num_tables)]
new_ranks = generate_rank_placements(
world_size, num_tables, ranks_per_tables, random_seed
)

ranks_per_tables_for_CW = []
for table in tables:

# CW sharding
valid_candidates = [
i for i in range(1, world_size + 1) if table.embedding_dim % i == 0
]
ranks_per_tables_for_CW.append(random.choice(valid_candidates))

new_ranks_cw = generate_rank_placements(
world_size, num_tables, ranks_per_tables_for_CW, random_seed
)

new_per_param_sharding = {}

assert len(sharders) == 1
kernel_type = sharders[0]._kernel_type # pyre-ignore
# Construct parameter shardings
for i in range(num_tables):
table_name = tables[i].name
table_constraint = constraints[table_name]
assert hasattr(table_constraint, "sharding_types")
assert table_constraint.sharding_types is not None
assert (
len(table_constraint.sharding_types) == 1
), "Dynamic Sharding currently only supports 1 sharding type per table"
sharding_type = ShardingType(table_constraint.sharding_types[0]) # pyre-ignore
sharding_type_constructor = get_sharding_constructor_from_type(sharding_type)

if sharding_type == ShardingType.TABLE_WISE:
new_per_param_sharding[table_name] = sharding_type_constructor(
rank=new_ranks[i][0], compute_kernel=kernel_type
)
elif sharding_type == ShardingType.COLUMN_WISE:
new_per_param_sharding[table_name] = sharding_type_constructor(
ranks=new_ranks_cw[i], compute_kernel=kernel_type
)
else:
raise NotImplementedError(
f"Dynamic Sharding currently does not support {sharding_type}"
)

new_module_sharding_plan = construct_module_sharding_plan(
local_model.sparse.ebc, # pyre-ignore
sharder=sharders[0],
per_param_sharding=new_per_param_sharding,
local_size=world_size,
world_size=world_size,
device_type="cuda" if torch.cuda.is_available() else "cpu",
)

plan_1 = copy.deepcopy(original_plan)
plan_1.plan["sparse.ebc"] = new_module_sharding_plan

return plan_1


def compare_opt_local_t(
opt_1: CombinedOptimizer,
opt_2: CombinedOptimizer,
Expand Down
Loading