Skip to content

Commit 10bb57a

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU][NFC] Share checks between the reshape/store and load/reshape optimization rules
They support exactly the same cases and are symmetrical. PiperOrigin-RevId: 830814515
1 parent 42e56cf commit 10bb57a

File tree

1 file changed

+80
-109
lines changed

1 file changed

+80
-109
lines changed

jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc

Lines changed: 80 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -50,104 +50,125 @@ namespace mlir::tpu {
5050

5151
namespace {
5252

53-
void optimizeLoadReshape(int hardware_generation,
54-
std::array<int64_t, 2> target_shape,
55-
Operation& raw_op) {
56-
// Below, we try to look for reshapes that flatten multiple dims into the
57-
// lane dimension. If the source of the reshape originates from a load of a
58-
// ref with 128 minor dimension (effectively untiled), we can replace the
59-
// load/reshape sequence with an efficient strided load. In essence, the
60-
// strided load creates vregs with a narrow slice along the target minor
61-
// dimension, but with the 2nd minor dim after the reshape already in
62-
// sublanes. The results of strided load can be concatenated to form the
63-
// final vector result.
64-
//
65-
// A little extra care needs to be applied to packed types, which we handle by
66-
// briefly extending to 32-bit and repacking them after concatenation.
67-
TypedValue<VectorType> src;
68-
VectorType tgt_ty;
69-
if (auto op = dyn_cast<tpu::ReshapeOp>(&raw_op)) {
70-
src = op.getSource();
71-
tgt_ty = op.getResult().getType();
72-
} else if (auto op = dyn_cast<vector::ShapeCastOp>(&raw_op)) {
73-
src = op.getSource();
74-
tgt_ty = op.getResult().getType();
75-
} else {
76-
return;
77-
}
78-
VectorType src_ty = src.getType();
79-
if (src_ty.getRank() < 2 || tgt_ty.getRank() < 1) {
80-
return;
53+
std::optional<int64_t> canOptimizeReshapeMemory(
54+
int hardware_generation, std::array<int64_t, 2> target_shape,
55+
TypedValue<MemRefType> ref, VectorType expanded_ty,
56+
VectorType collapsed_ty) {
57+
if (expanded_ty.getRank() < 2 || collapsed_ty.getRank() < 1) {
58+
return std::nullopt;
8159
}
82-
const int bitwidth = src_ty.getElementTypeBitWidth();
60+
const int bitwidth = expanded_ty.getElementTypeBitWidth();
8361
const int packing = 32 / bitwidth;
8462
if (hardware_generation < 4 && packing > 1) {
85-
return;
86-
}
87-
88-
auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp());
89-
// This rewrite might not be profitable if the load has other users.
90-
if (!load_op || !load_op.getBase().hasOneUse()) {
91-
return;
63+
return std::nullopt;
9264
}
9365

94-
TypedValue<MemRefType> ref = load_op.getBase();
95-
MemRefType ref_ty = getMemRefType(ref);
9666
// The reshape below might be invalid if the memref is not contiguous, but it
9767
// is an overly conservative check (we don't need all dims to be contiguous).
9868
if (!isContiguousMemref(ref)) {
99-
return;
69+
return std::nullopt;
10070
}
10171

10272
const int64_t lane = target_shape[1];
103-
auto src_shape = src_ty.getShape();
104-
auto tgt_shape = tgt_ty.getShape();
73+
int64_t collapsed_minor = collapsed_ty.getShape().back();
74+
int64_t expanded_minor = expanded_ty.getShape().back();
10575
// Only handle the cases where the minor dim starts out as the number of lanes
10676
// and we fold at least the second minor dim into it, in a way that changes
10777
// its shape.
108-
if (src_shape.back() != lane ||
109-
tgt_shape.back() % (packing * lane) != 0 ||
110-
tgt_shape.back() == src_shape.back() ||
111-
tgt_shape.back() < llvm::product_of(src_shape.take_back(2))) {
112-
return;
78+
if (expanded_minor != lane ||
79+
collapsed_minor % (packing * lane) != 0 ||
80+
collapsed_minor == expanded_minor ||
81+
collapsed_minor < llvm::product_of(expanded_ty.getShape().take_back(2))) {
82+
return std::nullopt;
11383
}
11484

11585
// We don't handle memrefs with padding.
86+
MemRefType ref_ty = getMemRefType(ref);
11687
auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout());
11788
if (!tiled_layout || tiled_layout.getTiles().empty()) {
118-
return;
89+
return std::nullopt;
11990
}
12091
ArrayRef<int64_t> front_tile = tiled_layout.getTiles().front().dimensions();
12192
ArrayRef<int64_t> ref_tiled_shape =
12293
ref_ty.getShape().take_back(front_tile.size());
12394
for (int i = 0; i < front_tile.size(); ++i) {
12495
if (ref_tiled_shape[i] % front_tile[i]) {
125-
return;
96+
return std::nullopt;
12697
}
12798
}
12899

129100
// NOTE: We could generalize this to allow only flattening part of a dimension
130101
int folded_dims = 0;
131102
{
132103
int suffix_size = 1;
133-
auto sizes_it = src_shape.rbegin();
134-
while (suffix_size < tgt_shape.back()) {
104+
auto sizes_it = expanded_ty.getShape().rbegin();
105+
while (suffix_size < collapsed_minor) {
135106
suffix_size *= *(sizes_it++);
136107
}
137108
// Make sure that the minor dim is folded only from entire major dims, not
138109
// from a part of some minor dim.
139-
if (suffix_size != tgt_shape.back()) {
140-
return;
110+
if (suffix_size != collapsed_minor) {
111+
return std::nullopt;
141112
}
142-
folded_dims = sizes_it - src_shape.rbegin();
113+
folded_dims = sizes_it - expanded_ty.getShape().rbegin();
143114
}
144115
DCHECK_GE(folded_dims, 2); // Should fold at least 2nd minor into minor.
145116

146117
// We don't handle slicing in the folded dims at the moment.
147118
if (ref_ty.getShape().take_back(folded_dims) !=
148-
src_ty.getShape().take_back(folded_dims)) {
119+
expanded_ty.getShape().take_back(folded_dims)) {
120+
return std::nullopt;
121+
}
122+
123+
return folded_dims;
124+
}
125+
126+
void optimizeLoadReshape(int hardware_generation,
127+
std::array<int64_t, 2> target_shape,
128+
Operation& raw_op) {
129+
// Below, we try to look for reshapes that flatten multiple dims into the
130+
// lane dimension. If the source of the reshape originates from a load of a
131+
// ref with 128 minor dimension (effectively untiled), we can replace the
132+
// load/reshape sequence with an efficient strided load. In essence, the
133+
// strided load creates vregs with a narrow slice along the target minor
134+
// dimension, but with the 2nd minor dim after the reshape already in
135+
// sublanes. The results of strided load can be concatenated to form the
136+
// final vector result.
137+
//
138+
// A little extra care needs to be applied to packed types, which we handle by
139+
// briefly extending to 32-bit and repacking them after concatenation.
140+
TypedValue<VectorType> src;
141+
VectorType tgt_ty;
142+
if (auto op = dyn_cast<tpu::ReshapeOp>(&raw_op)) {
143+
src = op.getSource();
144+
tgt_ty = op.getResult().getType();
145+
} else if (auto op = dyn_cast<vector::ShapeCastOp>(&raw_op)) {
146+
src = op.getSource();
147+
tgt_ty = op.getResult().getType();
148+
} else {
149149
return;
150150
}
151+
VectorType src_ty = src.getType();
152+
ArrayRef<int64_t> src_shape = src_ty.getShape();
153+
ArrayRef<int64_t> tgt_shape = tgt_ty.getShape();
154+
const int lane = target_shape[1];
155+
const int bitwidth = src_ty.getElementTypeBitWidth();
156+
const int packing = 32 / bitwidth;
157+
158+
auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp());
159+
// This rewrite might not be profitable if the load has other users.
160+
if (!load_op || !load_op.getBase().hasOneUse()) {
161+
return;
162+
}
163+
TypedValue<MemRefType> ref = load_op.getBase();
164+
MemRefType ref_ty = getMemRefType(ref);
165+
166+
auto maybe_folded_dims = canOptimizeReshapeMemory(
167+
hardware_generation, target_shape, ref, src_ty, tgt_ty);
168+
if (!maybe_folded_dims.has_value()) {
169+
return;
170+
}
171+
int folded_dims = *maybe_folded_dims;
151172

152173
Location loc = raw_op.getLoc();
153174
ImplicitLocOpBuilder b(loc, &raw_op);
@@ -277,68 +298,18 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
277298
MemRefType ref_ty = getMemRefType(base);
278299
VectorType src_ty = shape_cast_op.getSource().getType();
279300
VectorType tgt_ty = shape_cast_op.getResult().getType();
280-
if (src_ty.getRank() < 1 || tgt_ty.getRank() < 2) {
281-
return;
282-
}
283301
auto src_shape = src_ty.getShape();
284302
auto tgt_shape = tgt_ty.getShape();
285-
303+
const int64_t lane = target_shape[1];
286304
const int bitwidth = src_ty.getElementTypeBitWidth();
287305
const int packing = 32 / bitwidth;
288-
if (hardware_generation < 4 && packing > 1) {
289-
return;
290-
}
291-
292-
// The reshape below might be invalid if the memref is not contiguous, but it
293-
// is an overly conservative check (we don't need all dims to be contiguous).
294-
if (!isContiguousMemref(base)) {
295-
return;
296-
}
297-
const int64_t lane = target_shape[1];
298-
// Only handle the cases where the minor dim starts out as the number of lanes
299-
// and we fold at least the second minor dim into it, in a way that changes
300-
// its shape.
301-
if (tgt_shape.back() != lane ||
302-
src_shape.back() % (packing * lane) != 0 ||
303-
src_shape.back() == tgt_shape.back() ||
304-
src_shape.back() < llvm::product_of(tgt_shape.take_back(2))) {
305-
return;
306-
}
307-
// We don't handle memrefs with padding.
308-
auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout());
309-
if (!tiled_layout || tiled_layout.getTiles().empty()) {
310-
return;
311-
}
312-
ArrayRef<int64_t> front_tile = tiled_layout.getTiles().front().dimensions();
313-
ArrayRef<int64_t> ref_tiled_shape =
314-
ref_ty.getShape().take_back(front_tile.size());
315-
for (int i = 0; i < front_tile.size(); ++i) {
316-
if (ref_tiled_shape[i] % front_tile[i]) {
317-
return;
318-
}
319-
}
320-
321-
int expanded_dims = 0;
322-
{
323-
int suffix_size = 1;
324-
auto sizes_it = tgt_shape.rbegin();
325-
while (suffix_size < src_shape.back()) {
326-
suffix_size *= *(sizes_it++);
327-
}
328-
// Make sure the minor dim is expanded into its own dims and not folded into
329-
// other major dims.
330-
if (suffix_size != src_shape.back()) {
331-
return;
332-
}
333-
expanded_dims = sizes_it - tgt_shape.rbegin();
334-
}
335-
DCHECK_GE(expanded_dims, 2); // Minor should expand at least into 2 dims.
336306

337-
// We don't support slicing in the expanded dims at the moment.
338-
if (tgt_ty.getShape().take_back(expanded_dims) !=
339-
ref_ty.getShape().take_back(expanded_dims)) {
307+
std::optional<int> maybe_expanded_dims = canOptimizeReshapeMemory(
308+
hardware_generation, target_shape, base, tgt_ty, src_ty);
309+
if (!maybe_expanded_dims.has_value()) {
340310
return;
341311
}
312+
int expanded_dims = *maybe_expanded_dims;
342313

343314
ImplicitLocOpBuilder b(raw_op.getLoc(), &raw_op);
344315
auto loc = raw_op.getLoc();

0 commit comments

Comments
 (0)