Skip to content

Commit d458c8a

Browse files
authored
Fix invariant that factory ops should have input specs (#71)
The solver assumes all ops (including placeholder and factory ops) have an input spec. So let's just respect this condition in the code
1 parent 3fba727 commit d458c8a

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

autoparallel/propagation_rules.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
379379
This util applies to any factory function that takes 'size' as the first argument,
380380
and supports Replication and Shard placements all at zero cost.
381381
"""
382-
assert isinstance(op_schema.args_schema[0], torch.Size)
382+
assert isinstance(op_schema.args_schema[0], (torch.Size, list))
383383
shape = op_schema.args_schema[0]
384384
x = torch.empty(shape, device="meta")
385385
stride = x.stride()
@@ -408,7 +408,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
408408
for strategy_comb in strategy_combs:
409409
spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)]
410410
output_specs = spec_list[0]
411-
output_specs.tensor_meta = TensorMeta(shape, stride, dtype)
411+
output_specs.tensor_meta = TensorMeta(shape, stride, dtype) # type: ignore[arg-type]
412412

413413
if not is_tensor_shardable(shape, output_specs):
414414
continue
@@ -421,8 +421,13 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
421421
* len(strategy_combs)
422422
]
423423

424+
# NOTE: why do we have input_specs for constructor nodes, given that they have no inputs?
425+
# This is because the optimizer code expects to see input_specs for all nodes, and it
426+
# uses the input_specs to determine the sharding of the output. So we have to give it
427+
# something, even though it is in principle not needed.
424428
strategy = OpSpec(
425429
output_specs=output_specs,
430+
input_specs=[output_specs],
426431
redistribute_cost=redistribute_cost,
427432
)
428433
all_strategies.append(strategy)

autoparallel/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
7373
else:
7474
assert tm is None
7575
if strat.input_specs is None:
76-
if op in TENSOR_FACTORY_OPS:
77-
# there isn't an input spec bc the op has no input!
78-
continue
7976

8077
supported_ops = {
8178
torch.ops.prims.convert_element_type.default,
@@ -88,9 +85,18 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
8885
)
8986
strat.input_specs = (strat.output_specs,)
9087
assert strat.redistribute_cost is None
91-
assert len(tensor_metas) == len(
92-
strat.input_specs
93-
), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}"
88+
# NOTE: this invariant wrt factory ops is something I believe
89+
# I'll keep for the solver, so we need to have some consistency here
90+
# i.e., even though factory ops don't have inputs, we do put an
91+
# input spec for it which is equal to the output spec
92+
if op in TENSOR_FACTORY_OPS:
93+
assert len(tensor_metas) == 0, f"{op}, {len(tensor_metas)}"
94+
assert len(strat.input_specs) == 1, f"{op}, {len(strat.input_specs)}"
95+
else:
96+
assert len(tensor_metas) == len(
97+
strat.input_specs
98+
), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}"
99+
94100
for tm, ispec in zip(tensor_metas, strat.input_specs):
95101
if ispec.tensor_meta is None:
96102
ispec.tensor_meta = tm

0 commit comments

Comments
 (0)