package org.graalvm.compiler.loop.phases;
import static org.graalvm.compiler.core.common.GraalOptions.MaximumDesiredSize;
import static org.graalvm.compiler.loop.MathUtil.add;
import static org.graalvm.compiler.loop.MathUtil.sub;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.graalvm.compiler.core.common.RetryableBailoutException;
import org.graalvm.compiler.core.common.calc.Condition;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.graph.Graph.Mark;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.Position;
import org.graalvm.compiler.loop.CountedLoopInfo;
import org.graalvm.compiler.loop.InductionVariable;
import org.graalvm.compiler.loop.InductionVariable.Direction;
import org.graalvm.compiler.loop.LoopEx;
import org.graalvm.compiler.loop.LoopFragmentInside;
import org.graalvm.compiler.loop.LoopFragmentWhole;
import org.graalvm.compiler.nodeinfo.InputType;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.AbstractEndNode;
import org.graalvm.compiler.nodes.AbstractMergeNode;
import org.graalvm.compiler.nodes.BeginNode;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.ControlSplitNode;
import org.graalvm.compiler.nodes.EndNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.LogicNode;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.LoopExitNode;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.SafepointNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.CompareNode;
import org.graalvm.compiler.nodes.calc.ConditionalNode;
import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
import org.graalvm.compiler.nodes.extended.SwitchNode;
import org.graalvm.compiler.phases.common.CanonicalizerPhase;
import org.graalvm.compiler.phases.tiers.PhaseContext;
public abstract class LoopTransformations {
private LoopTransformations() {
}
public static void peel(LoopEx loop) {
loop.inside().duplicate().insertBefore(loop);
loop.loopBegin().setLoopFrequency(Math.max(0.0, loop.loopBegin().loopFrequency() - 1));
}
public static void fullUnroll(LoopEx loop, PhaseContext context, CanonicalizerPhase canonicalizer) {
LoopBeginNode loopBegin = loop.loopBegin();
StructuredGraph graph = loopBegin.graph();
int initialNodeCount = graph.getNodeCount();
while (!loopBegin.isDeleted()) {
Mark mark = graph.getMark();
peel(loop);
canonicalizer.applyIncremental(graph, context, mark);
loop.invalidateFragments();
if (graph.getNodeCount() > initialNodeCount + MaximumDesiredSize.getValue(graph.getOptions()) * 2) {
throw new RetryableBailoutException("FullUnroll : Graph seems to grow out of proportion");
}
}
}
public static void unswitch(LoopEx loop, List<ControlSplitNode> controlSplitNodeSet) {
ControlSplitNode firstNode = controlSplitNodeSet.iterator().next();
LoopFragmentWhole originalLoop = loop.whole();
StructuredGraph graph = firstNode.graph();
loop.loopBegin().incrementUnswitches();
ControlSplitNode newControlSplit = (ControlSplitNode) firstNode.copyWithInputs();
originalLoop.entryPoint().replaceAtPredecessor(newControlSplit);
Iterator<Position> successors = firstNode.successorPositions().iterator();
assert successors.hasNext();
Position firstPosition = successors.next();
AbstractBeginNode originalLoopBegin = BeginNode.begin(originalLoop.entryPoint());
firstPosition.set(newControlSplit, originalLoopBegin);
while (successors.hasNext()) {
Position position = successors.next();
LoopFragmentWhole duplicateLoop = originalLoop.duplicate();
AbstractBeginNode newBegin = BeginNode.begin(duplicateLoop.entryPoint());
position.set(newControlSplit, newBegin);
for (ControlSplitNode controlSplitNode : controlSplitNodeSet) {
ControlSplitNode duplicatedControlSplit = duplicateLoop.getDuplicatedNode(controlSplitNode);
if (duplicatedControlSplit.isAlive()) {
AbstractBeginNode survivingSuccessor = (AbstractBeginNode) position.get(duplicatedControlSplit);
survivingSuccessor.replaceAtUsages(InputType.Guard, newBegin);
graph.removeSplitPropagate(duplicatedControlSplit, survivingSuccessor);
}
}
}
for (ControlSplitNode controlSplitNode : controlSplitNodeSet) {
if (controlSplitNode.isAlive()) {
AbstractBeginNode survivingSuccessor = (AbstractBeginNode) firstPosition.get(controlSplitNode);
survivingSuccessor.replaceAtUsages(InputType.Guard, originalLoopBegin);
graph.removeSplitPropagate(controlSplitNode, survivingSuccessor);
}
}
}
public static void partialUnroll(LoopEx loop) {
assert loop.loopBegin().isMainLoop();
loop.loopBegin().graph().getDebug().log("LoopPartialUnroll %s", loop);
LoopFragmentInside newSegment = loop.inside().duplicate();
newSegment.insertWithinAfter(loop);
}
public static LoopBeginNode insertPrePostLoops(LoopEx loop) {
StructuredGraph graph = loop.loopBegin().graph();
graph.getDebug().log("LoopTransformations.insertPrePostLoops %s", loop);
LoopFragmentWhole preLoop = loop.whole();
CountedLoopInfo preCounted = loop.counted();
IfNode preLimit = preCounted.getLimitTest();
assert preLimit != null;
LoopBeginNode preLoopBegin = loop.loopBegin();
InductionVariable preIv = preCounted.getCounter();
LoopExitNode preLoopExitNode = preLoopBegin.getSingleLoopExit();
FixedNode continuationNode = preLoopExitNode.next();
LoopFragmentWhole mainLoop = preLoop.duplicate();
LoopFragmentWhole postLoop = preLoop.duplicate();
preLoopBegin.incrementSplits();
preLoopBegin.incrementSplits();
preLoopBegin.setPreLoop();
graph.getDebug().dump(DebugContext.VERBOSE_LEVEL, graph, "After duplication");
LoopBeginNode mainLoopBegin = mainLoop.getDuplicatedNode(preLoopBegin);
mainLoopBegin.setMainLoop();
LoopBeginNode postLoopBegin = postLoop.getDuplicatedNode(preLoopBegin);
postLoopBegin.setPostLoop();
EndNode postEndNode = getBlockEndAfterLoopExit(postLoopBegin);
AbstractMergeNode postMergeNode = postEndNode.merge();
LoopExitNode postLoopExitNode = postLoopBegin.getSingleLoopExit();
for (PhiNode prePhiNode : preLoopBegin.phis()) {
PhiNode mainPhiNode = mainLoop.getDuplicatedNode(prePhiNode);
mainPhiNode.setValueAt(0, prePhiNode);
}
EndNode mainEndNode = getBlockEndAfterLoopExit(mainLoopBegin);
AbstractMergeNode mainMergeNode = mainEndNode.merge();
AbstractEndNode postEntryNode = postLoopBegin.forwardEnd();
AbstractBeginNode mainLandingNode = BeginNode.begin(postEntryNode);
LoopExitNode mainLoopExitNode = mainLoopBegin.getSingleLoopExit();
mainLoopExitNode.setNext(mainLandingNode);
preLoopExitNode.setNext(mainLoopBegin.forwardEnd());
processPreLoopPhis(loop, mainLoop, postLoop);
continuationNode.predecessor().clearSuccessors();
postLoopExitNode.setNext(continuationNode);
cleanupMerge(postMergeNode, postLoopExitNode);
cleanupMerge(mainMergeNode, mainLandingNode);
updateMainLoopLimit(preLimit, preIv, mainLoop);
updatePreLoopLimit(preLimit, preIv, preCounted);
preLoopBegin.setLoopFrequency(1);
mainLoopBegin.setLoopFrequency(Math.max(0.0, mainLoopBegin.loopFrequency() - 2));
postLoopBegin.setLoopFrequency(Math.max(0.0, postLoopBegin.loopFrequency() - 1));
for (SafepointNode safepoint : preLoop.nodes().filter(SafepointNode.class)) {
graph.removeFixed(safepoint);
}
for (SafepointNode safepoint : postLoop.nodes().filter(SafepointNode.class)) {
graph.removeFixed(safepoint);
}
graph.getDebug().dump(DebugContext.DETAILED_LEVEL, graph, "InsertPrePostLoops %s", loop);
return mainLoopBegin;
}
private static void cleanupMerge(AbstractMergeNode mergeNode, AbstractBeginNode landingNode) {
for (EndNode end : mergeNode.cfgPredecessors().snapshot()) {
mergeNode.removeEnd(end);
end.safeDelete();
}
mergeNode.prepareDelete(landingNode);
mergeNode.safeDelete();
}
private static void processPreLoopPhis(LoopEx preLoop, LoopFragmentWhole mainLoop, LoopFragmentWhole postLoop) {
LoopBeginNode preLoopBegin = preLoop.loopBegin();
for (PhiNode prePhiNode : preLoopBegin.phis()) {
PhiNode postPhiNode = postLoop.getDuplicatedNode(prePhiNode);
PhiNode mainPhiNode = mainLoop.getDuplicatedNode(prePhiNode);
postPhiNode.setValueAt(0, mainPhiNode);
for (Node usage : prePhiNode.usages().snapshot()) {
if (usage == mainPhiNode) {
continue;
}
if (preLoop.isOutsideLoop(usage)) {
usage.replaceFirstInput(prePhiNode, postPhiNode);
}
}
}
for (Node node : preLoop.inside().nodes()) {
for (Node externalUsage : node.usages().snapshot()) {
if (preLoop.isOutsideLoop(externalUsage)) {
Node postUsage = postLoop.getDuplicatedNode(node);
assert postUsage != null;
externalUsage.replaceFirstInput(node, postUsage);
}
}
}
}
private static EndNode getBlockEndAfterLoopExit(LoopBeginNode curLoopBegin) {
FixedNode node = curLoopBegin.getSingleLoopExit().next();
return getBlockEnd(node);
}
private static EndNode getBlockEnd(FixedNode node) {
FixedNode curNode = node;
while (curNode instanceof FixedWithNextNode) {
curNode = ((FixedWithNextNode) curNode).next();
}
return (EndNode) curNode;
}
private static void updateMainLoopLimit(IfNode preLimit, InductionVariable preIv, LoopFragmentWhole mainLoop) {
StructuredGraph graph = preLimit.graph();
IfNode mainLimit = mainLoop.getDuplicatedNode(preLimit);
LogicNode ifTest = mainLimit.condition();
CompareNode compareNode = (CompareNode) ifTest;
ValueNode prePhi = preIv.valueNode();
ValueNode mainPhi = mainLoop.getDuplicatedNode(prePhi);
ValueNode preStride = preIv.strideNode();
ValueNode mainStride;
if (preStride instanceof ConstantNode) {
mainStride = preStride;
} else {
mainStride = mainLoop.getDuplicatedNode(preStride);
}
ValueNode ub = null;
if (compareNode.getX() == mainPhi) {
ub = compareNode.getY();
} else if (compareNode.getY() == mainPhi) {
ub = compareNode.getX();
} else {
throw GraalError.shouldNotReachHere();
}
ValueNode newLimit = sub(graph, ub, mainStride);
compareNode.replaceFirstInput(ub, newLimit);
}
private static void updatePreLoopLimit(IfNode preLimit, InductionVariable preIv, CountedLoopInfo preCounted) {
StructuredGraph graph = preLimit.graph();
LogicNode ifTest = preLimit.condition();
CompareNode compareNode = (CompareNode) ifTest;
ValueNode prePhi = preIv.valueNode();
ValueNode initIv = preCounted.getStart();
ValueNode newLimit = add(graph, initIv, preIv.strideNode());
ValueNode ub;
if (compareNode.getX() == prePhi) {
ub = compareNode.getY();
} else if (compareNode.getY() == prePhi) {
ub = compareNode.getX();
} else {
throw GraalError.shouldNotReachHere();
}
if (preIv.direction() == Direction.Up) {
compareNode.replaceFirstInput(ub, graph.unique(new ConditionalNode(graph.unique(new IntegerLessThanNode(newLimit, ub)), newLimit, ub)));
} else {
compareNode.replaceFirstInput(ub, graph.unique(new ConditionalNode(graph.unique(new IntegerLessThanNode(ub, newLimit)), newLimit, ub)));
}
}
public static List<ControlSplitNode> findUnswitchable(LoopEx loop) {
List<ControlSplitNode> controls = null;
ValueNode invariantValue = null;
for (IfNode ifNode : loop.whole().nodes().filter(IfNode.class)) {
if (loop.isOutsideLoop(ifNode.condition())) {
if (controls == null) {
invariantValue = ifNode.condition();
controls = new ArrayList<>();
controls.add(ifNode);
} else if (ifNode.condition() == invariantValue) {
controls.add(ifNode);
}
}
}
if (controls == null) {
SwitchNode firstSwitch = null;
for (SwitchNode switchNode : loop.whole().nodes().filter(SwitchNode.class)) {
if (switchNode.successors().count() > 1 && loop.isOutsideLoop(switchNode.value())) {
if (controls == null) {
firstSwitch = switchNode;
invariantValue = switchNode.value();
controls = new ArrayList<>();
controls.add(switchNode);
} else if (switchNode.value() == invariantValue && firstSwitch.structureEquals(switchNode)) {
controls.add(switchNode);
}
}
}
}
return controls;
}
public static boolean isUnrollableLoop(LoopEx loop) {
if (!loop.isCounted() || !loop.counted().getCounter().isConstantStride() || !loop.loop().getChildren().isEmpty()) {
return false;
}
LoopBeginNode loopBegin = loop.loopBegin();
LogicNode condition = loop.counted().getLimitTest().condition();
if (!(condition instanceof CompareNode)) {
return false;
}
if (((CompareNode) condition).condition() == Condition.EQ || ((CompareNode) condition).condition() == Condition.NE) {
condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s condition unsupported %s ", loopBegin, ((CompareNode) condition).condition());
return false;
}
if (loopBegin.isMainLoop() || loopBegin.isSimpleLoop()) {
if (loop.loop().getBlocks().size() < 3) {
return true;
}
condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s too large to unroll %s ", loopBegin, loop.loop().getBlocks().size());
}
return false;
}
}