From 90019f46b0f5ec5db3ed048fe6d2d75a9c3c399a Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 4 Aug 2025 11:37:01 +0000 Subject: [PATCH] Fix invariant that factory ops should have input specs The solver assumes all ops (including placeholder and factory ops) have an input spec. So let's just respect this condition in the code --- autoparallel/propagation_rules.py | 9 +++++++-- autoparallel/utils.py | 18 ++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index abba8964..26efa2bf 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -379,7 +379,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: This util applies to any factory function that takes 'size' as the first argument, and supports Replication and Shard placements all at zero cost. """ - assert isinstance(op_schema.args_schema[0], torch.Size) + assert isinstance(op_schema.args_schema[0], (torch.Size, list)) shape = op_schema.args_schema[0] x = torch.empty(shape, device="meta") stride = x.stride() @@ -408,7 +408,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: for strategy_comb in strategy_combs: spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)] output_specs = spec_list[0] - output_specs.tensor_meta = TensorMeta(shape, stride, dtype) + output_specs.tensor_meta = TensorMeta(shape, stride, dtype) # type: ignore[arg-type] if not is_tensor_shardable(shape, output_specs): continue @@ -421,8 +421,13 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: * len(strategy_combs) ] + # NOTE: why do we have input_specs for constructor nodes, given that they have no inputs? + # This is because the optimizer code expects to see input_specs for all nodes, and it + # uses the input_specs to determine the sharding of the output. So we have to give it + # something, even though it is in principle not needed. strategy = OpSpec( output_specs=output_specs, + input_specs=[output_specs], redistribute_cost=redistribute_cost, ) all_strategies.append(strategy) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 96ecbfe4..2627cb1d 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -73,9 +73,6 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): else: assert tm is None if strat.input_specs is None: - if op in TENSOR_FACTORY_OPS: - # there isn't an input spec bc the op has no input! - continue supported_ops = { torch.ops.prims.convert_element_type.default, @@ -88,9 +85,18 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): ) strat.input_specs = (strat.output_specs,) assert strat.redistribute_cost is None - assert len(tensor_metas) == len( - strat.input_specs - ), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}" + # NOTE: this invariant wrt factory ops is something I believe + # I'll keep for the solver, so we need to have some consistency here + # i.e., even though factory ops don't have inputs, we do put an + # input spec for it which is equal to the output spec + if op in TENSOR_FACTORY_OPS: + assert len(tensor_metas) == 0, f"{op}, {len(tensor_metas)}" + assert len(strat.input_specs) == 1, f"{op}, {len(strat.input_specs)}" + else: + assert len(tensor_metas) == len( + strat.input_specs + ), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}" + for tm, ispec in zip(tensor_metas, strat.input_specs): if ispec.tensor_meta is None: ispec.tensor_meta = tm