Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 183 additions & 104 deletions xla/service/gpu/gpu_fusible.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,8 @@ bool IsPhysicallyTransposing(const HloInstruction& instr) {
instr.shape(), instr.dimensions()));
}

bool IsReduceInputFusion(const HloInstruction& instr) {
return instr.opcode() == HloOpcode::kFusion &&
HasAnyUnnestedReductionRoot(instr.called_computations()[0]);
}

bool IsInputFusibleReduction(const HloInstruction& instr) {
return IsReduceInputFusion(instr) ||
IsReductionFromOrToContiguousDimensions(instr);
return FindRealHero(instr).type == FusionHero::kReduction;
}

bool IsNestableVariadicReduction(const HloInstruction& instr) {
Expand All @@ -119,54 +113,27 @@ bool IsTransposeInputFusion(const HloInstruction& instr) {
}

bool IsInputFusibleTranspose(const HloInstruction& instr) {
return FindAnyTiledTranspose(instr) || IsTransposeInputFusion(instr);
}

const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) {
return &instr;
}
auto fused_expression_root = instr.fused_expression_root();
if (!instr.IsMultiOutputFusion()) {
if (IsReductionFromOrToContiguousDimensions(*fused_expression_root) ||
FindAnyTiledTranspose(*fused_expression_root)) {
return &FindNonTrivialHero(*fused_expression_root);
}
return fused_expression_root;
}
// If possible, we want to pick a reduction-from-or-to-contiguous-dims
// operand of the fusion root or a tiled transpose, because they have the most
// constraints. Note that we cannot have both kinds at the same time, so once
// we find any, we can immediately return it.
for (const auto* inst : fused_expression_root->operands()) {
if (IsReductionFromOrToContiguousDimensions(*inst) ||
FindAnyTiledTranspose(*inst)) {
return &FindNonTrivialHero(*inst);
}
}
return fused_expression_root->operands()[0];
FusionHero fh = FindRealHero(instr);
return fh.type == FusionHero::kTranspose;
}

// Returns whether the output of a fusion with reduction are consistent with
// `first_reduce`.
static bool IsFusedReductionOutputConsistent(
const HloInstruction* inst, const HloInstruction* first_reduce) {
if (IsReductionFromOrToContiguousDimensions(*inst)) {
// Shapes, layouts and dimensions must be the same for all reduces
// inside of this fusion.
return ShapeUtil::EqualIgnoringElementType(first_reduce->shape(),
inst->shape()) &&
ShapeUtil::EqualIgnoringElementType(
first_reduce->operand(0)->shape(), inst->operand(0)->shape()) &&
ShapeUtil::EqualIgnoringElementType(
first_reduce->operand(1)->shape(), inst->operand(1)->shape()) &&
first_reduce->dimensions() == inst->dimensions();
}
return ShapeUtil::CompatibleIgnoringElementType(
first_reduce->operand(0)->shape(), inst->shape()) &&
LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
inst->shape().layout());
static bool IsFusedReductionOutputConsistent(const FusionHero& a,
const FusionHero& b) {
if (a.type != FusionHero::kReduction || b.type != FusionHero::kReduction) {
return false;
}

// Shapes, layouts and dimensions must be the same for all reduces
// inside of this fusion.
return ShapeUtil::EqualIgnoringElementType(a.hlo->shape(), b.hlo->shape()) &&
a.hlo->dimensions() == b.hlo->dimensions() &&
absl::c_equal(a.hlo->operands(), b.hlo->operands(),
[&](const HloInstruction* c1, const HloInstruction* c2) {
return ShapeUtil::EqualIgnoringElementType(
c1->shape(), c2->shape());
});
}

FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1,
Expand Down Expand Up @@ -235,34 +202,67 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1,

