Skip to content

Commit 79376e2

Browse files
committed
preserve tensor striding during compute estimation
1 parent d5edb93 commit 79376e2

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

autoparallel/compute_estimation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,13 @@ def _get_sharded_shape(spec):
164164
# TODO: find a better heuristic other than
165165
# running DTensor
166166
new_tensor_shape = list(tensor_shape)
167+
new_tensor_stride = list(spec.tensor_meta.stride)
167168
for mesh_size, placement in zip(mesh.shape, placements):
168169
if placement.is_shard():
169170
dim = placement.dim
170171
new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size
171-
return new_tensor_shape
172+
new_tensor_stride[dim] = (new_tensor_stride[dim] + mesh_size - 1) // mesh_size
173+
return new_tensor_shape, new_tensor_stride
172174

173175

174176
def estimate_strategy_runtime_cost(node, strategy):
@@ -191,15 +193,16 @@ def estimate_strategy_runtime_cost(node, strategy):
191193
if len(kwargs) > 0:
192194
for k, v in kwargs.items():
193195
assert not isinstance(v, torch.Tensor), f"{node} {v}"
194-
args_shapes = tuple(_get_sharded_shape(spec) for spec in strategy.input_specs)
196+
args_sizes_strides = tuple(_get_sharded_shape(spec) for spec in strategy.input_specs)
195197

196198
counter = 0
197199
args = list(args)
198200
for i, arg in enumerate(args):
199201
if isinstance(arg, torch.Tensor):
200202
with fake_mode:
201-
args[i] = torch.empty(
202-
args_shapes[counter], device=arg.device, dtype=arg.dtype
203+
sizes, strides = args_sizes_strides[counter]
204+
args[i] = torch.empty_strided(
205+
sizes, strides, device=arg.device, dtype=arg.dtype
203206
)
204207
counter += 1
205208

0 commit comments

Comments
 (0)