package org.graalvm.compiler.phases.common;
import static org.graalvm.compiler.core.common.SpeculativeExecutionAttacksMitigations.NonDeoptGuardTargets;
import static org.graalvm.compiler.core.common.SpeculativeExecutionAttacksMitigations.Options.MitigateSpeculativeExecutionAttacks;
import static org.graalvm.compiler.core.common.SpeculativeExecutionAttacksMitigations.Options.UseIndexMasking;
import java.util.ArrayList;
import java.util.List;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.Position;
import org.graalvm.compiler.nodeinfo.InputType;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.DeoptimizeNode;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.LoopExitNode;
import org.graalvm.compiler.nodes.NamedLocationIdentity;
import org.graalvm.compiler.nodes.PiNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.extended.MultiGuardNode;
import org.graalvm.compiler.nodes.memory.MemoryAccess;
import org.graalvm.compiler.phases.Phase;
import jdk.vm.ci.meta.DeoptimizationReason;
public class InsertGuardFencesPhase extends Phase {
@Override
protected void run(StructuredGraph graph) {
for (AbstractBeginNode beginNode : graph.getNodes(AbstractBeginNode.TYPE)) {
if (hasGuardUsages(beginNode)) {
if (MitigateSpeculativeExecutionAttacks.getValue(graph.getOptions()) == NonDeoptGuardTargets) {
if (isDeoptGuard(beginNode)) {
graph.getDebug().log(DebugContext.VERBOSE_LEVEL, "Skipping deoptimizing guard speculation fence at %s", beginNode);
continue;
}
}
if (UseIndexMasking.getValue(graph.getOptions())) {
if (isBoundsCheckGuard(beginNode)) {
graph.getDebug().log(DebugContext.VERBOSE_LEVEL, "Skipping bounds-check speculation fence at %s", beginNode);
continue;
}
}
if (graph.getDebug().isLogEnabled(DebugContext.DETAILED_LEVEL)) {
graph.getDebug().log(DebugContext.DETAILED_LEVEL, "Adding speculation fence for %s at %s", guardUsages(beginNode), beginNode);
} else {
graph.getDebug().log(DebugContext.VERBOSE_LEVEL, "Adding speculation fence at %s", beginNode);
}
beginNode.setWithSpeculationFence();
} else {
graph.getDebug().log(DebugContext.DETAILED_LEVEL, "No guards on %s", beginNode);
}
}
}
private static boolean isDeoptGuard(AbstractBeginNode beginNode) {
if (!(beginNode.predecessor() instanceof IfNode)) {
return false;
}
IfNode ifNode = (IfNode) beginNode.predecessor();
AbstractBeginNode otherBegin;
if (ifNode.trueSuccessor() == beginNode) {
otherBegin = ifNode.falseSuccessor();
} else {
assert ifNode.falseSuccessor() == beginNode;
otherBegin = ifNode.trueSuccessor();
}
if (!(otherBegin.next() instanceof DeoptimizeNode)) {
return false;
}
DeoptimizeNode deopt = (DeoptimizeNode) otherBegin.next();
return deopt.getAction().doesInvalidateCompilation();
}
public static final IntegerStamp POSITIVE_ARRAY_INDEX_STAMP = StampFactory.forInteger(32, 0, Integer.MAX_VALUE - 1);
private static boolean isBoundsCheckGuard(AbstractBeginNode beginNode) {
if (!(beginNode.predecessor() instanceof IfNode)) {
return false;
}
IfNode ifNode = (IfNode) beginNode.predecessor();
AbstractBeginNode otherBegin;
if (ifNode.trueSuccessor() == beginNode) {
otherBegin = ifNode.falseSuccessor();
} else {
assert ifNode.falseSuccessor() == beginNode;
otherBegin = ifNode.trueSuccessor();
}
if (otherBegin.next() instanceof DeoptimizeNode) {
DeoptimizeNode deopt = (DeoptimizeNode) otherBegin.next();
if (deopt.getReason() == DeoptimizationReason.BoundsCheckException && !hasMultipleGuardUsages(beginNode)) {
return true;
}
} else if (otherBegin instanceof LoopExitNode && beginNode.usages().filter(MultiGuardNode.class).isNotEmpty() && !hasMultipleGuardUsages(beginNode)) {
return true;
}
for (Node usage : beginNode.usages()) {
for (Position pos : usage.inputPositions()) {
if (pos.getInputType() == InputType.Guard && pos.get(usage) == beginNode) {
if (usage instanceof PiNode) {
if (!((PiNode) usage).piStamp().equals(POSITIVE_ARRAY_INDEX_STAMP)) {
return false;
}
} else if (usage instanceof MemoryAccess) {
if (!NamedLocationIdentity.isArrayLocation(((MemoryAccess) usage).getLocationIdentity())) {
return false;
}
} else {
return false;
}
break;
}
}
}
return true;
}
private static boolean hasGuardUsages(Node n) {
for (Node usage : n.usages()) {
for (Position pos : usage.inputPositions()) {
if (pos.getInputType() == InputType.Guard && pos.get(usage) == n) {
return true;
}
}
}
return false;
}
private static boolean hasMultipleGuardUsages(Node n) {
boolean foundOne = false;
for (Node usage : n.usages()) {
for (Position pos : usage.inputPositions()) {
if (pos.getInputType() == InputType.Guard && pos.get(usage) == n) {
if (foundOne) {
return true;
}
foundOne = true;
}
}
}
return false;
}
private static List<Node> guardUsages(Node n) {
List<Node> ret = new ArrayList<>();
for (Node usage : n.usages()) {
for (Position pos : usage.inputPositions()) {
if (pos.getInputType() == InputType.Guard && pos.get(usage) == n) {
ret.add(usage);
}
}
}
return ret;
}
}