Skip to content

Commit 9b7db93

Browse files
committed
[GR-68519] Vector API: Implement VectorMask::compress
PullRequest: graal/21779
2 parents c3fc9c4 + a9426d2 commit 9b7db93

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/vector/replacements/vectorapi/nodes/VectorAPICompressExpandOpNode.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,17 @@
3636
import jdk.graal.compiler.graph.NodeMap;
3737
import jdk.graal.compiler.nodeinfo.NodeInfo;
3838
import jdk.graal.compiler.nodes.FrameState;
39+
import jdk.graal.compiler.nodes.NodeView;
3940
import jdk.graal.compiler.nodes.ValueNode;
41+
import jdk.graal.compiler.nodes.calc.CompressBitsNode;
4042
import jdk.graal.compiler.nodes.spi.Canonicalizable;
4143
import jdk.graal.compiler.nodes.spi.CanonicalizerTool;
4244
import jdk.graal.compiler.nodes.spi.CoreProviders;
4345
import jdk.graal.compiler.replacements.nodes.MacroNode.MacroParams;
4446
import jdk.graal.compiler.vector.architecture.VectorArchitecture;
47+
import jdk.graal.compiler.vector.nodes.amd64.IntegerToOpMaskNode;
48+
import jdk.graal.compiler.vector.nodes.amd64.OpMaskToIntegerNode;
49+
import jdk.graal.compiler.vector.nodes.simd.LogicValueStamp;
4550
import jdk.graal.compiler.vector.nodes.simd.SimdConstant;
4651
import jdk.graal.compiler.vector.nodes.simd.SimdCompressNode;
4752
import jdk.graal.compiler.vector.nodes.simd.SimdExpandNode;
@@ -103,7 +108,7 @@ private ValueNode mask() {
103108

104109
@Override
105110
public Iterable<ValueNode> vectorInputs() {
106-
return List.of(source(), mask());
111+
return source().isNullConstant() ? List.of(mask()) : List.of(source(), mask());
107112
}
108113

109114
@Override
@@ -121,7 +126,7 @@ public Node canonical(CanonicalizerTool tool) {
121126

122127
ValueNode[] args = toArgumentArray();
123128
ObjectStamp newSpeciesStamp = improveResultBoxStamp(tool);
124-
SimdStamp newVectorStamp = improveVectorStamp(vectorStamp, args, VCLASS_ARG_INDEX, ECLASS_ARG_INDEX, LENGTH_ARG_INDEX, tool);
129+
SimdStamp newVectorStamp = improveResultStamp(vectorStamp, args, tool);
125130
if (newSpeciesStamp != speciesStamp || newVectorStamp != vectorStamp) {
126131
return new VectorAPICompressExpandOpNode(copyParamsWithImprovedStamp(newSpeciesStamp), newVectorStamp, null, stateAfter());
127132
}
@@ -138,7 +143,7 @@ public boolean canExpand(VectorArchitecture vectorArch, EconomicMap<ValueNode, S
138143
int opr = opr().asJavaConstant().asInt();
139144
GraalError.guarantee(opr == COMPRESS_OP || opr == EXPAND_OP || opr == MASK_COMPRESS_OP, "%d", opr);
140145
if (opr == MASK_COMPRESS_OP) {
141-
return false;
146+
return elementStamp instanceof LogicValueStamp;
142147
} else {
143148
return vectorArch.getSupportedVectorCompressExpandLength(elementStamp, vectorStamp.getVectorLength()) == vectorStamp.getVectorLength();
144149
}
@@ -147,13 +152,18 @@ public boolean canExpand(VectorArchitecture vectorArch, EconomicMap<ValueNode, S
147152
@Override
148153
public ValueNode expand(VectorArchitecture vectorArch, NodeMap<ValueNode> expanded) {
149154
int opr = opr().asJavaConstant().asInt();
150-
ValueNode src = expanded.get(source());
151155
ValueNode mask = expanded.get(mask());
152156
if (opr == COMPRESS_OP) {
157+
ValueNode src = expanded.get(source());
153158
return SimdCompressNode.create(src, mask);
154-
} else {
155-
GraalError.guarantee(opr == EXPAND_OP, "%d", opr);
159+
} else if (opr == EXPAND_OP) {
160+
ValueNode src = expanded.get(source());
156161
return SimdExpandNode.create(src, mask);
162+
} else {
163+
GraalError.guarantee(opr == MASK_COMPRESS_OP, "unexpected opcode %d", opr);
164+
ValueNode maskToInt = OpMaskToIntegerNode.create(mask);
165+
ValueNode compressedInt = new CompressBitsNode(maskToInt, maskToInt);
166+
return new IntegerToOpMaskNode(compressedInt, mask.stamp(NodeView.DEFAULT).unrestricted());
157167
}
158168
}
159169

0 commit comments

Comments
 (0)