|
33 | 33 | ) |
34 | 34 |
|
35 | 35 |
|
36 | | -def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): |
| 36 | +def _get_meta_tensors_for_op(op, user_args, user_kwargs): |
37 | 37 | out_t = op(*user_args, **user_kwargs) |
38 | 38 |
|
39 | 39 | if isinstance(out_t, torch.Tensor): |
40 | | - new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) |
| 40 | + out_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) |
41 | 41 | else: |
42 | | - new_tensor_meta = tree_map_only( |
| 42 | + out_tensor_meta = tree_map_only( |
43 | 43 | torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), out_t |
44 | 44 | ) |
45 | 45 |
|
46 | | - tensor_metas = tree_flatten(user_args)[0] |
47 | | - tensor_metas = tree_map_only( |
48 | | - torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), tensor_metas |
| 46 | + input_tensor_metas = tree_flatten(user_args)[0] |
| 47 | + input_tensor_metas = tree_map_only( |
| 48 | + torch.Tensor, |
| 49 | + lambda x: TensorMeta(x.shape, x.stride(), x.dtype), |
| 50 | + input_tensor_metas, |
| 51 | + ) |
| 52 | + input_tensor_metas = tuple( |
| 53 | + x for x in input_tensor_metas if isinstance(x, TensorMeta) |
49 | 54 | ) |
50 | | - tensor_metas = tuple(x for x in tensor_metas if isinstance(x, TensorMeta)) |
| 55 | + return out_tensor_meta, input_tensor_metas |
| 56 | + |
| 57 | + |
| 58 | +def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): |
| 59 | + new_tensor_meta, tensor_metas = _get_meta_tensors_for_op(op, user_args, user_kwargs) |
51 | 60 |
|
52 | 61 | for strat in out_strat.strategies: |
53 | 62 | if isinstance(new_tensor_meta, TensorMeta): |
@@ -120,6 +129,7 @@ def fill_missing_redistribute_cost(op, specs, out_strat): |
120 | 129 | torch.ops.aten.empty_like.default, |
121 | 130 | torch.ops.prims.convert_element_type.default, |
122 | 131 | torch.ops.aten.slice.Tensor, |
| 132 | + torch.ops.aten.select.int, |
123 | 133 | } |
124 | 134 | assert op in handled_ops, f"got {op}, supported ops here are {handled_ops}" |
125 | 135 | # assert len(specs) == 1, f"Expected len(specs) == 1, got {len(specs)}" |
|
0 commit comments