@@ -164,11 +164,13 @@ def _get_sharded_shape(spec):
164164 # TODO: find a better heuristic other than
165165 # running DTensor
166166 new_tensor_shape = list (tensor_shape )
167+ new_tensor_stride = list (spec .tensor_meta .stride )
167168 for mesh_size , placement in zip (mesh .shape , placements ):
168169 if placement .is_shard ():
169170 dim = placement .dim
170171 new_tensor_shape [dim ] = (new_tensor_shape [dim ] + mesh_size - 1 ) // mesh_size
171- return new_tensor_shape
172+ new_tensor_stride [dim ] = (new_tensor_stride [dim ] + mesh_size - 1 ) // mesh_size
173+ return new_tensor_shape , new_tensor_stride
172174
173175
174176def estimate_strategy_runtime_cost (node , strategy ):
@@ -191,15 +193,16 @@ def estimate_strategy_runtime_cost(node, strategy):
191193 if len (kwargs ) > 0 :
192194 for k , v in kwargs .items ():
193195 assert not isinstance (v , torch .Tensor ), f"{ node } { v } "
194- args_shapes = tuple (_get_sharded_shape (spec ) for spec in strategy .input_specs )
196+ args_sizes_strides = tuple (_get_sharded_shape (spec ) for spec in strategy .input_specs )
195197
196198 counter = 0
197199 args = list (args )
198200 for i , arg in enumerate (args ):
199201 if isinstance (arg , torch .Tensor ):
200202 with fake_mode :
201- args [i ] = torch .empty (
202- args_shapes [counter ], device = arg .device , dtype = arg .dtype
203+ sizes , strides = args_sizes_strides [counter ]
204+ args [i ] = torch .empty_strided (
205+ sizes , strides , device = arg .device , dtype = arg .dtype
203206 )
204207 counter += 1
205208
0 commit comments