- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[LLVM][ConstantFolding] Extend constantFoldVectorReduce to include scalable vectors. #165437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[LLVM][ConstantFolding] Extend constantFoldVectorReduce to include scalable vectors. #165437
Conversation
| @llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-analysis Author: Paul Walker (paulwalker-arm) ChangesFull diff: https://github.com/llvm/llvm-project/pull/165437.diff 2 Files Affected: 
 diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index e9e2e7d0316c7..bd84bd1a14113 100755
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -2163,18 +2163,39 @@ Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
 }
 
 Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
-  FixedVectorType *VT = dyn_cast<FixedVectorType>(Op->getType());
-  if (!VT)
-    return nullptr;
-
-  // This isn't strictly necessary, but handle the special/common case of zero:
-  // all integer reductions of a zero input produce zero.
-  if (isa<ConstantAggregateZero>(Op))
-    return ConstantInt::get(VT->getElementType(), 0);
+  auto *OpVT = cast<VectorType>(Op->getType());
 
   // This is the same as the underlying binops - poison propagates.
   if (isa<PoisonValue>(Op) || Op->containsPoisonElement())
-    return PoisonValue::get(VT->getElementType());
+    return PoisonValue::get(OpVT->getElementType());
+
+  // Shortcut non-accumulating reductions.
+  if (Constant *SplatVal = Op->getSplatValue()) {
+    switch (IID) {
+    case Intrinsic::vector_reduce_and:
+    case Intrinsic::vector_reduce_or:
+    case Intrinsic::vector_reduce_smin:
+    case Intrinsic::vector_reduce_smax:
+    case Intrinsic::vector_reduce_umin:
+    case Intrinsic::vector_reduce_umax:
+      return SplatVal;
+    case Intrinsic::vector_reduce_add:
+    case Intrinsic::vector_reduce_mul:
+      if (SplatVal->isZeroValue())
+        return SplatVal;
+      break;
+    case Intrinsic::vector_reduce_xor:
+      if (SplatVal->isZeroValue())
+        return SplatVal;
+      if (OpVT->getElementCount().isKnownMultipleOf(2))
+        return Constant::getNullValue(OpVT->getElementType());
+      break;
+    }
+  }
+
+  FixedVectorType *VT = dyn_cast<FixedVectorType>(OpVT);
+  if (!VT)
+    return nullptr;
 
   // TODO: Handle undef.
   auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U));
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
index 77a7f0d4e4acf..1ba8b0eff1d1a 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
@@ -12,8 +12,7 @@ define i32 @add_0() {
 
 define i32 @add_0_scalable_vector() {
 ; CHECK-LABEL: @add_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -89,8 +88,7 @@ define i32 @add_poison() {
 
 define i32 @add_poison_scalable_vector() {
 ; CHECK-LABEL: @add_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -123,8 +121,7 @@ define i32 @mul_0() {
 
 define i32 @mul_0_scalable_vector() {
 ; CHECK-LABEL: @mul_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -200,8 +197,7 @@ define i32 @mul_poison() {
 
 define i32 @mul_poison_scalable_vector() {
 ; CHECK-LABEL: @mul_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -225,8 +221,7 @@ define i32 @and_0() {
 
 define i32 @and_0_scalable_vector() {
 ; CHECK-LABEL: @and_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -242,8 +237,7 @@ define i32 @and_1() {
 
 define i32 @and_1_scalable_vector() {
 ; CHECK-LABEL: @and_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> splat (i32 1))
   ret i32 %x
@@ -302,8 +296,7 @@ define i32 @and_poison() {
 
 define i32 @and_poison_scalable_vector() {
 ; CHECK-LABEL: @and_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -327,8 +320,7 @@ define i32 @or_0() {
 
 define i32 @or_0_scalable_vector() {
 ; CHECK-LABEL: @or_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -344,8 +336,7 @@ define i32 @or_1() {
 
 define i32 @or_1_scalable_vector() {
 ; CHECK-LABEL: @or_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> splat (i32 1))
   ret i32 %x
@@ -404,8 +395,7 @@ define i32 @or_poison() {
 
 define i32 @or_poison_scalable_vector() {
 ; CHECK-LABEL: @or_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -429,8 +419,7 @@ define i32 @xor_0() {
 
 define i32 @xor_0_scalable_vector() {
 ; CHECK-LABEL: @xor_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -446,13 +435,21 @@ define i32 @xor_1() {
 
 define i32 @xor_1_scalable_vector() {
 ; CHECK-LABEL: @xor_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> splat(i32 1))
   ret i32 %x
 }
 
+define i32 @xor_1_scalable_vector_lane_count_not_known_even() {
+; CHECK-LABEL: @xor_1_scalable_vector_lane_count_not_known_even(
+; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv1i32(<vscale x 1 x i32> splat (i32 1))
+; CHECK-NEXT:    ret i32 [[X]]
+;
+  %x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 1 x i32> splat(i32 1))
+  ret i32 %x
+}
+
 define i32 @xor_inc() {
 ; CHECK-LABEL: @xor_inc(
 ; CHECK-NEXT:    ret i32 10
@@ -506,8 +503,7 @@ define i32 @xor_poison() {
 
 define i32 @xor_poison_scalable_vector() {
 ; CHECK-LABEL: @xor_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -531,8 +527,7 @@ define i32 @smin_0() {
 
 define i32 @smin_0_scalable_vector() {
 ; CHECK-LABEL: @smin_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -548,8 +543,7 @@ define i32 @smin_1() {
 
 define i32 @smin_1_scalable_vector() {
 ; CHECK-LABEL: @smin_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> splat(i32 1))
   ret i32 %x
@@ -608,8 +602,7 @@ define i32 @smin_poison() {
 
 define i32 @smin_poison_scalable_vector() {
 ; CHECK-LABEL: @smin_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -633,8 +626,7 @@ define i32 @smax_0() {
 
 define i32 @smax_0_scalable_vector() {
 ; CHECK-LABEL: @smax_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -650,8 +642,7 @@ define i32 @smax_1() {
 
 define i32 @smax_1_scalable_vector() {
 ; CHECK-LABEL: @smax_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> splat(i32 1))
   ret i32 %x
@@ -710,8 +701,7 @@ define i32 @smax_poison() {
 
 define i32 @smax_poison_scalable_vector() {
 ; CHECK-LABEL: @smax_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -735,8 +725,7 @@ define i32 @umin_0() {
 
 define i32 @umin_0_scalable_vector() {
 ; CHECK-LABEL: @umin_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -752,8 +741,7 @@ define i32 @umin_1() {
 
 define i32 @umin_1_scalable_vector() {
 ; CHECK-LABEL: @umin_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
   ret i32 %x
@@ -812,8 +800,7 @@ define i32 @umin_poison() {
 
 define i32 @umin_poison_scalable_vector() {
 ; CHECK-LABEL: @umin_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -837,8 +824,7 @@ define i32 @umax_0() {
 
 define i32 @umax_0_scalable_vector() {
 ; CHECK-LABEL: @umax_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -854,8 +840,7 @@ define i32 @umax_1() {
 
 define i32 @umax_1_scalable_vector() {
 ; CHECK-LABEL: @umax_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> splat(i32 1))
   ret i32 %x
@@ -914,8 +899,7 @@ define i32 @umax_poison() {
 
 define i32 @umax_poison_scalable_vector() {
 ; CHECK-LABEL: @umax_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
 | 
| auto *OpVT = cast<VectorType>(Op->getType()); | ||
|  | ||
| // This is the same as the underlying binops - poison propagates. | ||
| if (isa<PoisonValue>(Op) || Op->containsPoisonElement()) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interestingly, containsUndefinedElement doesn't try as hard as it could to find poison values in scalable vectors. For example, it could check if there is an insert of poison somewhere.
This is just an observation - no expectation from you to do anything!
However, it looks like the additional call to isa<PoisonValue>(Op) is redundant since containsPoisonElement asks the same question. Since you're here perhaps worth removing?
| return SplatVal; | ||
| case Intrinsic::vector_reduce_add: | ||
| case Intrinsic::vector_reduce_mul: | ||
| if (SplatVal->isZeroValue()) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory you can also do a short cut for a mul reduce with a splat value of 1, if you think it's worth adding?
…nxv##(1). Patch also uses isNullValue() rather than isZeroValue(), which is less relevant for integer types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks.
No description provided.