Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
This util applies to any factory function that takes 'size' as the first argument,
and supports Replication and Shard placements all at zero cost.
"""
assert isinstance(op_schema.args_schema[0], torch.Size)
assert isinstance(op_schema.args_schema[0], (torch.Size, list))
shape = op_schema.args_schema[0]
x = torch.empty(shape, device="meta")
stride = x.stride()
Expand Down Expand Up @@ -408,7 +408,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
for strategy_comb in strategy_combs:
spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)]
output_specs = spec_list[0]
output_specs.tensor_meta = TensorMeta(shape, stride, dtype)
output_specs.tensor_meta = TensorMeta(shape, stride, dtype) # type: ignore[arg-type]

if not is_tensor_shardable(shape, output_specs):
continue
Expand All @@ -421,8 +421,13 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
* len(strategy_combs)
]

# NOTE: why do we have input_specs for constructor nodes, given that they have no inputs?
# This is because the optimizer code expects to see input_specs for all nodes, and it
# uses the input_specs to determine the sharding of the output. So we have to give it
# something, even though it is in principle not needed.
strategy = OpSpec(
output_specs=output_specs,
input_specs=[output_specs],
redistribute_cost=redistribute_cost,
)
all_strategies.append(strategy)
Expand Down
18 changes: 12 additions & 6 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
else:
assert tm is None
if strat.input_specs is None:
if op in TENSOR_FACTORY_OPS:
# there isn't an input spec bc the op has no input!
continue

supported_ops = {
torch.ops.prims.convert_element_type.default,
Expand All @@ -88,9 +85,18 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
)
strat.input_specs = (strat.output_specs,)
assert strat.redistribute_cost is None
assert len(tensor_metas) == len(
strat.input_specs
), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}"
# NOTE: this invariant wrt factory ops is something I believe
# I'll keep for the solver, so we need to have some consistency here
# i.e., even though factory ops don't have inputs, we do put an
# input spec for it which is equal to the output spec
if op in TENSOR_FACTORY_OPS:
assert len(tensor_metas) == 0, f"{op}, {len(tensor_metas)}"
assert len(strat.input_specs) == 1, f"{op}, {len(strat.input_specs)}"
else:
assert len(tensor_metas) == len(
strat.input_specs
), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}"

for tm, ispec in zip(tensor_metas, strat.input_specs):
if ispec.tensor_meta is None:
ispec.tensor_meta = tm
Expand Down
Loading