package org.graalvm.compiler.nodes.extended;
import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_2;
import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_64;
import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_8;
import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_UNKNOWN;
import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_2;
import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_64;
import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_8;
import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_UNKNOWN;
import java.util.Arrays;
import org.graalvm.compiler.core.common.type.AbstractPointerStamp;
import org.graalvm.compiler.core.common.type.Stamp;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.NodeClass;
import org.graalvm.compiler.graph.NodeSuccessorList;
import org.graalvm.compiler.graph.spi.SimplifierTool;
import org.graalvm.compiler.nodeinfo.NodeCycles;
import org.graalvm.compiler.nodeinfo.NodeInfo;
import org.graalvm.compiler.nodeinfo.NodeSize;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.ControlSplitNode;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.ValueNode;
import jdk.vm.ci.meta.Constant;
@NodeInfo(cycles = CYCLES_UNKNOWN,
cyclesRationale = "We cannot estimate the runtime cost of a switch statement without knowing the number" +
"of case statements and the involved keys.",
size = SIZE_UNKNOWN,
sizeRationale = "We cannot estimate the code size of a switch statement without knowing the number" +
"of case statements.")
public abstract class SwitchNode extends ControlSplitNode {
public static final NodeClass<SwitchNode> TYPE = NodeClass.create(SwitchNode.class);
@Successor protected NodeSuccessorList<AbstractBeginNode> successors;
@Input protected ValueNode value;
protected final double[] keyProbabilities;
protected final int[] keySuccessors;
protected SwitchNode(NodeClass<? extends SwitchNode> c, ValueNode value, AbstractBeginNode[] successors, int[] keySuccessors, double[] keyProbabilities) {
super(c, StampFactory.forVoid());
assert value.stamp(NodeView.DEFAULT).getStackKind().isNumericInteger() || value.stamp(NodeView.DEFAULT) instanceof AbstractPointerStamp : value.stamp(NodeView.DEFAULT) +
" key not supported by SwitchNode";
assert keySuccessors.length == keyProbabilities.length;
this.successors = new NodeSuccessorList<>(this, successors);
this.value = value;
this.keySuccessors = keySuccessors;
this.keyProbabilities = keyProbabilities;
assert assertProbabilities();
}
private boolean assertProbabilities() {
double total = 0;
for (double d : keyProbabilities) {
total += d;
assert d >= 0.0 : "Cannot have negative probabilities in switch node: " + d;
}
assert total > 0.999 && total < 1.001 : "Total " + total;
return true;
}
@Override
public int getSuccessorCount() {
return successors.count();
}
@Override
public double probability(AbstractBeginNode successor) {
double sum = 0;
for (int i = 0; i < keySuccessors.length; i++) {
if (successors.get(keySuccessors[i]) == successor) {
sum += keyProbabilities[i];
}
}
return sum;
}
@Override
public boolean setProbability(AbstractBeginNode successor, double value) {
assert value <= 1.0 && value >= 0.0 : value;
assert assertProbabilities();
double sum = 0;
double otherSum = 0;
for (int i = 0; i < keySuccessors.length; i++) {
if (successors.get(keySuccessors[i]) == successor) {
sum += keyProbabilities[i];
} else {
otherSum += keyProbabilities[i];
}
}
if (otherSum == 0 || sum == 0) {
return false;
}
double delta = value - sum;
for (int i = 0; i < keySuccessors.length; i++) {
if (successors.get(keySuccessors[i]) == successor) {
keyProbabilities[i] = Math.max(0.0, keyProbabilities[i] + (delta * keyProbabilities[i]) / sum);
} else {
keyProbabilities[i] = Math.max(0.0, keyProbabilities[i] - (delta * keyProbabilities[i]) / otherSum);
}
}
assert assertProbabilities();
return true;
}
public ValueNode value() {
return value;
}
public abstract boolean isSorted();
public abstract int keyCount();
public abstract Constant keyAt(int i);
public boolean structureEquals(SwitchNode switchNode) {
return Arrays.equals(keySuccessors, switchNode.keySuccessors) && equalKeys(switchNode);
}
public abstract boolean equalKeys(SwitchNode switchNode);
public int keySuccessorIndex(int i) {
return keySuccessors[i];
}
public AbstractBeginNode keySuccessor(int i) {
return successors.get(keySuccessors[i]);
}
public double keyProbability(int i) {
return keyProbabilities[i];
}
public int defaultSuccessorIndex() {
return keySuccessors[keySuccessors.length - 1];
}
public AbstractBeginNode blockSuccessor(int i) {
return successors.get(i);
}
public void setBlockSuccessor(int i, AbstractBeginNode s) {
successors.set(i, s);
}
public int blockSuccessorCount() {
return successors.count();
}
public AbstractBeginNode defaultSuccessor() {
if (defaultSuccessorIndex() == -1) {
throw new GraalError("unexpected");
}
return successors.get(defaultSuccessorIndex());
}
@Override
public AbstractBeginNode getPrimarySuccessor() {
return null;
}
protected void killOtherSuccessors(SimplifierTool tool, int survivingEdge) {
for (Node successor : successors()) {
if (successor != blockSuccessor(survivingEdge)) {
tool.deleteBranch(successor);
}
}
tool.addToWorkList(blockSuccessor(survivingEdge));
graph().removeSplit(this, blockSuccessor(survivingEdge));
}
public abstract Stamp getValueStampForSuccessor(AbstractBeginNode beginNode);
@Override
public NodeCycles estimatedNodeCycles() {
if (keyCount() == 1) {
return CYCLES_2;
} else if (isSorted()) {
return CYCLES_8;
} else {
return CYCLES_64;
}
}
@Override
public NodeSize estimatedNodeSize() {
if (keyCount() == 1) {
return SIZE_2;
} else if (isSorted()) {
return SIZE_8;
} else {
return SIZE_64;
}
}
}