@@ -22,6 +22,7 @@ limitations under the License.
2222#include < tuple>
2323#include < utility>
2424
25+ #include " absl/log/check.h"
2526#include " llvm/ADT/ArrayRef.h"
2627#include " llvm/ADT/STLExtras.h"
2728#include " llvm/ADT/SmallVector.h"
@@ -140,8 +141,9 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
140141 }
141142 DCHECK_GE (expanded_dims, 2 ); // Minor should expand at least into 2 dims.
142143
143- // TODO(mvoz,apaszke): Add slicing support for stores (analogous to loads).
144- if (tgt_ty.getShape () != ref_ty.getShape ()) {
144+ // We don't support slicing in the expanded dims at the moment.
145+ if (tgt_ty.getShape ().take_back (expanded_dims) !=
146+ ref_ty.getShape ().take_back (expanded_dims)) {
145147 return ;
146148 }
147149
@@ -160,7 +162,10 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
160162 src_shape = src_ty.getShape ();
161163 }
162164
163- SmallVector<int64_t > mem_shape (src_ty.getShape ().drop_back (1 ));
165+ SmallVector<int64_t > mem_shape (ref_ty.getShape ().drop_back (expanded_dims));
166+ if (mem_shape.empty ()) {
167+ mem_shape.push_back (1 );
168+ }
164169 mem_shape.back () *= src_shape.back () / lane;
165170 mem_shape.push_back (lane);
166171
@@ -170,20 +175,37 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
170175 Value i32_view = b.create <tpu::MemRefBitcastOp>(
171176 MemRefType::get (mem_shape, i32_type), reshaped_ref);
172177
178+
173179 Value src_vec = shape_cast_op.getSource ();
174- SmallVector<int64_t > slice_shape (src_ty. getShape () );
180+ SmallVector<int64_t > slice_shape (src_shape );
175181 slice_shape.back () = lane;
182+ SmallVector<int64_t > slice_strides (slice_shape.size (), 1 );
183+ SmallVector<int64_t > slice_offsets (slice_shape.size (), 0 );
176184
177185 // We don't support slicing so the program only didn't contain OOB stores
178186 // if all indices were 0.
179- SmallVector<Value> store_indices (mem_shape.size (), IdxConst (0 , b, loc));
180187 SmallVector<int32_t > store_strides (mem_shape.size (), 1 );
181188 const int64_t sublane_prod = src_shape.back () / lane;
182189 const int64_t stride = sublane_prod / packing;
183190 *(store_strides.end () - 2 ) = stride;
184191
185- SmallVector<int64_t > slice_strides (src_ty.getRank (), 1 );
186- SmallVector<int64_t > slice_offsets (src_ty.getRank (), 0 );
192+ SmallVector<Value> store_indices (indices.drop_back (expanded_dims));
193+ if (store_indices.empty ()) {
194+ store_indices.push_back (IdxConst (0 , b, loc));
195+ }
196+ Value second_minor_base = b.create <arith::MulIOp>(
197+ store_indices.back (),
198+ b.create <arith::ConstantOp>(b.getIndexType (),
199+ b.getI32IntegerAttr (stride)));
200+ store_indices.back () = nullptr ;
201+ store_indices.push_back (IdxConst (0 , b, loc));
202+ SmallVector<int64_t > to_store_shape (tgt_shape.drop_back (expanded_dims));
203+ if (to_store_shape.empty ()) {
204+ to_store_shape.push_back (1 );
205+ }
206+ to_store_shape.push_back (lane);
207+ auto store_vty = VectorType::get (to_store_shape, b.getI32Type ());
208+
187209 for (int64_t i = 0 ; i < stride; ++i) {
188210 slice_offsets.back () = i * packing * lane;
189211 Value slice_i = b.create <vector::ExtractStridedSliceOp>(
@@ -217,9 +239,19 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
217239 VectorType::get (slice_shape, i32_type), slice_i);
218240 }
219241
220- Value chunk_to_store = packed_chunk;
242+ // TODO(b/458291444): This reshape might end up being non-trivial and might
243+ // produce a vector with an unnecessarily bad layout. Consider the where
244+ // src_shape is (24, 1024) and tgt_shape is (8, 3, 8, 128). In that case
245+ // slice_shape is (24, 128), which can be neatly packed into vregs, but here
246+ // we would reshape to (8, 3, 128), which of course is problematic and will
247+ // introduce lots of padding... We could work around this by flattening the
248+ // ref dimensions, but it is complicated by non-contiguous slices which
249+ // might prevent this. In case we find a non-contiguous slice we could still
250+ // try unrolling into multiple strided stores.
251+ Value chunk_to_store = b.create <tpu::ReshapeOp>(store_vty, packed_chunk);
221252 CHECK_GE (store_indices.size (), 2 );
222- *(store_indices.end () - 2 ) = IdxConst (i, b, loc);
253+ *(store_indices.end () - 2 ) =
254+ b.create <arith::AddIOp>(second_minor_base, IdxConst (i, b, loc));
223255 b.create <tpu::StridedStoreOp>(chunk_to_store, i32_view, store_indices,
224256 store_strides);
225257 }
0 commit comments