FusionDecision ShapesCompatibleForMultiOutputFusion(
const HloInstruction& instr1, const HloInstruction& instr2) {
// Multi-output fusion kernels share a common parallel loop. The loop
// dimensions are determined by instruction shapes.
auto get_loop_shape = [&](const HloInstruction* element_instr) {
// Special-case reduction-to-vector ops: The loop dimensions are determined
// by the shape of the first operand.
if (IsReductionFromOrToContiguousDimensions(*element_instr) ||
FindAnyTiledTranspose(*element_instr)) {
return FindNonTrivialHero(*element_instr).operand(0)->shape();
}
return element_instr->shape();
};

// All shapes of the root tuple of multi-output fusions should agree, i.e. all
// root ops should have equal output shapes. An exception are
// reduction-to-vector ops. Here the input shapes of the reduction (first
// operand shape) and the reduction dimensions need to match.
const HloInstruction* hero1 = GetRealHeroForMultiOutputFusion(instr1);
const HloInstruction* hero2 = GetRealHeroForMultiOutputFusion(instr2);
FusionHero hero1 = FindRealHero(instr1);
FusionHero hero2 = FindRealHero(instr2);

// TODO(cheshire): This should be covered elsewhere?
if (hero1.type == FusionHero::kScatter ||
hero2.type == FusionHero::kScatter) {
return "scatter is not MOF-fusible";
}

// Multi-output fusion kernels share a common parallel loop. The loop
// dimensions are determined by instruction shapes.
auto get_loop_shape = [&](const FusionHero& hero) {
switch (hero.type) {
case FusionHero::kReduction:
case FusionHero::kTranspose:
return hero.hlo->operand(0)->shape();
case FusionHero::kLoop:
return hero.hlo->shape();
case FusionHero::kScatter: // TODO(cheshire): ??
LOG(FATAL) << "Unexpected";
}
};

if (auto compatible = FusionHeroesAreCompatible(hero1, hero2); !compatible) {
return compatible;
if (hero1.type == FusionHero::kReduction &&
hero2.type == FusionHero::kReduction) {
if (!IsFusedReductionOutputConsistent(hero1, hero2)) {
return "tiled reductions with different shapes";
}
return {};
}

if (hero1.type == FusionHero::kTranspose &&
hero2.type == FusionHero::kTranspose) {
if (!ShapeUtil::EqualIgnoringElementType(hero1.hlo->shape(),
hero2.hlo->shape()) ||
!ShapeUtil::EqualIgnoringElementType(hero1.hlo->operand(0)->shape(),
hero2.hlo->operand(0)->shape())) {
return "tiled transposes with different shapes";
} else {
return {};
}
}

if (hero1.type != FusionHero::kLoop && hero2.type != FusionHero::kLoop) {
return "MOF-fusion of a transpose and a reduction";
}

const Shape& l1 = get_loop_shape(hero1);
const Shape& l2 = get_loop_shape(hero2);

// We accept different shapes provided shapes are trivially reshapable.
bool accept_unequal_shape = !l1.IsTuple() && !l2.IsTuple();
// We accept different shapes provided shapes are trivially reshapable. Not
// between variadic reductions though.
bool accept_unequal_shape = !l1.IsTuple() && !l2.IsTuple()
&& (hero1.type != FusionHero::kLoop || hero2.type != FusionHero::kLoop);

if (!ShapeUtil::EqualIgnoringElementType(l1, l2) &&
(!accept_unequal_shape ||
Expand All @@ -283,38 +283,63 @@ bool IsInputFusibleScatter(const HloInstruction& instr) {
return false;
}

// TODO(cheshire): This is something we really should refactor.
bool IsInputFusible(const HloInstruction& instr) {
// Input fusion only handles non-elemental reduction and scatter operations.
return instr.IsFusible() &&
(IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr) ||
IsInputFusibleTranspose(instr));
}

const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) {
return &instr;
}
auto fused_expression_root = instr.fused_expression_root();
if (!instr.IsMultiOutputFusion()) {
if (IsReductionFromOrToContiguousDimensions(*fused_expression_root) ||
FindAnyTiledTranspose(*fused_expression_root)) {
return &FindNonTrivialHero(*fused_expression_root);
}
return fused_expression_root;
}
// If possible, we want to pick a reduction-from-or-to-contiguous-dims
// operand of the fusion root or a tiled transpose, because they have the most
// constraints. Note that we cannot have both kinds at the same time, so once
// we find any, we can immediately return it.
for (const auto* inst : fused_expression_root->operands()) {
if (IsReductionFromOrToContiguousDimensions(*inst) ||
FindAnyTiledTranspose(*inst)) {
return &FindNonTrivialHero(*inst);
}
}
return fused_expression_root->operands()[0];
}

bool IsUniversallyLoopFusible(const HloInstruction& instr) {
// Don't fuse get-tuple-element on GPU: We can, but it's slower than not
// fusing. We never generate kernels for unfused GTEs. Instead, if an
// unfused GTE is an input to a kernel (including a fusion kernel), we
// compute the address of the GTE at the top of the kernel. Often we know the
// address of the GTE result statically, so we can do this without chasing any
// pointers.
return (
(instr.IsElementwise() && instr.operand_count() > 0 &&
instr.opcode() != HloOpcode::kCopy) ||
(instr.opcode() == HloOpcode::kCopy && !FindAnyTiledTranspose(instr)) ||
instr.opcode() == HloOpcode::kBitcast ||
return instr.IsFusible() &&
((instr.IsElementwise() && instr.operand_count() > 0) ||
instr.opcode() == HloOpcode::kBitcast ||
instr.opcode() == HloOpcode::kBroadcast ||
instr.opcode() == HloOpcode::kConcatenate ||
instr.opcode() == HloOpcode::kDynamicSlice ||
instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
(instr.opcode() == HloOpcode::kFusion &&
instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
instr.opcode() == HloOpcode::kGather ||
instr.opcode() == HloOpcode::kPad ||
instr.opcode() == HloOpcode::kReduceWindow ||
instr.opcode() == HloOpcode::kReshape ||
instr.opcode() == HloOpcode::kReverse ||
instr.opcode() == HloOpcode::kSlice ||
instr.opcode() == HloOpcode::kTranspose);
instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
(instr.opcode() == HloOpcode::kFusion) ||
instr.opcode() == HloOpcode::kGather ||
instr.opcode() == HloOpcode::kIota ||
instr.opcode() == HloOpcode::kPad ||
(instr.opcode() == HloOpcode::kReduce &&
!instr.shape().IsTuple()) || // TODO(b/129089333): Don't fuse
// variadic reductions.
instr.opcode() == HloOpcode::kReduceWindow ||
instr.opcode() == HloOpcode::kTranspose);
}

bool IsLoopFusibleAsConsumer(const HloInstruction& instr) {
Expand All @@ -330,16 +355,69 @@ bool IsLoopFusibleAsProducer(const HloInstruction& instr) {
instr.opcode() == HloOpcode::kConstant ||
// Non-variadic elemental reductions can be fused as producers.
(instr.opcode() == HloOpcode::kReduce &&
!IsReductionFromOrToContiguousDimensions(instr) &&
// !IsReductionFromOrToContiguousDimensions(instr) &&
!instr.shape().IsTuple())));
}

static bool AllSatisfy(const HloInstruction& instr,
const HloPredicate& predicate) {
if (instr.opcode() != HloOpcode::kFusion) {
return predicate(&instr);
}

// Magic number, TODO: refactor.
if (instr.fused_instruction_count() > 5) {
return false;
}

return absl::c_all_of(
instr.fused_instructions(), [&](const HloInstruction* i) {
return i->opcode() == HloOpcode::kParameter || predicate(i);
});
}

FusionDecision IsProducerConsumerFusible(const HloInstruction& producer,
const HloInstruction& consumer) {
if (!IsLoopFusibleAsProducer(producer) &&
!(FindAnyTiledTranspose(producer) &&
&FindNonTrivialHero(consumer) == &producer)) {
return "the producer is not loop-fusible";
if (!IsLoopFusibleAsProducer(producer)) {
return "producer is not loop-fusible";
}

FusionHero consumer_hero = FindRealHero(consumer);

// TODO(cheshire): Do we actually need this rule?
if (producer.opcode() == HloOpcode::kScatter) {
return "can't fuse scatter output";
}

if (IsReductionFromOrToContiguousDimensions(producer)) {
if (!AllSatisfy(consumer, &IsReduceIntermediate)) {
return "epilogue not fusible";
}

if (producer.user_count() > 1) {
return "reduction output fusion only works for single user";
}
if (consumer_hero.hlo != &producer &&
consumer_hero.type == FusionHero::kReduction) {
return "nested tiled reduction fusion is not beneficial";
}
}

if (FindAnyTiledTranspose(producer)) {
if (!AllSatisfy(consumer, &IsTransposeIntermediate)) {
return "epilogue not fusible";
}

if (producer.user_count() > 1) {
return "transpose output fusion only works for single user";
}
if (consumer_hero.hlo != &producer &&
consumer_hero.type == FusionHero::kTranspose) {
return "nested tiled transpose fusion is not beneficial";
}
if (consumer_hero.type == FusionHero::kReduction) {
return "transpose into reduction fusion not supported";
}
}

if (!IsInputFusible(consumer) && !IsLoopFusibleAsConsumer(consumer)) {
Expand Down Expand Up @@ -403,7 +481,8 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) {
return "In-place operations are present";
}

if (!IsLoopFusibleAsProducer(producer)) {
if (!IsLoopFusibleAsProducer(producer) ||
FindRealHero(producer).type != FusionHero::kLoop) {
return "producer is not loop-fusible";
}

Expand All @@ -416,15 +495,23 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) {

// Returns shared memory usage for a given instruction in bytes.
static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) {
// For now we are only fusing reductions.
if (instr.opcode() == HloOpcode::kReduce &&
IsReductionFromOrToContiguousDimensions(instr)) {
if (instr.opcode() == HloOpcode::kFusion) {
int64_t sum = 0;
for (const HloInstruction* hlo :
instr.fused_instructions_computation()->instructions()) {
sum += SharedMemoryUsageNoCache(*hlo);
}
return sum;
}
FusionHero h = FindRealHero(instr);
if (h.type == FusionHero::kReduction) {
const HloInstruction& hero = *h.hlo;
ReductionDimensions reduction_info =
GetReductionKindAndContiguousComponents(instr);
GetReductionKindAndContiguousComponents(hero);
int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(
instr.operand(0)->shape().element_type());
hero.operand(0)->shape().element_type());
int num_variadic =
instr.shape().IsTuple() ? instr.shape().tuple_shapes_size() : 1;
hero.shape().IsTuple() ? hero.shape().tuple_shapes_size() : 1;
if (reduction_info.is_row_reduction) {
// __shared__[32] is used for row reduction.
return 32 * primitive_size * num_variadic;
Expand All @@ -433,18 +520,11 @@ static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) {
// from potential x-tiling).
return 2 * 32 * 33 * primitive_size * num_variadic;
}
} else if (FindAnyTiledTranspose(instr)) {
} else if (h.type == FusionHero::kTranspose) {
// Tile size for transposition.
int64_t primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type());
ShapeUtil::ByteSizeOfPrimitiveType(h.hlo->shape().element_type());
return 32 * 33 * primitive_size;
} else if (instr.opcode() == HloOpcode::kFusion) {
int64_t sum = 0;
for (const HloInstruction* hlo :
instr.fused_instructions_computation()->instructions()) {
sum += SharedMemoryUsageNoCache(*hlo);
}
return sum;
}
// Other fused expressions for now don't need the shared memory budget.
return 0;
Expand Down Expand Up @@ -475,8 +555,7 @@ constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8;

// Returns the number of unnested reductions in the instruction output.
static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) {
if (instr.opcode() == HloOpcode::kReduce &&
IsReductionFromOrToContiguousDimensions(instr)) {
if (IsReductionFromOrToContiguousDimensions(instr)) {
return 1;
}
if (instr.opcode() == HloOpcode::kFusion) {
Expand Down
Loading