diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index abba8964..26efa2bf 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -379,7 +379,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: This util applies to any factory function that takes 'size' as the first argument, and supports Replication and Shard placements all at zero cost. """ - assert isinstance(op_schema.args_schema[0], torch.Size) + assert isinstance(op_schema.args_schema[0], (torch.Size, list)) shape = op_schema.args_schema[0] x = torch.empty(shape, device="meta") stride = x.stride() @@ -408,7 +408,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: for strategy_comb in strategy_combs: spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)] output_specs = spec_list[0] - output_specs.tensor_meta = TensorMeta(shape, stride, dtype) + output_specs.tensor_meta = TensorMeta(shape, stride, dtype) # type: ignore[arg-type] if not is_tensor_shardable(shape, output_specs): continue @@ -421,8 +421,13 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: * len(strategy_combs) ] + # NOTE: why do we have input_specs for constructor nodes, given that they have no inputs? + # This is because the optimizer code expects to see input_specs for all nodes, and it + # uses the input_specs to determine the sharding of the output. So we have to give it + # something, even though it is in principle not needed. strategy = OpSpec( output_specs=output_specs, + input_specs=[output_specs], redistribute_cost=redistribute_cost, ) all_strategies.append(strategy) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 96ecbfe4..2627cb1d 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -73,9 +73,6 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): else: assert tm is None if strat.input_specs is None: - if op in TENSOR_FACTORY_OPS: - # there isn't an input spec bc the op has no input! - continue supported_ops = { torch.ops.prims.convert_element_type.default, @@ -88,9 +85,18 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): ) strat.input_specs = (strat.output_specs,) assert strat.redistribute_cost is None - assert len(tensor_metas) == len( - strat.input_specs - ), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}" + # NOTE: this invariant wrt factory ops is something I believe + # I'll keep for the solver, so we need to have some consistency here + # i.e., even though factory ops don't have inputs, we do put an + # input spec for it which is equal to the output spec + if op in TENSOR_FACTORY_OPS: + assert len(tensor_metas) == 0, f"{op}, {len(tensor_metas)}" + assert len(strat.input_specs) == 1, f"{op}, {len(strat.input_specs)}" + else: + assert len(tensor_metas) == len( + strat.input_specs + ), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}" + for tm, ispec in zip(tensor_metas, strat.input_specs): if ispec.tensor_meta is None: ispec.tensor_meta = tm