Skip to content

Commit 449e35e

Browse files
authored
Small refactoring of propagate_tensor_meta (#262)
This extracts the refactoring from #29, and also adds as an allowed op for automatic redistribute cost filling
1 parent 6e2085e commit 449e35e

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

autoparallel/utils.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,30 @@
3333
)
3434

3535

36-
def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
36+
def _get_meta_tensors_for_op(op, user_args, user_kwargs):
3737
out_t = op(*user_args, **user_kwargs)
3838

3939
if isinstance(out_t, torch.Tensor):
40-
new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype)
40+
out_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype)
4141
else:
42-
new_tensor_meta = tree_map_only(
42+
out_tensor_meta = tree_map_only(
4343
torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), out_t
4444
)
4545

46-
tensor_metas = tree_flatten(user_args)[0]
47-
tensor_metas = tree_map_only(
48-
torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), tensor_metas
46+
input_tensor_metas = tree_flatten(user_args)[0]
47+
input_tensor_metas = tree_map_only(
48+
torch.Tensor,
49+
lambda x: TensorMeta(x.shape, x.stride(), x.dtype),
50+
input_tensor_metas,
51+
)
52+
input_tensor_metas = tuple(
53+
x for x in input_tensor_metas if isinstance(x, TensorMeta)
4954
)
50-
tensor_metas = tuple(x for x in tensor_metas if isinstance(x, TensorMeta))
55+
return out_tensor_meta, input_tensor_metas
56+
57+
58+
def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
59+
new_tensor_meta, tensor_metas = _get_meta_tensors_for_op(op, user_args, user_kwargs)
5160

5261
for strat in out_strat.strategies:
5362
if isinstance(new_tensor_meta, TensorMeta):
@@ -120,6 +129,7 @@ def fill_missing_redistribute_cost(op, specs, out_strat):
120129
torch.ops.aten.empty_like.default,
121130
torch.ops.prims.convert_element_type.default,
122131
torch.ops.aten.slice.Tensor,
132+
torch.ops.aten.select.int,
123133
}
124134
assert op in handled_ops, f"got {op}, supported ops here are {handled_ops}"
125135
# assert len(specs) == 1, f"Expected len(specs) == 1, got {len(specs)}"

0 commit comments

Comments
 (0)