From 7caf8a149e3e068bf416f8bb5d18ee7f8d3507cd Mon Sep 17 00:00:00 2001 From: TJ Date: Thu, 6 Jul 2023 19:15:19 +0000 Subject: [PATCH] kinput reduction fusion changes --- xla/service/gpu/gpu_fusible.cc | 287 +++++++++++++-------- xla/service/gpu/gpu_fusible.h | 10 +- xla/service/gpu/horizontal_input_fusion.cc | 3 +- xla/service/gpu/ir_emission_utils.cc | 129 ++++++++- xla/service/gpu/ir_emission_utils.h | 57 +++- xla/service/gpu/ir_emitter_unnested.cc | 214 ++++++++++----- xla/service/gpu/ir_emitter_unnested.h | 29 ++- 7 files changed, 527 insertions(+), 202 deletions(-) diff --git a/xla/service/gpu/gpu_fusible.cc b/xla/service/gpu/gpu_fusible.cc index a7e60e0f221c3..36035e838badc 100644 --- a/xla/service/gpu/gpu_fusible.cc +++ b/xla/service/gpu/gpu_fusible.cc @@ -91,14 +91,8 @@ bool IsPhysicallyTransposing(const HloInstruction& instr) { instr.shape(), instr.dimensions())); } -bool IsReduceInputFusion(const HloInstruction& instr) { - return instr.opcode() == HloOpcode::kFusion && - HasAnyUnnestedReductionRoot(instr.called_computations()[0]); -} - bool IsInputFusibleReduction(const HloInstruction& instr) { - return IsReduceInputFusion(instr) || - IsReductionFromOrToContiguousDimensions(instr); + return FindRealHero(instr).type == FusionHero::kReduction; } bool IsNestableVariadicReduction(const HloInstruction& instr) { @@ -119,54 +113,27 @@ bool IsTransposeInputFusion(const HloInstruction& instr) { } bool IsInputFusibleTranspose(const HloInstruction& instr) { - return FindAnyTiledTranspose(instr) || IsTransposeInputFusion(instr); -} - -const HloInstruction* GetRealHeroForMultiOutputFusion( - const HloInstruction& instr) { - if (instr.opcode() != HloOpcode::kFusion) { - return &instr; - } - auto fused_expression_root = instr.fused_expression_root(); - if (!instr.IsMultiOutputFusion()) { - if (IsReductionFromOrToContiguousDimensions(*fused_expression_root) || - FindAnyTiledTranspose(*fused_expression_root)) { - return &FindNonTrivialHero(*fused_expression_root); - } - return fused_expression_root; - } - // If possible, we want to pick a reduction-from-or-to-contiguous-dims - // operand of the fusion root or a tiled transpose, because they have the most - // constraints. Note that we cannot have both kinds at the same time, so once - // we find any, we can immediately return it. - for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionFromOrToContiguousDimensions(*inst) || - FindAnyTiledTranspose(*inst)) { - return &FindNonTrivialHero(*inst); - } - } - return fused_expression_root->operands()[0]; + FusionHero fh = FindRealHero(instr); + return fh.type == FusionHero::kTranspose; } // Returns whether the output of a fusion with reduction are consistent with // `first_reduce`. -static bool IsFusedReductionOutputConsistent( - const HloInstruction* inst, const HloInstruction* first_reduce) { - if (IsReductionFromOrToContiguousDimensions(*inst)) { - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - return ShapeUtil::EqualIgnoringElementType(first_reduce->shape(), - inst->shape()) && - ShapeUtil::EqualIgnoringElementType( - first_reduce->operand(0)->shape(), inst->operand(0)->shape()) && - ShapeUtil::EqualIgnoringElementType( - first_reduce->operand(1)->shape(), inst->operand(1)->shape()) && - first_reduce->dimensions() == inst->dimensions(); - } - return ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape()) && - LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout()); +static bool IsFusedReductionOutputConsistent(const FusionHero& a, + const FusionHero& b) { + if (a.type != FusionHero::kReduction || b.type != FusionHero::kReduction) { + return false; + } + + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + return ShapeUtil::EqualIgnoringElementType(a.hlo->shape(), b.hlo->shape()) && + a.hlo->dimensions() == b.hlo->dimensions() && + absl::c_equal(a.hlo->operands(), b.hlo->operands(), + [&](const HloInstruction* c1, const HloInstruction* c2) { + return ShapeUtil::EqualIgnoringElementType( + c1->shape(), c2->shape()); + }); } FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, @@ -235,34 +202,67 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, FusionDecision ShapesCompatibleForMultiOutputFusion( const HloInstruction& instr1, const HloInstruction& instr2) { - // Multi-output fusion kernels share a common parallel loop. The loop - // dimensions are determined by instruction shapes. - auto get_loop_shape = [&](const HloInstruction* element_instr) { - // Special-case reduction-to-vector ops: The loop dimensions are determined - // by the shape of the first operand. - if (IsReductionFromOrToContiguousDimensions(*element_instr) || - FindAnyTiledTranspose(*element_instr)) { - return FindNonTrivialHero(*element_instr).operand(0)->shape(); - } - return element_instr->shape(); - }; // All shapes of the root tuple of multi-output fusions should agree, i.e. all // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - const HloInstruction* hero1 = GetRealHeroForMultiOutputFusion(instr1); - const HloInstruction* hero2 = GetRealHeroForMultiOutputFusion(instr2); + FusionHero hero1 = FindRealHero(instr1); + FusionHero hero2 = FindRealHero(instr2); + + // TODO(cheshire): This should be covered elsewhere? + if (hero1.type == FusionHero::kScatter || + hero2.type == FusionHero::kScatter) { + return "scatter is not MOF-fusible"; + } + + // Multi-output fusion kernels share a common parallel loop. The loop + // dimensions are determined by instruction shapes. + auto get_loop_shape = [&](const FusionHero& hero) { + switch (hero.type) { + case FusionHero::kReduction: + case FusionHero::kTranspose: + return hero.hlo->operand(0)->shape(); + case FusionHero::kLoop: + return hero.hlo->shape(); + case FusionHero::kScatter: // TODO(cheshire): ?? + LOG(FATAL) << "Unexpected"; + } + }; if (auto compatible = FusionHeroesAreCompatible(hero1, hero2); !compatible) { return compatible; + if (hero1.type == FusionHero::kReduction && + hero2.type == FusionHero::kReduction) { + if (!IsFusedReductionOutputConsistent(hero1, hero2)) { + return "tiled reductions with different shapes"; + } + return {}; + } + + if (hero1.type == FusionHero::kTranspose && + hero2.type == FusionHero::kTranspose) { + if (!ShapeUtil::EqualIgnoringElementType(hero1.hlo->shape(), + hero2.hlo->shape()) || + !ShapeUtil::EqualIgnoringElementType(hero1.hlo->operand(0)->shape(), + hero2.hlo->operand(0)->shape())) { + return "tiled transposes with different shapes"; + } else { + return {}; + } + } + + if (hero1.type != FusionHero::kLoop && hero2.type != FusionHero::kLoop) { + return "MOF-fusion of a transpose and a reduction"; } const Shape& l1 = get_loop_shape(hero1); const Shape& l2 = get_loop_shape(hero2); - // We accept different shapes provided shapes are trivially reshapable. - bool accept_unequal_shape = !l1.IsTuple() && !l2.IsTuple(); + // We accept different shapes provided shapes are trivially reshapable. Not + // between variadic reductions though. + bool accept_unequal_shape = !l1.IsTuple() && !l2.IsTuple() + && (hero1.type != FusionHero::kLoop || hero2.type != FusionHero::kLoop); if (!ShapeUtil::EqualIgnoringElementType(l1, l2) && (!accept_unequal_shape || @@ -283,6 +283,7 @@ bool IsInputFusibleScatter(const HloInstruction& instr) { return false; } +// TODO(cheshire): This is something we really should refactor. bool IsInputFusible(const HloInstruction& instr) { // Input fusion only handles non-elemental reduction and scatter operations. return instr.IsFusible() && @@ -290,6 +291,32 @@ bool IsInputFusible(const HloInstruction& instr) { IsInputFusibleTranspose(instr)); } +const HloInstruction* GetRealHeroForMultiOutputFusion( + const HloInstruction& instr) { + if (instr.opcode() != HloOpcode::kFusion) { + return &instr; + } + auto fused_expression_root = instr.fused_expression_root(); + if (!instr.IsMultiOutputFusion()) { + if (IsReductionFromOrToContiguousDimensions(*fused_expression_root) || + FindAnyTiledTranspose(*fused_expression_root)) { + return &FindNonTrivialHero(*fused_expression_root); + } + return fused_expression_root; + } + // If possible, we want to pick a reduction-from-or-to-contiguous-dims + // operand of the fusion root or a tiled transpose, because they have the most + // constraints. Note that we cannot have both kinds at the same time, so once + // we find any, we can immediately return it. + for (const auto* inst : fused_expression_root->operands()) { + if (IsReductionFromOrToContiguousDimensions(*inst) || + FindAnyTiledTranspose(*inst)) { + return &FindNonTrivialHero(*inst); + } + } + return fused_expression_root->operands()[0]; +} + bool IsUniversallyLoopFusible(const HloInstruction& instr) { // Don't fuse get-tuple-element on GPU: We can, but it's slower than not // fusing. We never generate kernels for unfused GTEs. Instead, if an @@ -297,24 +324,22 @@ bool IsUniversallyLoopFusible(const HloInstruction& instr) { // compute the address of the GTE at the top of the kernel. Often we know the // address of the GTE result statically, so we can do this without chasing any // pointers. - return ( - (instr.IsElementwise() && instr.operand_count() > 0 && - instr.opcode() != HloOpcode::kCopy) || - (instr.opcode() == HloOpcode::kCopy && !FindAnyTiledTranspose(instr)) || - instr.opcode() == HloOpcode::kBitcast || + return instr.IsFusible() && + ((instr.IsElementwise() && instr.operand_count() > 0) || + instr.opcode() == HloOpcode::kBitcast || instr.opcode() == HloOpcode::kBroadcast || instr.opcode() == HloOpcode::kConcatenate || instr.opcode() == HloOpcode::kDynamicSlice || - instr.opcode() == HloOpcode::kDynamicUpdateSlice || - (instr.opcode() == HloOpcode::kFusion && - instr.fusion_kind() == HloInstruction::FusionKind::kLoop) || - instr.opcode() == HloOpcode::kGather || - instr.opcode() == HloOpcode::kPad || - instr.opcode() == HloOpcode::kReduceWindow || - instr.opcode() == HloOpcode::kReshape || - instr.opcode() == HloOpcode::kReverse || - instr.opcode() == HloOpcode::kSlice || - instr.opcode() == HloOpcode::kTranspose); + instr.opcode() == HloOpcode::kDynamicUpdateSlice || + (instr.opcode() == HloOpcode::kFusion) || + instr.opcode() == HloOpcode::kGather || + instr.opcode() == HloOpcode::kIota || + instr.opcode() == HloOpcode::kPad || + (instr.opcode() == HloOpcode::kReduce && + !instr.shape().IsTuple()) || // TODO(b/129089333): Don't fuse + // variadic reductions. + instr.opcode() == HloOpcode::kReduceWindow || + instr.opcode() == HloOpcode::kTranspose); } bool IsLoopFusibleAsConsumer(const HloInstruction& instr) { @@ -330,16 +355,69 @@ bool IsLoopFusibleAsProducer(const HloInstruction& instr) { instr.opcode() == HloOpcode::kConstant || // Non-variadic elemental reductions can be fused as producers. (instr.opcode() == HloOpcode::kReduce && - !IsReductionFromOrToContiguousDimensions(instr) && + // !IsReductionFromOrToContiguousDimensions(instr) && !instr.shape().IsTuple()))); } +static bool AllSatisfy(const HloInstruction& instr, + const HloPredicate& predicate) { + if (instr.opcode() != HloOpcode::kFusion) { + return predicate(&instr); + } + + // Magic number, TODO: refactor. + if (instr.fused_instruction_count() > 5) { + return false; + } + + return absl::c_all_of( + instr.fused_instructions(), [&](const HloInstruction* i) { + return i->opcode() == HloOpcode::kParameter || predicate(i); + }); +} + FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, const HloInstruction& consumer) { - if (!IsLoopFusibleAsProducer(producer) && - !(FindAnyTiledTranspose(producer) && - &FindNonTrivialHero(consumer) == &producer)) { - return "the producer is not loop-fusible"; + if (!IsLoopFusibleAsProducer(producer)) { + return "producer is not loop-fusible"; + } + + FusionHero consumer_hero = FindRealHero(consumer); + + // TODO(cheshire): Do we actually need this rule? + if (producer.opcode() == HloOpcode::kScatter) { + return "can't fuse scatter output"; + } + + if (IsReductionFromOrToContiguousDimensions(producer)) { + if (!AllSatisfy(consumer, &IsReduceIntermediate)) { + return "epilogue not fusible"; + } + + if (producer.user_count() > 1) { + return "reduction output fusion only works for single user"; + } + if (consumer_hero.hlo != &producer && + consumer_hero.type == FusionHero::kReduction) { + return "nested tiled reduction fusion is not beneficial"; + } + } + + if (FindAnyTiledTranspose(producer)) { + if (!AllSatisfy(consumer, &IsTransposeIntermediate)) { + return "epilogue not fusible"; + } + + if (producer.user_count() > 1) { + return "transpose output fusion only works for single user"; + } + if (consumer_hero.hlo != &producer && + consumer_hero.type == FusionHero::kTranspose) { + return "nested tiled transpose fusion is not beneficial"; + } + if (consumer_hero.type == FusionHero::kReduction) { + return "transpose into reduction fusion not supported"; + } } if (!IsInputFusible(consumer) && !IsLoopFusibleAsConsumer(consumer)) { @@ -403,7 +481,8 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { return "In-place operations are present"; } - if (!IsLoopFusibleAsProducer(producer)) { + if (!IsLoopFusibleAsProducer(producer) || + FindRealHero(producer).type != FusionHero::kLoop) { return "producer is not loop-fusible"; } @@ -416,15 +495,23 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { // Returns shared memory usage for a given instruction in bytes. static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) { - // For now we are only fusing reductions. - if (instr.opcode() == HloOpcode::kReduce && - IsReductionFromOrToContiguousDimensions(instr)) { + if (instr.opcode() == HloOpcode::kFusion) { + int64_t sum = 0; + for (const HloInstruction* hlo : + instr.fused_instructions_computation()->instructions()) { + sum += SharedMemoryUsageNoCache(*hlo); + } + return sum; + } + FusionHero h = FindRealHero(instr); + if (h.type == FusionHero::kReduction) { + const HloInstruction& hero = *h.hlo; ReductionDimensions reduction_info = - GetReductionKindAndContiguousComponents(instr); + GetReductionKindAndContiguousComponents(hero); int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType( - instr.operand(0)->shape().element_type()); + hero.operand(0)->shape().element_type()); int num_variadic = - instr.shape().IsTuple() ? instr.shape().tuple_shapes_size() : 1; + hero.shape().IsTuple() ? hero.shape().tuple_shapes_size() : 1; if (reduction_info.is_row_reduction) { // __shared__[32] is used for row reduction. return 32 * primitive_size * num_variadic; @@ -433,18 +520,11 @@ static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) { // from potential x-tiling). return 2 * 32 * 33 * primitive_size * num_variadic; } - } else if (FindAnyTiledTranspose(instr)) { + } else if (h.type == FusionHero::kTranspose) { // Tile size for transposition. int64_t primitive_size = - ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type()); + ShapeUtil::ByteSizeOfPrimitiveType(h.hlo->shape().element_type()); return 32 * 33 * primitive_size; - } else if (instr.opcode() == HloOpcode::kFusion) { - int64_t sum = 0; - for (const HloInstruction* hlo : - instr.fused_instructions_computation()->instructions()) { - sum += SharedMemoryUsageNoCache(*hlo); - } - return sum; } // Other fused expressions for now don't need the shared memory budget. return 0; @@ -475,8 +555,7 @@ constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8; // Returns the number of unnested reductions in the instruction output. static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) { - if (instr.opcode() == HloOpcode::kReduce && - IsReductionFromOrToContiguousDimensions(instr)) { + if (IsReductionFromOrToContiguousDimensions(instr)) { return 1; } if (instr.opcode() == HloOpcode::kFusion) { diff --git a/xla/service/gpu/gpu_fusible.h b/xla/service/gpu/gpu_fusible.h index e5080e49ad80f..47d98aa6918e6 100644 --- a/xla/service/gpu/gpu_fusible.h +++ b/xla/service/gpu/gpu_fusible.h @@ -103,11 +103,6 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, bool CreatesHeavyComputation(const HloInstruction& producer, const HloInstruction& consumer); -// Returns the instruction that determines the emitter used for lowering, -// sometimes referred to as "the real hero". -const HloInstruction* GetRealHeroForMultiOutputFusion( - const HloInstruction& instr); - // Whether 'hero1' and 'hero2' are compatible if the two fusions containing // 'hero1' and 'hero2' are merged together. For example merging two fusions with // a reduction hero and a transpose here, respectively, does not work. @@ -124,6 +119,11 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, FusionDecision ShapesCompatibleForMultiOutputFusion( const HloInstruction& instr1, const HloInstruction& instr2); +// Returns the instruction that determines the emitter used for lowering, +// sometimes referred to as "the real hero". +const HloInstruction* GetRealHeroForMultiOutputFusion( + const HloInstruction& instr); + // Whether the instructions are compatible for producer-consumer fusion // i.e. whether the producer and consumer are loop/input fusible and // they are not library calls. diff --git a/xla/service/gpu/horizontal_input_fusion.cc b/xla/service/gpu/horizontal_input_fusion.cc index 3cf913806b453..0abbaf02edd51 100644 --- a/xla/service/gpu/horizontal_input_fusion.cc +++ b/xla/service/gpu/horizontal_input_fusion.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/horizontal_input_fusion.h" #include +#include #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" @@ -30,7 +31,7 @@ namespace { // Gets the representative input shape of the multi-output fusion. Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { // Get the HLO that determines the emitter used for lowering. - const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr); + const HloInstruction* real_hero = FindRealHero(instr).hlo; if (real_hero->operands().empty()) { // Simply return an empty shape if the representative node has no input // operands. diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index 6b10c59aec4eb..241a00c6750fd 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -879,18 +879,15 @@ std::optional FindTiledLogicalTranspose( std::optional FindAnyTiledTranspose( const HloInstruction& instr) { - const HloInstruction& hero = FindNonTrivialHero(instr); - // TODO(b/284431534): Figure out how to make the shared memory transpose - // emitter faster for this case. if (hero.shape().element_type() == F32 && instr.shape().element_type() == S8) { return std::nullopt; } - if (auto d1 = FindTiledTranspose(hero)) { + if (auto d1 = FindTiledTranspose(instr)) { return d1; } - if (auto d2 = FindTiledLogicalTranspose(hero)) { + if (auto d2 = FindTiledLogicalTranspose(instr)) { return d2; } return std::nullopt; @@ -913,6 +910,128 @@ bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count) { instr->shape(), instr->dimensions())))); } +bool IsTransposeIntermediate(const HloInstruction* instr) { + return ( + instr->operand_count() <= 1 && instr->user_count() <= 1 && + ((instr->IsElementwise() && instr->opcode() != HloOpcode::kCopy) || + instr->opcode() == HloOpcode::kBitcast || + (instr->opcode() == HloOpcode::kReshape && + ShapeUtil::ReshapeIsBitcast(instr->operand(0)->shape(), + instr->shape())) || + (instr->opcode() == HloOpcode::kTranspose && + ShapeUtil::TransposeIsBitcast(instr->operand(0)->shape(), + instr->shape(), instr->dimensions())))); +} + +bool IsReduceIntermediate(const HloInstruction* instr) { + // TODO(cheshire): Do not duplicate logic with the function above. + if (instr->operand_count() > 1 || instr->user_count() > 1) { + return false; + } + + // Monotone functions are fine, and pure data movement is fine. + // The reason we want monotone, is that reduction with kMin and kMax can still + // use atomics. + switch (instr->opcode()) { + case HloOpcode::kAtan2: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kSqrt: + case HloOpcode::kRsqrt: + case HloOpcode::kTanh: + case HloOpcode::kReshape: // TODO(cheshire): only when it is a bitcast? + //case HloOpcode::kTranspose: TODO(cheshire): probably more support is + // required? Also: kCopy + return true; + // TODO(cheshire): Also support other bitcasts? + //return ShapeUtil::ReshapeIsBitcast(instr->operand(0)->shape(), + //instr->shape()); + default: + return false; + } +} + +// TODO(cheshire): Avoid duplication. +static const HloInstruction* FindNonTrivialReductionHero( + const HloInstruction& instr) { + const HloInstruction* idx = &instr; + while (IsReduceIntermediate(idx) && idx->operand_count() == 1) { + idx = idx->operand(0); + } + if (IsReductionFromOrToContiguousDimensions(*idx)) { + return idx; + } + return nullptr; +} + +static const HloInstruction* FindNonTrivialTransposeHero( + const HloInstruction& instr) { + const HloInstruction* idx = &instr; + + // Go up the chain of trivial elementwise(+bitcast, -copy) operations. Such + // chains are bound to be quite small, as we restrict the number of users as + // well. Note that no memoization is needed due to user number constraints: we + // never have to revisit same nodes. + while (IsTransposeIntermediate(idx) && idx->operand_count() == 1) { + idx = idx->operand(0); + } + if (FindAnyTiledTranspose(*idx)) { + return idx; + } + return nullptr; +} + +FusionHero FindRealHero(const HloComputation& cmp) { + // TODO(cheshire): Refactor the code to avoid this hack, currently this is OK + // as instructions are never modified. + std::vector roots = GetFusionRoots( + &const_cast(cmp)); + CHECK(!roots.empty()); + FusionHero found; + for (const HloInstruction* r : roots) { + FusionHero h = FindRealHero(*r); + if (found.hlo == nullptr || + static_cast(h.type) < static_cast(found.type)) { + found = h; + } + } + CHECK(found.hlo); + return found; +} + +// TODO(cheshire): Add cache. +FusionHero FindRealHero(const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kFusion) { + return FindRealHero(*hlo.fused_instructions_computation()); + } + + if (hlo.opcode() == HloOpcode::kScatter) { + return {&hlo, FusionHero::kScatter}; + } + + if (const HloInstruction* rh = FindNonTrivialReductionHero(hlo)) { + // No output fusions in case we have multiple users. + // No output fusions for reductions requiring atomics. + if (rh == &hlo || + (rh->user_count() == 1 && + ReductionIsRaceFree(GetReductionKindAndContiguousComponents(*rh)))) { + return {rh, FusionHero::kReduction}; + } + } + + if (const HloInstruction* th = FindNonTrivialTransposeHero(hlo)) { + if (th == &hlo || th->user_count() == 1) { + // No output fusions in case we need multiple users. + return {th, FusionHero::kTranspose}; + } + } + return {&hlo, FusionHero::kLoop}; +} + const HloInstruction& FindNonTrivialHero(const HloInstruction& instr) { const HloInstruction* idx = &instr; diff --git a/xla/service/gpu/ir_emission_utils.h b/xla/service/gpu/ir_emission_utils.h index 9a6a4f67a22e4..37a50838f6280 100644 --- a/xla/service/gpu/ir_emission_utils.h +++ b/xla/service/gpu/ir_emission_utils.h @@ -224,14 +224,54 @@ struct TransposeDimsAndParams { // Expected output: [R1] std::vector GetFusionRoots(HloComputation* computation); -// Returns whether the computation has at least one root triggering unnested -// reduction emitter. -bool HasAnyUnnestedReductionRoot(HloComputation* computation); - -const HloInstruction& FindNonTrivialHero(const HloInstruction& instr); +struct FusionHero { + // Preference in returning. + enum FusionType { + kScatter = -1, + kReduction = 0, + kTranspose = 1, + kLoop = 2 + }; + + const HloInstruction* hlo = nullptr; + FusionType type = kLoop; + + std::string ToString() { + std::string s; + absl::StrAppend(&s, hlo->ToString()); + absl::StrAppend(&s, "; "); + std::string t = [&] { + switch (type) { + case kReduction: + return "(reduction)"; + case kTranspose: + return "(transpose)"; + case kLoop: + return "(loop)"; + case kScatter: + return "(scatter)"; + } + }(); + absl::StrAppend(&s, t); + return s; + } +}; -// Whether there is a fusion root triggering transposition emitter. -bool HasAnyTiledTransposeRoot(HloComputation* computation); +// Finds real hero for the fusion. +// +// Invariant: walk up the graph, etc etc. +// TODO(cheshire): Write up definition. +// +// For instruction: find the real hero. +// +// For fusion: find the real hero of fusion root. +// +// For MOF: iterate through outputs, find "main" one. +// +// In case of multiple heros possible, always returns one with smallest fusion +// type. +FusionHero FindRealHero(const HloInstruction& hlo); +FusionHero FindRealHero(const HloComputation& cmp); struct TransposeDescription { Vector3 dimensions; @@ -265,6 +305,9 @@ std::optional FindAnyTiledTranspose( bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count = 1); +bool IsTransposeIntermediate(const HloInstruction* instr); +bool IsReduceIntermediate(const HloInstruction* instr); + // Log and verify an LLVM module. void LogAndVerify(const llvm::Module* m); diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 625cfbc2a82fa..e3af7fd163c9e 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -2066,6 +2067,7 @@ Status IrEmitterUnnested::EmitLoopFusion(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, GetOrCreateSubComputationFromRegion(&fusion.getRegion(), /*is_fusion=*/true)); + VLOG(10) << "Emitting: " << fused_computation->ToString(); const GpuDeviceInfo gpu_device_info = ir_emitter_context_->gpu_device_info(); @@ -2179,7 +2181,8 @@ Status IrEmitterUnnested::EmitLoopFusion(mlir::Operation* op) { } Status IrEmitterUnnested::EmitUnnestedTranspose( - mlir::lmhlo::FusionOp fusion, HloComputation* fused_computation) { + mlir::lmhlo::FusionOp fusion, HloComputation* fused_computation, + const FusionHero& hero) { std::vector hlo_roots = GetFusionRoots(fused_computation); // TODO(cheshire): avoid duplication of FindTiledTranspose function, is it @@ -2232,19 +2235,24 @@ Status IrEmitterUnnested::EmitUnnestedTranspose( fusion, fused_computation, absl::MakeSpan(ir_arrays).subspan(0, fusion.getInputBuffers().size()), absl::MakeSpan(ir_arrays).subspan(fusion.getInputBuffers().size()), - tiling_scheme, launch_dimensions)); + tiling_scheme, launch_dimensions, + hero)); return OkStatus(); } // Returns true if the fusion has consistent transpose heros. static bool HasConsistentTransposeHeros(HloComputation* fusion) { - std::vector hlo_roots = GetFusionRoots(fusion); + std::vectorhlo_roots = GetFusionRoots(fusion); if (!HasAnyTiledTransposeRoot(fusion)) { return false; } - const HloInstruction* first_transpose = &FindNonTrivialHero(**absl::c_find_if( + // TEST TJ + // const HloInstruction* first_transpose = &FindNonTrivialHero(**absl::c_find_if( + // hlo_roots, + // [](HloInstruction* instr) { return FindAnyTiledTranspose(*instr); })); + const HloInstruction* first_transpose = FindRealHero(**absl::c_find_if( hlo_roots, - [](HloInstruction* instr) { return FindAnyTiledTranspose(*instr); })); + [](HloInstruction* instr) { return FindAnyTiledTranspose(*instr); })).hlo; const Shape& transpose_in_shape = first_transpose->operand(0)->shape(); std::optional first_tiled_transpose = FindAnyTiledTranspose(*first_transpose); @@ -2288,7 +2296,7 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { } TF_ASSIGN_OR_RETURN( - HloComputation * fused_computation, + HloComputation* fused_computation, GetOrCreateSubComputationFromRegion(&fusion_op.getRegion(), /*is_fusion=*/true)); @@ -2314,14 +2322,18 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { } #endif // GOOGLE_CUDA - if (HasAnyUnnestedReductionRoot(fused_computation)) { - return EmitUnnestedReduction(fusion_op, fused_computation); - } + FusionHero fh = FindRealHero(*fused_computation); + if (fh.type == FusionHero::kReduction) { + return EmitUnnestedReduction(fusion_op, fused_computation, fh); + } // Triton fusions can have transposes too but they are intercepted earlier. // TODO(b/286029825): Do not generate fusions with inconsistent transposes. - if (HasConsistentTransposeHeros(fused_computation)) { - return EmitUnnestedTranspose(fusion_op, fused_computation); + // if (HasConsistentTransposeHeros(fused_computation)) { + // return EmitUnnestedTranspose(fusion_op, fused_computation); + // } + if (fh.type == FusionHero::kTranspose) { + return EmitUnnestedTranspose(fusion_op, fused_computation, fh); } auto fusion_results = fusion_op.getFusionResults(); @@ -2400,6 +2412,8 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( auto get_index = [&](const HloInstruction* instr) { const Shape& s = instr->shape(); + // TODO(cheshire): This doesn't quite work, need a more general converter + // here. return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s) ? index : index.SourceIndexOfBitcast(reduction_operand_shape, s, &b_); @@ -4141,6 +4155,7 @@ ReductionCodegenState IrEmitterUnnested::GenerateReductionCodegenState( mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info, absl::Span reduce_instr_index_group, FusedIrEmitter& fused_emitter) { + // TODO(cheshire): ReductionCodegenState reduction_codegen_state(reduction_info); VLOG(10) << "Emit prologue for reduction: " << llvm_ir::DumpToString(fusion); @@ -4235,8 +4250,8 @@ void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce( for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { absl::InlinedVector reduction_params; - for (auto acc : partial_result_addresses) { - reduction_params.push_back(acc.first); + for (auto [addr, type] : partial_result_addresses) { + reduction_params.push_back(addr); } for (auto [partial_result_address, element_type] : @@ -4276,12 +4291,44 @@ void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce( } } -llvm::Value* IrEmitterUnnested::GetOutputAddressForReduction( +static IrArray::Index TransformIndexForOutputFusion( + const IrArray::Index& hero_index, const HloReduceInstruction* hero, + const HloInstruction* root, llvm::IRBuilder<>* b) { + IrArray::Index out = hero_index; + const HloInstruction* it = root; + while (it != hero) { + CHECK_EQ(it->operand_count(), 1); + const HloInstruction* op = it->operand(0); + if (!ShapeUtil::EqualIgnoringElementType(op->shape(), it->shape())) { + // Only have: bitcast, transpose, reshape. + out = [&] { + switch (it->opcode()) { + // TODO(cheshire): For now, don't support transposing bitcast + // case HloOpcode::kTranspose: + //// TODO(cheshire): check that we are doing it the right way. + // return out.SourceIndexOfTranspose(op->shape(), it->shape(), + // InversePermutation(it->dimensions())); + // case HloOpcode::kReshape: + // return out.SourceIndexOfReshape(op->shape(), it->shape(), b); + default: + return out.SourceIndexOfBitcast(op->shape(), it->shape(), b); + } + }(); + } + it = op; + } + return out; +} + +IrArray::Index IrEmitterUnnested::GetOutputAddressForReduction( int partial_result_idx, llvm::Type* index_ty, const ReductionCodegenState& reduction_codegen_state, const TilingKernelInfo& tiling_kernel_info, + + // TODO(cheshire): Why are we even passing this argument around? const IrEmitterUnnested::ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int output_idx) { + const HloReduceInstruction* reduction, const HloInstruction* root, + int output_idx) { auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; @@ -4301,7 +4348,7 @@ llvm::Value* IrEmitterUnnested::GetOutputAddressForReduction( .AddOffsetToDim(start_offset_x, kDimX, &b_); }(); - const IrArray& output_array = output_arrays.at(reduction)[output_idx]; + // TODO(cheshire): Not correct for variadic reduction case. const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); Shape reduction_kept_element_shape = ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); @@ -4335,11 +4382,12 @@ llvm::Value* IrEmitterUnnested::GetOutputAddressForReduction( IrArray::Index element_index( /*linear=*/untransposed_output_linear_address, reduction_kept_element_shape, &b_); - IrArray::Index output_index(element_index.multidim(), output_array.GetShape(), + const Shape& output_shape = !reduction->shape().IsTuple() + ? reduction->shape() + : reduction->shape().tuple_shapes(output_idx); + IrArray::Index output_index(element_index.multidim(), output_shape, element_index.GetType()); - - return output_array.EmitArrayElementAddress(output_index, &b_, - "output_element_address"); + return TransformIndexForOutputFusion(output_index, reduction, root, &b_); } llvm::Value* IrEmitterUnnested::EmitBlockId(int32_t num_blocks, @@ -4407,18 +4455,34 @@ void IrEmitterUnnested::WriteReductionOutput( llvm::Type* index_ty, const ReductionCodegenState& reduction_codegen_state, const TilingKernelInfo& tiling_kernel_info, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx, - const absl::Span values) { + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx, const absl::Span values) { const HloComputation* reducer = reduction->to_apply(); for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { auto [output_ptr, type] = typed_ptr; - llvm::Value* output_address = GetOutputAddressForReduction( + + IrArray::Index output_index = GetOutputAddressForReduction( partial_result_idx, index_ty, reduction_codegen_state, - tiling_kernel_info, output_arrays, reduction, oidx); + tiling_kernel_info, output_arrays, reduction, root, oidx); + llvm::Value* output_address = + output_arrays.at(root)[oidx].EmitArrayElementAddress( + output_index, &b_, "output_element_address"); + if (reduction_codegen_state.IsRaceFree()) { - b_.CreateStore(b_.CreateLoad(type, output_ptr, "output"), output_address); + FusedIrEmitter fused_emitter(elemental_emitter_); + llvm::Value* loaded = b_.CreateLoad(type, output_ptr, "output"); + fused_emitter.BindGenerator( + *reduction, [&](const IrArray::Index& index) { return loaded; }); + + // TODO(cheshire): Better checks. + llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root); + llvm::Value* generated = *gen(output_index); + b_.CreateStore(generated, output_address); } else { CHECK_EQ(values.size(), 1); + CHECK_EQ(reduction, root) + << "output fusion is not allowed for racing reductions"; + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( *reducer, output_address, output_ptr, type)); } @@ -4429,7 +4493,8 @@ void IrEmitterUnnested::EmitReductionOutputForRowReduction( const TilingKernelInfo& tiling_kernel_info, const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx) { + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx) { const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; auto constant = [&](uint64_t c) -> llvm::Constant* { @@ -4455,8 +4520,8 @@ void IrEmitterUnnested::EmitReductionOutputForRowReduction( int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2]; int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size); EmitFullWarpShuffleDownLoopForReduce( - reducer, absl::MakeSpan(current_outputs), - tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); + reducer, current_outputs, tiling_scheme.GetNumThreadsPerBlockPhysical(), + num_rows_per_warp); KernelSupportLibrary ksl(&b_); llvm::Value* warp_id = @@ -4466,7 +4531,7 @@ void IrEmitterUnnested::EmitReductionOutputForRowReduction( const absl::Span values) { ksl.If("reduction_write_output", write_condition, [&] { WriteReductionOutput(index_ty, reduction_codegen_state, - tiling_kernel_info, output_arrays, reduction, + tiling_kernel_info, output_arrays, reduction, root, partial_result_idx, values); }); }; @@ -4540,7 +4605,8 @@ void IrEmitterUnnested::EmitReductionOutputForColumnReduction( const TilingKernelInfo& tiling_kernel_info, const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx) { + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx) { KernelSupportLibrary ksl(&b_); const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; @@ -4617,7 +4683,8 @@ void IrEmitterUnnested::EmitReductionOutputForColumnReduction( b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { WriteReductionOutput(index_ty, reduction_codegen_state, tiling_kernel_info, output_arrays, reduction, - partial_result_idx, shmem_transposed_addrs); + root, partial_result_idx, + shmem_transposed_addrs); }); } @@ -4787,9 +4854,15 @@ Status IrEmitterUnnested::EmitTransposeTile( absl::Span operand_arrays, absl::Span output_arrays, const TilingScheme& tiling_scheme, - const LaunchDimensions& launch_dimensions) { + const LaunchDimensions& launch_dimensions, + const FusionHero& hero) { std::vector hlo_roots = GetFusionRoots(fusion_hlo); FusedIrEmitter fused_emitter(elemental_emitter_); + const HloInstruction* first_transpose = hero.hlo; + const Shape& out_shape = first_transpose->shape(); + const Shape& transpose_in_shape = first_transpose->operand(0)->shape(); + + for (int i = 0; i < fusion_hlo->num_parameters(); i++) { llvm_ir::IrArray ir_array = operand_arrays[i]; HloInstruction* fused_operand = fusion_hlo->parameter_instruction(i); @@ -4804,13 +4877,12 @@ Status IrEmitterUnnested::EmitTransposeTile( absl::flat_hash_map tiles; Vector3 permutation; for (const auto& [tile_idx, root] : llvm::enumerate(hlo_roots)) { - if (auto tr = FindAnyTiledTranspose(*root)) { - permutation = tr->permutation; - const HloInstruction& hero = FindNonTrivialHero(*root); - tiles[&hero] = + FusionHero hero = FindRealHero(*root); + if (hero.type == FusionHero::kTranspose) { + tiles[hero.hlo] = AllocateShared(tiling_scheme, llvm_ir::PrimitiveTypeToIrType( - hero.operand(0)->shape().element_type(), module_), + hero.hlo->operand(0)->shape().element_type(), module_), {tiling_scheme.GetBlockTileSizeFor(permutation[kDimX]), tiling_scheme.GetBlockTileSizeFor(kDimX) + 1}, absl::StrCat("tr_tile_", tile_idx)); @@ -4834,16 +4906,16 @@ Status IrEmitterUnnested::EmitTransposeTile( scheduled_writes; for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { - if (FindAnyTiledTranspose(*root)) { - const HloInstruction& hero = FindNonTrivialHero(*root); + FusionHero hero = FindRealHero(*root); + if (hero.type == FusionHero::kTranspose) { llvm_ir::ElementGenerator input_gen = - *fused_emitter.GetGenerator(*hero.operand(0)); + *fused_emitter.GetGenerator(*hero.hlo->operand(0)); IrArray::Index untiled_index = GetUnnormalizedIndex(index, hero.operand(0)->shape(), &b_, tiling_scheme.GetDimsInElems()); llvm::Value* value = *input_gen(untiled_index); llvm::Value* addr = thread_id_info.GEPIntoSharedMemory( - &b_, tiles[&hero], {y_loc, x_loc}); + &b_, tiles[hero.hlo], {y_loc, x_loc}); b_.CreateStore(value, addr); } else { @@ -4876,8 +4948,9 @@ Status IrEmitterUnnested::EmitTransposeTile( const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc) { for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { - if (FindAnyTiledTranspose(*root)) { - const HloInstruction& hero = FindNonTrivialHero(*root); + FusionHero fh = FindRealHero(*root); + const HloInstruction& hero = *fh.hlo; + if (fh.type == FusionHero::kTranspose) { std::vector idx = {x_loc, y_loc}; llvm::Value* gep = @@ -5024,11 +5097,12 @@ static bool IsUnrollingColumnReductionBeneficial( std::vector hlo_roots = GetFusionRoots(fused_computation); for (int i = 0; i < fusion_roots.size(); i++) { - if (!reduction_is_race_free && - IsReductionFromOrToContiguousDimensions(*hlo_roots[i])) { - // Atomics cannot be vectorized. + FusionHero fh = FindRealHero(*hlo_roots[i]); + if (fh.type == FusionHero::kReduction) { + // Atomic.add of the reduction result can't be vectorized. cannot_be_vectorized++; } else { + // Write of the non-reduction result can be vectorized. can_be_vectorized++; } use_chain_endings.insert(fusion_roots[i]); @@ -5204,8 +5278,9 @@ static int64_t ProjectedShmemUsageBytes( for (const std::vector& group : instr_index_groups) { int64_t sum = 0; for (HloInstruction* root : group) { - if (IsReductionFromOrToContiguousDimensions(*root)) { - sum += SharedMemoryUsage(*root); + FusionHero fh = FindRealHero(*root); + if (fh.type == FusionHero::kReduction) { + sum += SharedMemoryUsage(*fh.hlo); } } out = std::max(out, sum); @@ -5215,7 +5290,7 @@ static int64_t ProjectedShmemUsageBytes( StatusOr IrEmitterUnnested::ComputeReductionCodegenInfo( mlir::lmhlo::FusionOp fusion, HloComputation* fused_computation, - HloInstruction* first_reduce, + const HloInstruction* first_reduce, const std::vector>& instr_index_groups) { Shape input_shape = first_reduce->operand(0)->shape(); ReductionDimensions reduction_dimensions = @@ -5385,18 +5460,22 @@ Status IrEmitterUnnested::EmitIRForReduction( absl::Span instr_index_group, FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, const ReductionCodegenInfo& reduction_info, const Shape& input_shape) { - std::vector reductions; + std::vector roots; + std::vector heros; ExtraOutputGensMap extra_output_gens; for (const HloInstruction* hlo : instr_index_group) { - if (IsReductionFromOrToContiguousDimensions(*hlo)) { - reductions.push_back(Cast(hlo)); + FusionHero fh = FindRealHero(*hlo); + if (fh.type == FusionHero::kReduction) { + auto hero = Cast(fh.hlo); + roots.push_back(hlo); + heros.push_back(hero); } else { extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); } } - CHECK(!reductions.empty()) << " expect at least one reduce instructions."; + CHECK(!roots.empty()) << " expect at least one reduce instructions."; const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); llvm::Type* index_ty = @@ -5405,7 +5484,7 @@ Status IrEmitterUnnested::EmitIRForReduction( tiling_scheme.GetNumberOfBlocksPhysical(), &b_); ReductionCodegenState codegen_state = GenerateReductionCodegenState( - fusion, reduction_info, reductions, fused_emitter); + fusion, reduction_info, heros, fused_emitter); EmitElementFunction emit_reduction_element = [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, @@ -5431,7 +5510,7 @@ Status IrEmitterUnnested::EmitIRForReduction( // Emit code to generate the input and perform the reduction computation // for each reduction instruction. - for (const HloReduceInstruction* reduce : reductions) { + for (const HloReduceInstruction* reduce : heros) { GenerateElementForReducer(reduce, partial_result_index, codegen_state, index_without_linear, input_index, num_partial_results, result_ir_arrays); @@ -5455,18 +5534,18 @@ Status IrEmitterUnnested::EmitIRForReduction( })); KernelSupportLibrary ksl(&b_); - for (const HloReduceInstruction* reduce : reductions) { + for (auto [reduce, root] : llvm::zip(heros, roots)) { for (int partial_result_idx = 0; partial_result_idx < reduction_info.GetNumPartialResults(); ++partial_result_idx) { if (codegen_state.IsRowReduction()) { EmitReductionOutputForRowReduction(tiling_kernel_info, codegen_state, index_ty, result_ir_arrays, reduce, - partial_result_idx); + root, partial_result_idx); } else { EmitReductionOutputForColumnReduction(tiling_kernel_info, codegen_state, index_ty, result_ir_arrays, - reduce, partial_result_idx); + reduce, root, partial_result_idx); } } } @@ -5505,7 +5584,8 @@ std::vector> GroupDisjointReductions( for (HloInstruction* root : roots) { disjoint_sets[root].Get() = root; - if (!IsReductionFromOrToContiguousDimensions(*root)) { + FusionHero fh = FindRealHero(*root); + if (fh.type != FusionHero::kReduction) { if (!first_non_reduction_root) { first_non_reduction_root = root; } else { @@ -5520,7 +5600,8 @@ std::vector> GroupDisjointReductions( std::vector reached_output_ids; bool added_to_reduce = false; for (HloInstruction* output : roots) { - if (IsReductionFromOrToContiguousDimensions(*output) && + FusionHero fh = FindRealHero(*output); + if (fh.type == FusionHero::kReduction && (hlo_query::IsBroadcastedConstantOrScalar(*instr))) { if (added_to_reduce) { // Do not group more than one output reduce instructions through @@ -5536,7 +5617,7 @@ std::vector> GroupDisjointReductions( VLOG(3) << "Reaching " << output->ToString() << " from " << instr->ToString(); reached_output_ids.push_back(output); - if (IsReductionFromOrToContiguousDimensions(*output)) { + if (fh.type == FusionHero::kReduction) { added_to_reduce = true; } } @@ -5562,7 +5643,8 @@ std::vector> GroupDisjointReductions( } // namespace Status IrEmitterUnnested::EmitUnnestedReduction( - mlir::lmhlo::FusionOp fusion, HloComputation* fused_computation) { + mlir::lmhlo::FusionOp fusion, HloComputation* fused_computation, + const FusionHero& hero) { llvm::SmallVector fusion_roots = fusion.getFusionRoots(); // Group disjoint reductions in groups, to be executed in parallel. @@ -5574,10 +5656,7 @@ Status IrEmitterUnnested::EmitUnnestedReduction( // hlo_roots has same ordering as fusion_roots. auto hlo_roots = GetFusionRoots(fused_computation); - HloInstruction* first_reduce = - *absl::c_find_if(hlo_roots, [](HloInstruction* instr) { - return IsReductionFromOrToContiguousDimensions(*instr); - }); + const HloInstruction* first_reduce = hero.hlo; // We always use the first reduce as representative to construct // ReductionCodegenInfo, since all the reductions are required to have the @@ -5600,7 +5679,8 @@ Status IrEmitterUnnested::EmitUnnestedReduction( << launch_dimensions.ToString(); if (!reduction_codegen_info.IsRaceFree()) { for (int i = 0; i < fusion_roots.size(); ++i) { - if (IsReductionFromOrToContiguousDimensions(*hlo_roots[i])) { + FusionHero fh = FindRealHero(*hlo_roots[i]); + if (fh.type == FusionHero::kReduction) { TF_RETURN_IF_ERROR(BuildFusedInitializerThunk(fusion, i)); } } diff --git a/xla/service/gpu/ir_emitter_unnested.h b/xla/service/gpu/ir_emitter_unnested.h index b769b089d5b4c..5e9eab7965b55 100644 --- a/xla/service/gpu/ir_emitter_unnested.h +++ b/xla/service/gpu/ir_emitter_unnested.h @@ -469,7 +469,8 @@ class IrEmitterUnnested : public IrEmitter { // instructions. In other words, a block_id_y is assigned to a group and so // different groups can be run in parallel. Status EmitUnnestedReduction(mlir::lmhlo::FusionOp fusion, - HloComputation* fused_computation); + HloComputation* fused_computation, + const FusionHero& hero); // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose // algorithm to improve the memory access patterns for the input parameters @@ -492,11 +493,9 @@ class IrEmitterUnnested : public IrEmitter { // // `kTileSize` should usually be same as warp size. We currently choose 32 for // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. - // - // TODO(b/33320379): Here each block transposes 1 tile. It may be more - // efficient to launch fewer blocks so each transposes many tiles. Status EmitUnnestedTranspose(mlir::lmhlo::FusionOp fusion, - HloComputation* fused_computation); + HloComputation* fused_computation, + const FusionHero& hero); // Computes the KernelMappingScheme for the reduce HLO and indicates whether // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo @@ -505,7 +504,7 @@ class IrEmitterUnnested : public IrEmitter { // reduce op. StatusOr ComputeReductionCodegenInfo( mlir::lmhlo::FusionOp fusion, HloComputation* fused_computation, - HloInstruction* first_reduce, + const HloInstruction* first_reduce, const std::vector>& instr_index_groups); // Generates code for input-fusible slices. @@ -563,7 +562,8 @@ class IrEmitterUnnested : public IrEmitter { absl::Span operand_arrays, absl::Span output_arrays, const TilingScheme& tiling_scheme, - const LaunchDimensions& launch_dimensions); + const LaunchDimensions& launch_dimensions, + const FusionHero& hero); Status EmitScatter(mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation); @@ -635,12 +635,13 @@ class IrEmitterUnnested : public IrEmitter { const TilingKernelInfo& tiling_kernel_info); // Returns the address to write the reduction output to. - llvm::Value* GetOutputAddressForReduction( + llvm_ir::IrArray::Index GetOutputAddressForReduction( int partial_result_idx, llvm::Type* index_ty, const ReductionCodegenState& reduction_codegen_state, const TilingKernelInfo& tiling_kernel_info, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int output_idx); + const HloReduceInstruction* reduction, const HloInstruction* root, + int output_idx); // Performs the actual write of the reduction result. using TypedPointer = std::pair; @@ -649,8 +650,8 @@ class IrEmitterUnnested : public IrEmitter { const ReductionCodegenState& reduction_codegen_state, const TilingKernelInfo& tiling_kernel_info, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx, - const absl::Span values); + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx, const absl::Span values); // `current_output`: the value the tile has calculated. // `output_address`: address where the output value has to be written. @@ -658,14 +659,16 @@ class IrEmitterUnnested : public IrEmitter { const TilingKernelInfo& tiling_kernel_info, const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx); + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx); // Same arguments as EmitReductionOutputForRowReduction. void EmitReductionOutputForColumnReduction( const TilingKernelInfo& tiling_kernel_info, const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx); + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx); // Emits code for reductions in the output_instructions. Status EmitIRForReduction(mlir::lmhlo::FusionOp fusion,