Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,104 +50,125 @@ namespace mlir::tpu {

namespace {

void optimizeLoadReshape(int hardware_generation,
std::array<int64_t, 2> target_shape,
Operation& raw_op) {
// Below, we try to look for reshapes that flatten multiple dims into the
// lane dimension. If the source of the reshape originates from a load of a
// ref with 128 minor dimension (effectively untiled), we can replace the
// load/reshape sequence with an efficient strided load. In essence, the
// strided load creates vregs with a narrow slice along the target minor
// dimension, but with the 2nd minor dim after the reshape already in
// sublanes. The results of strided load can be concatenated to form the
// final vector result.
//
// A little extra care needs to be applied to packed types, which we handle by
// briefly extending to 32-bit and repacking them after concatenation.
TypedValue<VectorType> src;
VectorType tgt_ty;
if (auto op = dyn_cast<tpu::ReshapeOp>(&raw_op)) {
src = op.getSource();
tgt_ty = op.getResult().getType();
} else if (auto op = dyn_cast<vector::ShapeCastOp>(&raw_op)) {
src = op.getSource();
tgt_ty = op.getResult().getType();
} else {
return;
}
VectorType src_ty = src.getType();
if (src_ty.getRank() < 2 || tgt_ty.getRank() < 1) {
return;
std::optional<int64_t> canOptimizeReshapeMemory(
int hardware_generation, std::array<int64_t, 2> target_shape,
TypedValue<MemRefType> ref, VectorType expanded_ty,
VectorType collapsed_ty) {
if (expanded_ty.getRank() < 2 || collapsed_ty.getRank() < 1) {
return std::nullopt;
}
const int bitwidth = src_ty.getElementTypeBitWidth();
const int bitwidth = expanded_ty.getElementTypeBitWidth();
const int packing = 32 / bitwidth;
if (hardware_generation < 4 && packing > 1) {
return;
}

auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp());
// This rewrite might not be profitable if the load has other users.
if (!load_op || !load_op.getBase().hasOneUse()) {
return;
return std::nullopt;
}

TypedValue<MemRefType> ref = load_op.getBase();
MemRefType ref_ty = getMemRefType(ref);
// The reshape below might be invalid if the memref is not contiguous, but it
// is an overly conservative check (we don't need all dims to be contiguous).
if (!isContiguousMemref(ref)) {
return;
return std::nullopt;
}

const int64_t lane = target_shape[1];
auto src_shape = src_ty.getShape();
auto tgt_shape = tgt_ty.getShape();
int64_t collapsed_minor = collapsed_ty.getShape().back();
int64_t expanded_minor = expanded_ty.getShape().back();
// Only handle the cases where the minor dim starts out as the number of lanes
// and we fold at least the second minor dim into it, in a way that changes
// its shape.
if (src_shape.back() != lane ||
tgt_shape.back() % (packing * lane) != 0 ||
tgt_shape.back() == src_shape.back() ||
tgt_shape.back() < llvm::product_of(src_shape.take_back(2))) {
return;
if (expanded_minor != lane ||
collapsed_minor % (packing * lane) != 0 ||
collapsed_minor == expanded_minor ||
collapsed_minor < llvm::product_of(expanded_ty.getShape().take_back(2))) {
return std::nullopt;
}

// We don't handle memrefs with padding.
MemRefType ref_ty = getMemRefType(ref);
auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout());
if (!tiled_layout || tiled_layout.getTiles().empty()) {
return;
return std::nullopt;
}
ArrayRef<int64_t> front_tile = tiled_layout.getTiles().front().dimensions();
ArrayRef<int64_t> ref_tiled_shape =
ref_ty.getShape().take_back(front_tile.size());
for (int i = 0; i < front_tile.size(); ++i) {
if (ref_tiled_shape[i] % front_tile[i]) {
return;
return std::nullopt;
}
}

// NOTE: We could generalize this to allow only flattening part of a dimension
int folded_dims = 0;
{
int suffix_size = 1;
auto sizes_it = src_shape.rbegin();
while (suffix_size < tgt_shape.back()) {
auto sizes_it = expanded_ty.getShape().rbegin();
while (suffix_size < collapsed_minor) {
suffix_size *= *(sizes_it++);
}
// Make sure that the minor dim is folded only from entire major dims, not
// from a part of some minor dim.
if (suffix_size != tgt_shape.back()) {
return;
if (suffix_size != collapsed_minor) {
return std::nullopt;
}
folded_dims = sizes_it - src_shape.rbegin();
folded_dims = sizes_it - expanded_ty.getShape().rbegin();
}
DCHECK_GE(folded_dims, 2); // Should fold at least 2nd minor into minor.

// We don't handle slicing in the folded dims at the moment.
if (ref_ty.getShape().take_back(folded_dims) !=
src_ty.getShape().take_back(folded_dims)) {
expanded_ty.getShape().take_back(folded_dims)) {
return std::nullopt;
}

return folded_dims;
}

void optimizeLoadReshape(int hardware_generation,
std::array<int64_t, 2> target_shape,
Operation& raw_op) {
// Below, we try to look for reshapes that flatten multiple dims into the
// lane dimension. If the source of the reshape originates from a load of a
// ref with 128 minor dimension (effectively untiled), we can replace the
// load/reshape sequence with an efficient strided load. In essence, the
// strided load creates vregs with a narrow slice along the target minor
// dimension, but with the 2nd minor dim after the reshape already in
// sublanes. The results of strided load can be concatenated to form the
// final vector result.
//
// A little extra care needs to be applied to packed types, which we handle by
// briefly extending to 32-bit and repacking them after concatenation.
TypedValue<VectorType> src;
VectorType tgt_ty;
if (auto op = dyn_cast<tpu::ReshapeOp>(&raw_op)) {
src = op.getSource();
tgt_ty = op.getResult().getType();
} else if (auto op = dyn_cast<vector::ShapeCastOp>(&raw_op)) {
src = op.getSource();
tgt_ty = op.getResult().getType();
} else {
return;
}
VectorType src_ty = src.getType();
ArrayRef<int64_t> src_shape = src_ty.getShape();
ArrayRef<int64_t> tgt_shape = tgt_ty.getShape();
const int lane = target_shape[1];
const int bitwidth = src_ty.getElementTypeBitWidth();
const int packing = 32 / bitwidth;

auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp());
// This rewrite might not be profitable if the load has other users.
if (!load_op || !load_op.getBase().hasOneUse()) {
return;
}
TypedValue<MemRefType> ref = load_op.getBase();
MemRefType ref_ty = getMemRefType(ref);

auto maybe_folded_dims = canOptimizeReshapeMemory(
hardware_generation, target_shape, ref, src_ty, tgt_ty);
if (!maybe_folded_dims.has_value()) {
return;
}
int folded_dims = *maybe_folded_dims;

Location loc = raw_op.getLoc();
ImplicitLocOpBuilder b(loc, &raw_op);
Expand Down Expand Up @@ -277,68 +298,18 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
MemRefType ref_ty = getMemRefType(base);
VectorType src_ty = shape_cast_op.getSource().getType();
VectorType tgt_ty = shape_cast_op.getResult().getType();
if (src_ty.getRank() < 1 || tgt_ty.getRank() < 2) {
return;
}
auto src_shape = src_ty.getShape();
auto tgt_shape = tgt_ty.getShape();

const int64_t lane = target_shape[1];
const int bitwidth = src_ty.getElementTypeBitWidth();
const int packing = 32 / bitwidth;
if (hardware_generation < 4 && packing > 1) {
return;
}

// The reshape below might be invalid if the memref is not contiguous, but it
// is an overly conservative check (we don't need all dims to be contiguous).
if (!isContiguousMemref(base)) {
return;
}
const int64_t lane = target_shape[1];
// Only handle the cases where the minor dim starts out as the number of lanes
// and we fold at least the second minor dim into it, in a way that changes
// its shape.
if (tgt_shape.back() != lane ||
src_shape.back() % (packing * lane) != 0 ||
src_shape.back() == tgt_shape.back() ||
src_shape.back() < llvm::product_of(tgt_shape.take_back(2))) {
return;
}
// We don't handle memrefs with padding.
auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout());
if (!tiled_layout || tiled_layout.getTiles().empty()) {
return;
}
ArrayRef<int64_t> front_tile = tiled_layout.getTiles().front().dimensions();
ArrayRef<int64_t> ref_tiled_shape =
ref_ty.getShape().take_back(front_tile.size());
for (int i = 0; i < front_tile.size(); ++i) {
if (ref_tiled_shape[i] % front_tile[i]) {
return;
}
}

int expanded_dims = 0;
{
int suffix_size = 1;
auto sizes_it = tgt_shape.rbegin();
while (suffix_size < src_shape.back()) {
suffix_size *= *(sizes_it++);
}
// Make sure the minor dim is expanded into its own dims and not folded into
// other major dims.
if (suffix_size != src_shape.back()) {
return;
}
expanded_dims = sizes_it - tgt_shape.rbegin();
}
DCHECK_GE(expanded_dims, 2); // Minor should expand at least into 2 dims.

// We don't support slicing in the expanded dims at the moment.
if (tgt_ty.getShape().take_back(expanded_dims) !=
ref_ty.getShape().take_back(expanded_dims)) {
std::optional<int> maybe_expanded_dims = canOptimizeReshapeMemory(
hardware_generation, target_shape, base, tgt_ty, src_ty);
if (!maybe_expanded_dims.has_value()) {
return;
}
int expanded_dims = *maybe_expanded_dims;

ImplicitLocOpBuilder b(raw_op.getLoc(), &raw_op);
auto loc = raw_op.getLoc();
Expand Down
Loading