From 10bb57ac33e6b7b6a201baed35d4035f26d9819f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 11 Nov 2025 02:08:50 -0800 Subject: [PATCH] [Mosaic TPU][NFC] Share checks between the reshape/store and load/reshape optimization rules They support exactly the same cases and are symmetrical. PiperOrigin-RevId: 830814515 --- .../pre_canonicalization_optimization.cc | 189 ++++++++---------- 1 file changed, 80 insertions(+), 109 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc b/jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc index d4407cf19cfa..d1abb89afa0f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc @@ -50,79 +50,50 @@ namespace mlir::tpu { namespace { -void optimizeLoadReshape(int hardware_generation, - std::array target_shape, - Operation& raw_op) { - // Below, we try to look for reshapes that flatten multiple dims into the - // lane dimension. If the source of the reshape originates from a load of a - // ref with 128 minor dimension (effectively untiled), we can replace the - // load/reshape sequence with an efficient strided load. In essence, the - // strided load creates vregs with a narrow slice along the target minor - // dimension, but with the 2nd minor dim after the reshape already in - // sublanes. The results of strided load can be concatenated to form the - // final vector result. - // - // A little extra care needs to be applied to packed types, which we handle by - // briefly extending to 32-bit and repacking them after concatenation. - TypedValue src; - VectorType tgt_ty; - if (auto op = dyn_cast(&raw_op)) { - src = op.getSource(); - tgt_ty = op.getResult().getType(); - } else if (auto op = dyn_cast(&raw_op)) { - src = op.getSource(); - tgt_ty = op.getResult().getType(); - } else { - return; - } - VectorType src_ty = src.getType(); - if (src_ty.getRank() < 2 || tgt_ty.getRank() < 1) { - return; +std::optional canOptimizeReshapeMemory( + int hardware_generation, std::array target_shape, + TypedValue ref, VectorType expanded_ty, + VectorType collapsed_ty) { + if (expanded_ty.getRank() < 2 || collapsed_ty.getRank() < 1) { + return std::nullopt; } - const int bitwidth = src_ty.getElementTypeBitWidth(); + const int bitwidth = expanded_ty.getElementTypeBitWidth(); const int packing = 32 / bitwidth; if (hardware_generation < 4 && packing > 1) { - return; - } - - auto load_op = dyn_cast_if_present(src.getDefiningOp()); - // This rewrite might not be profitable if the load has other users. - if (!load_op || !load_op.getBase().hasOneUse()) { - return; + return std::nullopt; } - TypedValue ref = load_op.getBase(); - MemRefType ref_ty = getMemRefType(ref); // The reshape below might be invalid if the memref is not contiguous, but it // is an overly conservative check (we don't need all dims to be contiguous). if (!isContiguousMemref(ref)) { - return; + return std::nullopt; } const int64_t lane = target_shape[1]; - auto src_shape = src_ty.getShape(); - auto tgt_shape = tgt_ty.getShape(); + int64_t collapsed_minor = collapsed_ty.getShape().back(); + int64_t expanded_minor = expanded_ty.getShape().back(); // Only handle the cases where the minor dim starts out as the number of lanes // and we fold at least the second minor dim into it, in a way that changes // its shape. - if (src_shape.back() != lane || - tgt_shape.back() % (packing * lane) != 0 || - tgt_shape.back() == src_shape.back() || - tgt_shape.back() < llvm::product_of(src_shape.take_back(2))) { - return; + if (expanded_minor != lane || + collapsed_minor % (packing * lane) != 0 || + collapsed_minor == expanded_minor || + collapsed_minor < llvm::product_of(expanded_ty.getShape().take_back(2))) { + return std::nullopt; } // We don't handle memrefs with padding. + MemRefType ref_ty = getMemRefType(ref); auto tiled_layout = dyn_cast(ref_ty.getLayout()); if (!tiled_layout || tiled_layout.getTiles().empty()) { - return; + return std::nullopt; } ArrayRef front_tile = tiled_layout.getTiles().front().dimensions(); ArrayRef ref_tiled_shape = ref_ty.getShape().take_back(front_tile.size()); for (int i = 0; i < front_tile.size(); ++i) { if (ref_tiled_shape[i] % front_tile[i]) { - return; + return std::nullopt; } } @@ -130,24 +101,74 @@ void optimizeLoadReshape(int hardware_generation, int folded_dims = 0; { int suffix_size = 1; - auto sizes_it = src_shape.rbegin(); - while (suffix_size < tgt_shape.back()) { + auto sizes_it = expanded_ty.getShape().rbegin(); + while (suffix_size < collapsed_minor) { suffix_size *= *(sizes_it++); } // Make sure that the minor dim is folded only from entire major dims, not // from a part of some minor dim. - if (suffix_size != tgt_shape.back()) { - return; + if (suffix_size != collapsed_minor) { + return std::nullopt; } - folded_dims = sizes_it - src_shape.rbegin(); + folded_dims = sizes_it - expanded_ty.getShape().rbegin(); } DCHECK_GE(folded_dims, 2); // Should fold at least 2nd minor into minor. // We don't handle slicing in the folded dims at the moment. if (ref_ty.getShape().take_back(folded_dims) != - src_ty.getShape().take_back(folded_dims)) { + expanded_ty.getShape().take_back(folded_dims)) { + return std::nullopt; + } + + return folded_dims; +} + +void optimizeLoadReshape(int hardware_generation, + std::array target_shape, + Operation& raw_op) { + // Below, we try to look for reshapes that flatten multiple dims into the + // lane dimension. If the source of the reshape originates from a load of a + // ref with 128 minor dimension (effectively untiled), we can replace the + // load/reshape sequence with an efficient strided load. In essence, the + // strided load creates vregs with a narrow slice along the target minor + // dimension, but with the 2nd minor dim after the reshape already in + // sublanes. The results of strided load can be concatenated to form the + // final vector result. + // + // A little extra care needs to be applied to packed types, which we handle by + // briefly extending to 32-bit and repacking them after concatenation. + TypedValue src; + VectorType tgt_ty; + if (auto op = dyn_cast(&raw_op)) { + src = op.getSource(); + tgt_ty = op.getResult().getType(); + } else if (auto op = dyn_cast(&raw_op)) { + src = op.getSource(); + tgt_ty = op.getResult().getType(); + } else { return; } + VectorType src_ty = src.getType(); + ArrayRef src_shape = src_ty.getShape(); + ArrayRef tgt_shape = tgt_ty.getShape(); + const int lane = target_shape[1]; + const int bitwidth = src_ty.getElementTypeBitWidth(); + const int packing = 32 / bitwidth; + + auto load_op = dyn_cast_if_present(src.getDefiningOp()); + // This rewrite might not be profitable if the load has other users. + if (!load_op || !load_op.getBase().hasOneUse()) { + return; + } + TypedValue ref = load_op.getBase(); + MemRefType ref_ty = getMemRefType(ref); + + auto maybe_folded_dims = canOptimizeReshapeMemory( + hardware_generation, target_shape, ref, src_ty, tgt_ty); + if (!maybe_folded_dims.has_value()) { + return; + } + int folded_dims = *maybe_folded_dims; Location loc = raw_op.getLoc(); ImplicitLocOpBuilder b(loc, &raw_op); @@ -277,68 +298,18 @@ void optimizeStore(int hardware_generation, std::array target_shape, MemRefType ref_ty = getMemRefType(base); VectorType src_ty = shape_cast_op.getSource().getType(); VectorType tgt_ty = shape_cast_op.getResult().getType(); - if (src_ty.getRank() < 1 || tgt_ty.getRank() < 2) { - return; - } auto src_shape = src_ty.getShape(); auto tgt_shape = tgt_ty.getShape(); - + const int64_t lane = target_shape[1]; const int bitwidth = src_ty.getElementTypeBitWidth(); const int packing = 32 / bitwidth; - if (hardware_generation < 4 && packing > 1) { - return; - } - - // The reshape below might be invalid if the memref is not contiguous, but it - // is an overly conservative check (we don't need all dims to be contiguous). - if (!isContiguousMemref(base)) { - return; - } - const int64_t lane = target_shape[1]; - // Only handle the cases where the minor dim starts out as the number of lanes - // and we fold at least the second minor dim into it, in a way that changes - // its shape. - if (tgt_shape.back() != lane || - src_shape.back() % (packing * lane) != 0 || - src_shape.back() == tgt_shape.back() || - src_shape.back() < llvm::product_of(tgt_shape.take_back(2))) { - return; - } - // We don't handle memrefs with padding. - auto tiled_layout = dyn_cast(ref_ty.getLayout()); - if (!tiled_layout || tiled_layout.getTiles().empty()) { - return; - } - ArrayRef front_tile = tiled_layout.getTiles().front().dimensions(); - ArrayRef ref_tiled_shape = - ref_ty.getShape().take_back(front_tile.size()); - for (int i = 0; i < front_tile.size(); ++i) { - if (ref_tiled_shape[i] % front_tile[i]) { - return; - } - } - - int expanded_dims = 0; - { - int suffix_size = 1; - auto sizes_it = tgt_shape.rbegin(); - while (suffix_size < src_shape.back()) { - suffix_size *= *(sizes_it++); - } - // Make sure the minor dim is expanded into its own dims and not folded into - // other major dims. - if (suffix_size != src_shape.back()) { - return; - } - expanded_dims = sizes_it - tgt_shape.rbegin(); - } - DCHECK_GE(expanded_dims, 2); // Minor should expand at least into 2 dims. - // We don't support slicing in the expanded dims at the moment. - if (tgt_ty.getShape().take_back(expanded_dims) != - ref_ty.getShape().take_back(expanded_dims)) { + std::optional maybe_expanded_dims = canOptimizeReshapeMemory( + hardware_generation, target_shape, base, tgt_ty, src_ty); + if (!maybe_expanded_dims.has_value()) { return; } + int expanded_dims = *maybe_expanded_dims; ImplicitLocOpBuilder b(raw_op.getLoc(), &raw_op); auto loc = raw_op.getLoc();