Skip to content

Commit 266e639

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Add support for slicing in reshape->store fusion optimization
PiperOrigin-RevId: 829359950
1 parent 062e6a4 commit 266e639

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,8 @@ FailureOr<Value> canonicalize_reshape(const CanonicalizeContext &ctx,
15921592
MemRefType::get(mem_shape, b.getI32Type()), reshaped_ref);
15931593

15941594
// Define the shape of the small i32 chunk we will load in each iteration.
1595+
// TODO(b/458291444): The loads we emit here might use suboptimal shapes and
1596+
// we could do better by folding some dims (as much as slicing allows).
15951597
SmallVector<int64_t> chunk_shape(src_shape.drop_back(folded_dims));
15961598
if (chunk_shape.empty()) {
15971599
chunk_shape.push_back(1);

jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include <tuple>
2323
#include <utility>
2424

25+
#include "absl/log/check.h"
2526
#include "llvm/ADT/ArrayRef.h"
2627
#include "llvm/ADT/STLExtras.h"
2728
#include "llvm/ADT/SmallVector.h"
@@ -140,8 +141,9 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
140141
}
141142
DCHECK_GE(expanded_dims, 2); // Minor should expand at least into 2 dims.
142143

143-
// TODO(mvoz,apaszke): Add slicing support for stores (analogous to loads).
144-
if (tgt_ty.getShape() != ref_ty.getShape()) {
144+
// We don't support slicing in the expanded dims at the moment.
145+
if (tgt_ty.getShape().take_back(expanded_dims) !=
146+
ref_ty.getShape().take_back(expanded_dims)) {
145147
return;
146148
}
147149

@@ -160,7 +162,10 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
160162
src_shape = src_ty.getShape();
161163
}
162164

163-
SmallVector<int64_t> mem_shape(src_ty.getShape().drop_back(1));
165+
SmallVector<int64_t> mem_shape(ref_ty.getShape().drop_back(expanded_dims));
166+
if (mem_shape.empty()) {
167+
mem_shape.push_back(1);
168+
}
164169
mem_shape.back() *= src_shape.back() / lane;
165170
mem_shape.push_back(lane);
166171

@@ -170,20 +175,37 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
170175
Value i32_view = b.create<tpu::MemRefBitcastOp>(
171176
MemRefType::get(mem_shape, i32_type), reshaped_ref);
172177

178+
173179
Value src_vec = shape_cast_op.getSource();
174-
SmallVector<int64_t> slice_shape(src_ty.getShape());
180+
SmallVector<int64_t> slice_shape(src_shape);
175181
slice_shape.back() = lane;
182+
SmallVector<int64_t> slice_strides(slice_shape.size(), 1);
183+
SmallVector<int64_t> slice_offsets(slice_shape.size(), 0);
176184

177185
// We don't support slicing so the program only didn't contain OOB stores
178186
// if all indices were 0.
179-
SmallVector<Value> store_indices(mem_shape.size(), IdxConst(0, b, loc));
180187
SmallVector<int32_t> store_strides(mem_shape.size(), 1);
181188
const int64_t sublane_prod = src_shape.back() / lane;
182189
const int64_t stride = sublane_prod / packing;
183190
*(store_strides.end() - 2) = stride;
184191

185-
SmallVector<int64_t> slice_strides(src_ty.getRank(), 1);
186-
SmallVector<int64_t> slice_offsets(src_ty.getRank(), 0);
192+
SmallVector<Value> store_indices(indices.drop_back(expanded_dims));
193+
if (store_indices.empty()) {
194+
store_indices.push_back(IdxConst(0, b, loc));
195+
}
196+
Value second_minor_base = b.create<arith::MulIOp>(
197+
store_indices.back(),
198+
b.create<arith::ConstantOp>(b.getIndexType(),
199+
b.getI32IntegerAttr(stride)));
200+
store_indices.back() = nullptr;
201+
store_indices.push_back(IdxConst(0, b, loc));
202+
SmallVector<int64_t> to_store_shape(tgt_shape.drop_back(expanded_dims));
203+
if (to_store_shape.empty()) {
204+
to_store_shape.push_back(1);
205+
}
206+
to_store_shape.push_back(lane);
207+
auto store_vty = VectorType::get(to_store_shape, b.getI32Type());
208+
187209
for (int64_t i = 0; i < stride; ++i) {
188210
slice_offsets.back() = i * packing * lane;
189211
Value slice_i = b.create<vector::ExtractStridedSliceOp>(
@@ -217,9 +239,19 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
217239
VectorType::get(slice_shape, i32_type), slice_i);
218240
}
219241

220-
Value chunk_to_store = packed_chunk;
242+
// TODO(b/458291444): This reshape might end up being non-trivial and might
243+
// produce a vector with an unnecessarily bad layout. Consider the where
244+
// src_shape is (24, 1024) and tgt_shape is (8, 3, 8, 128). In that case
245+
// slice_shape is (24, 128), which can be neatly packed into vregs, but here
246+
// we would reshape to (8, 3, 128), which of course is problematic and will
247+
// introduce lots of padding... We could work around this by flattening the
248+
// ref dimensions, but it is complicated by non-contiguous slices which
249+
// might prevent this. In case we find a non-contiguous slice we could still
250+
// try unrolling into multiple strided stores.
251+
Value chunk_to_store = b.create<tpu::ReshapeOp>(store_vty, packed_chunk);
221252
CHECK_GE(store_indices.size(), 2);
222-
*(store_indices.end() - 2) = IdxConst(i, b, loc);
253+
*(store_indices.end() - 2) =
254+
b.create<arith::AddIOp>(second_minor_base, IdxConst(i, b, loc));
223255
b.create<tpu::StridedStoreOp>(chunk_to_store, i32_view, store_indices,
224256
store_strides);
225257
}

0 commit comments

Comments
 (0)