package org.graalvm.compiler.loop.phases;
import static org.graalvm.compiler.core.common.GraalOptions.MaximumDesiredSize;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import jdk.internal.vm.compiler.collections.EconomicMap;
import org.graalvm.compiler.core.common.RetryableBailoutException;
import org.graalvm.compiler.core.common.calc.CanonicalCondition;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.graph.Graph.Mark;
import org.graalvm.compiler.graph.Graph.NodeEventScope;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.Position;
import org.graalvm.compiler.graph.spi.Simplifiable;
import org.graalvm.compiler.graph.spi.SimplifierTool;
import org.graalvm.compiler.loop.CountedLoopInfo;
import org.graalvm.compiler.loop.DefaultLoopPolicies;
import org.graalvm.compiler.loop.InductionVariable.Direction;
import org.graalvm.compiler.loop.LoopEx;
import org.graalvm.compiler.loop.LoopFragmentInside;
import org.graalvm.compiler.loop.LoopFragmentWhole;
import org.graalvm.compiler.nodeinfo.InputType;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.AbstractEndNode;
import org.graalvm.compiler.nodes.AbstractMergeNode;
import org.graalvm.compiler.nodes.BeginNode;
import org.graalvm.compiler.nodes.ControlSplitNode;
import org.graalvm.compiler.nodes.EndNode;
import org.graalvm.compiler.nodes.FixedGuardNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.LogicNode;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.SafepointNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.AddNode;
import org.graalvm.compiler.nodes.calc.CompareNode;
import org.graalvm.compiler.nodes.calc.ConditionalNode;
import org.graalvm.compiler.nodes.extended.OpaqueNode;
import org.graalvm.compiler.nodes.extended.SwitchNode;
import org.graalvm.compiler.nodes.spi.CoreProviders;
import org.graalvm.compiler.nodes.util.GraphUtil;
import org.graalvm.compiler.nodes.util.IntegerHelper;
import org.graalvm.compiler.phases.common.CanonicalizerPhase;
import org.graalvm.compiler.phases.common.util.EconomicSetNodeEventListener;
public abstract class LoopTransformations {
private LoopTransformations() {
}
public static void peel(LoopEx loop) {
loop.detectCounted();
loop.inside().duplicate().insertBefore(loop);
if (loop.isCounted()) {
loop.loopBegin().setLoopFrequency(Math.max(1.0, loop.loopBegin().loopFrequency() - 1));
}
loop.loopBegin().incrementPeelings();
}
@SuppressWarnings("try")
public static void fullUnroll(LoopEx loop, CoreProviders context, CanonicalizerPhase canonicalizer) {
LoopBeginNode loopBegin = loop.loopBegin();
StructuredGraph graph = loopBegin.graph();
int initialNodeCount = graph.getNodeCount();
SimplifierTool defaultSimplifier = GraphUtil.getDefaultSimplifier(context, canonicalizer.getCanonicalizeReads(), graph.getAssumptions(), graph.getOptions());
CanonicalizerPhase c = canonicalizer.copyWithoutSimplification();
EconomicSetNodeEventListener l = new EconomicSetNodeEventListener();
int peelings = 0;
try (NodeEventScope ev = graph.trackNodeEvents(l)) {
while (!loopBegin.isDeleted()) {
Mark newNodes = graph.getMark();
EconomicSetNodeEventListener peeledListener = new EconomicSetNodeEventListener();
try (NodeEventScope peeledScope = graph.trackNodeEvents(peeledListener)) {
LoopTransformations.peel(loop);
}
graph.getDebug().dump(DebugContext.VERY_DETAILED_LEVEL, graph, "After peeling loop %s", loop);
c.applyIncremental(graph, context, peeledListener.getNodes());
loop.invalidateFragments();
for (Node n : graph.getNewNodes(newNodes)) {
if (n.isAlive() && (n instanceof IfNode || n instanceof SwitchNode || n instanceof FixedGuardNode || n instanceof BeginNode)) {
Simplifiable s = (Simplifiable) n;
s.simplify(defaultSimplifier);
graph.getDebug().dump(DebugContext.VERY_DETAILED_LEVEL, graph, "After simplifying if %s", s);
}
}
if (graph.getNodeCount() > initialNodeCount + MaximumDesiredSize.getValue(graph.getOptions()) * 2 ||
peelings > DefaultLoopPolicies.Options.FullUnrollMaxIterations.getValue(graph.getOptions())) {
throw new RetryableBailoutException("FullUnroll : Graph seems to grow out of proportion");
}
peelings++;
}
}
canonicalizer.applyIncremental(graph, context, l.getNodes());
}
public static void unswitch(LoopEx loop, List<ControlSplitNode> controlSplitNodeSet) {
ControlSplitNode firstNode = controlSplitNodeSet.iterator().next();
LoopFragmentWhole originalLoop = loop.whole();
StructuredGraph graph = firstNode.graph();
loop.loopBegin().incrementUnswitches();
ControlSplitNode newControlSplit = (ControlSplitNode) firstNode.copyWithInputs();
originalLoop.entryPoint().replaceAtPredecessor(newControlSplit);
Iterator<Position> successors = firstNode.successorPositions().iterator();
assert successors.hasNext();
Position firstPosition = successors.next();
AbstractBeginNode originalLoopBegin = BeginNode.begin(originalLoop.entryPoint());
firstPosition.set(newControlSplit, originalLoopBegin);
originalLoopBegin.setNodeSourcePosition(firstPosition.get(firstNode).getNodeSourcePosition());
while (successors.hasNext()) {
Position position = successors.next();
LoopFragmentWhole duplicateLoop = originalLoop.duplicate();
AbstractBeginNode newBegin = BeginNode.begin(duplicateLoop.entryPoint());
newBegin.setNodeSourcePosition(position.get(firstNode).getNodeSourcePosition());
position.set(newControlSplit, newBegin);
for (ControlSplitNode controlSplitNode : controlSplitNodeSet) {
ControlSplitNode duplicatedControlSplit = duplicateLoop.getDuplicatedNode(controlSplitNode);
if (duplicatedControlSplit.isAlive()) {
AbstractBeginNode survivingSuccessor = (AbstractBeginNode) position.get(duplicatedControlSplit);
survivingSuccessor.replaceAtUsages(newBegin, InputType.Guard);
graph.removeSplitPropagate(duplicatedControlSplit, survivingSuccessor);
}
}
}
for (ControlSplitNode controlSplitNode : controlSplitNodeSet) {
if (controlSplitNode.isAlive()) {
AbstractBeginNode survivingSuccessor = (AbstractBeginNode) firstPosition.get(controlSplitNode);
survivingSuccessor.replaceAtUsages(originalLoopBegin, InputType.Guard);
graph.removeSplitPropagate(controlSplitNode, survivingSuccessor);
}
}
}
public static void partialUnroll(LoopEx loop, EconomicMap<LoopBeginNode, OpaqueNode> opaqueUnrolledStrides) {
assert loop.loopBegin().isMainLoop();
loop.loopBegin().graph().getDebug().log("LoopPartialUnroll %s", loop);
LoopFragmentInside newSegment = loop.inside().duplicate();
newSegment.insertWithinAfter(loop, opaqueUnrolledStrides);
}
public static LoopBeginNode insertPrePostLoops(LoopEx loop) {
StructuredGraph graph = loop.loopBegin().graph();
graph.getDebug().log("LoopTransformations.insertPrePostLoops %s", loop);
LoopFragmentWhole preLoop = loop.whole();
CountedLoopInfo preCounted = loop.counted();
LoopBeginNode preLoopBegin = loop.loopBegin();
AbstractBeginNode preLoopExitNode = preCounted.getCountedExit();
assert preLoop.nodes().contains(preLoopBegin);
assert preLoop.nodes().contains(preLoopExitNode);
LoopFragmentWhole mainLoop = preLoop.duplicate();
LoopFragmentWhole postLoop = preLoop.duplicate();
preLoopBegin.incrementSplits();
preLoopBegin.incrementSplits();
preLoopBegin.setPreLoop();
graph.getDebug().dump(DebugContext.VERBOSE_LEVEL, graph, "After duplication");
LoopBeginNode mainLoopBegin = mainLoop.getDuplicatedNode(preLoopBegin);
mainLoopBegin.setMainLoop();
LoopBeginNode postLoopBegin = postLoop.getDuplicatedNode(preLoopBegin);
postLoopBegin.setPostLoop();
AbstractBeginNode postLoopExitNode = postLoop.getDuplicatedNode(preLoopExitNode);
EndNode postEndNode = getBlockEndAfterLoopExit(postLoopExitNode);
AbstractMergeNode postMergeNode = postEndNode.merge();
for (PhiNode prePhiNode : preLoopBegin.phis()) {
PhiNode mainPhiNode = mainLoop.getDuplicatedNode(prePhiNode);
mainPhiNode.setValueAt(0, prePhiNode);
}
AbstractBeginNode mainLoopExitNode = mainLoop.getDuplicatedNode(preLoopExitNode);
EndNode mainEndNode = getBlockEndAfterLoopExit(mainLoopExitNode);
AbstractMergeNode mainMergeNode = mainEndNode.merge();
AbstractEndNode postEntryNode = postLoopBegin.forwardEnd();
FixedNode continuationNode = mainMergeNode.next();
AbstractBeginNode mainLandingNode = BeginNode.begin(postEntryNode);
mainLoopExitNode.setNext(mainLandingNode);
preLoopExitNode.setNext(mainLoopBegin.forwardEnd());
processPreLoopPhis(loop, mainLoop, postLoop);
continuationNode.predecessor().clearSuccessors();
postLoopExitNode.setNext(continuationNode);
cleanupMerge(postMergeNode, postLoopExitNode);
cleanupMerge(mainMergeNode, mainLandingNode);
updatePreLoopLimit(preCounted);
preLoopBegin.setLoopFrequency(1.0);
mainLoopBegin.setLoopFrequency(Math.max(1.0, mainLoopBegin.loopFrequency() - 2));
postLoopBegin.setLoopFrequency(Math.max(1.0, postLoopBegin.loopFrequency() - 1));
for (SafepointNode safepoint : preLoop.nodes().filter(SafepointNode.class)) {
graph.removeFixed(safepoint);
}
for (SafepointNode safepoint : postLoop.nodes().filter(SafepointNode.class)) {
graph.removeFixed(safepoint);
}
graph.getDebug().dump(DebugContext.DETAILED_LEVEL, graph, "InsertPrePostLoops %s", loop);
return mainLoopBegin;
}
private static void cleanupMerge(AbstractMergeNode mergeNode, AbstractBeginNode landingNode) {
for (EndNode end : mergeNode.cfgPredecessors().snapshot()) {
mergeNode.removeEnd(end);
end.safeDelete();
}
mergeNode.prepareDelete(landingNode);
mergeNode.safeDelete();
}
private static void processPreLoopPhis(LoopEx preLoop, LoopFragmentWhole mainLoop, LoopFragmentWhole postLoop) {
LoopBeginNode preLoopBegin = preLoop.loopBegin();
for (PhiNode prePhiNode : preLoopBegin.phis()) {
PhiNode postPhiNode = postLoop.getDuplicatedNode(prePhiNode);
PhiNode mainPhiNode = mainLoop.getDuplicatedNode(prePhiNode);
postPhiNode.setValueAt(0, mainPhiNode);
for (Node usage : prePhiNode.usages().snapshot()) {
if (usage == mainPhiNode) {
continue;
}
if (preLoop.isOutsideLoop(usage)) {
usage.replaceFirstInput(prePhiNode, postPhiNode);
}
}
}
for (Node node : preLoop.inside().nodes()) {
for (Node externalUsage : node.usages().snapshot()) {
if (preLoop.isOutsideLoop(externalUsage)) {
Node postUsage = postLoop.getDuplicatedNode(node);
assert postUsage != null;
externalUsage.replaceFirstInput(node, postUsage);
}
}
}
}
private static EndNode getBlockEndAfterLoopExit(AbstractBeginNode exit) {
FixedNode node = exit.next();
return getBlockEnd(node);
}
private static EndNode getBlockEnd(FixedNode node) {
FixedNode curNode = node;
while (curNode instanceof FixedWithNextNode) {
curNode = ((FixedWithNextNode) curNode).next();
}
return (EndNode) curNode;
}
private static void updatePreLoopLimit(CountedLoopInfo preCounted) {
ValueNode newLimit = AddNode.add(preCounted.getStart(), preCounted.getCounter().strideNode(), NodeView.DEFAULT);
ValueNode ub = preCounted.getLimit();
IntegerHelper helper = preCounted.getCounterIntegerHelper();
LogicNode entryCheck;
if (preCounted.getDirection() == Direction.Up) {
entryCheck = helper.createCompareNode(newLimit, ub, NodeView.DEFAULT);
} else {
entryCheck = helper.createCompareNode(ub, newLimit, NodeView.DEFAULT);
}
newLimit = ConditionalNode.create(entryCheck, newLimit, ub, NodeView.DEFAULT);
CompareNode compareNode = (CompareNode) preCounted.getLimitTest().condition();
compareNode.replaceFirstInput(ub, compareNode.graph().addOrUniqueWithInputs(newLimit));
}
public static List<ControlSplitNode> findUnswitchable(LoopEx loop) {
List<ControlSplitNode> controls = null;
ValueNode invariantValue = null;
for (IfNode ifNode : loop.whole().nodes().filter(IfNode.class)) {
if (loop.isOutsideLoop(ifNode.condition())) {
if (controls == null) {
invariantValue = ifNode.condition();
controls = new ArrayList<>();
controls.add(ifNode);
} else if (ifNode.condition() == invariantValue) {
controls.add(ifNode);
}
}
}
if (controls == null) {
SwitchNode firstSwitch = null;
for (SwitchNode switchNode : loop.whole().nodes().filter(SwitchNode.class)) {
if (switchNode.successors().count() > 1 && loop.isOutsideLoop(switchNode.value())) {
if (controls == null) {
firstSwitch = switchNode;
invariantValue = switchNode.value();
controls = new ArrayList<>();
controls.add(switchNode);
} else if (switchNode.value() == invariantValue) {
assert firstSwitch != null;
if (firstSwitch.structureEquals(switchNode)) {
controls.add(switchNode);
}
}
}
}
}
return controls;
}
public static boolean isUnrollableLoop(LoopEx loop) {
if (!loop.isCounted() || !loop.counted().getCounter().isConstantStride() || !loop.loop().getChildren().isEmpty()) {
return false;
}
assert loop.counted().getDirection() != null;
LoopBeginNode loopBegin = loop.loopBegin();
LogicNode condition = loop.counted().getLimitTest().condition();
if (!(condition instanceof CompareNode)) {
return false;
}
if (((CompareNode) condition).condition() == CanonicalCondition.EQ) {
condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s condition unsupported %s ", loopBegin, ((CompareNode) condition).condition());
return false;
}
long stride = loop.counted().getCounter().constantStride();
try {
Math.addExact(stride, stride);
} catch (ArithmeticException ae) {
condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s doubling the stride overflows %d", loopBegin, stride);
return false;
}
if (!loop.canDuplicateLoop()) {
return false;
}
if (loopBegin.isMainLoop() || loopBegin.isSimpleLoop()) {
if (loop.loop().getBlocks().size() < 3) {
return true;
}
condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s too large to unroll %s ", loopBegin, loop.loop().getBlocks().size());
}
return false;
}
}