diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 96ecbfe4..013d9287 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -17,7 +17,7 @@ from torch.distributed.tensor.placement_types import Replicate from torch.utils._pytree import tree_flatten, tree_map_only -from .dtensor_util import get_op_strategy +from .dtensor_util import get_op_strategy, with_implicit_strategies from .propagation_rules import ( TENSOR_FACTORY_OPS, _op_partial_rules, @@ -163,7 +163,8 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): if op in _op_partial_rules: out_strat = _op_partial_rules[op](mesh, op_schema) else: - out_strat = get_op_strategy(op, op_schema) + with with_implicit_strategies(): + out_strat = get_op_strategy(op, op_schema) propagate_tensor_meta(op, user_args, user_kwargs, out_strat) fill_missing_redistribute_cost(op, specs, out_strat)