package org.graalvm.compiler.phases.graph;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.function.Predicate;
import org.graalvm.compiler.core.common.PermanentBailoutException;
import org.graalvm.compiler.core.common.RetryableBailoutException;
import org.graalvm.compiler.core.common.cfg.Loop;
import org.graalvm.compiler.core.common.util.CompilationAlarm;
import org.graalvm.compiler.nodes.AbstractEndNode;
import org.graalvm.compiler.nodes.AbstractMergeNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.cfg.Block;
import org.graalvm.util.Equivalence;
import org.graalvm.util.EconomicMap;
public final class ReentrantBlockIterator {
public static class LoopInfo<StateT> {
public final List<StateT> endStates;
public final List<StateT> exitStates;
public LoopInfo(int endCount, int exitCount) {
endStates = new ArrayList<>(endCount);
exitStates = new ArrayList<>(exitCount);
}
}
public abstract static class BlockIteratorClosure<StateT> {
protected abstract StateT getInitialState();
protected abstract StateT processBlock(Block block, StateT currentState);
protected abstract StateT merge(Block merge, List<StateT> states);
protected abstract StateT cloneState(StateT oldState);
protected abstract List<StateT> processLoop(Loop<Block> loop, StateT initialState);
}
private ReentrantBlockIterator() {
}
public static <StateT> LoopInfo<StateT> processLoop(BlockIteratorClosure<StateT> closure, Loop<Block> loop, StateT initialState) {
EconomicMap<FixedNode, StateT> blockEndStates = apply(closure, loop.getHeader(), initialState, block -> !(block.getLoop() == loop || block.isLoopHeader()));
Block[] predecessors = loop.getHeader().getPredecessors();
LoopInfo<StateT> info = new LoopInfo<>(predecessors.length - 1, loop.getExits().size());
for (int i = 1; i < predecessors.length; i++) {
StateT endState = blockEndStates.get(predecessors[i].getEndNode());
info.endStates.add(closure.cloneState(endState));
}
for (Block loopExit : loop.getExits()) {
assert loopExit.getPredecessorCount() == 1;
assert blockEndStates.containsKey(loopExit.getBeginNode()) : loopExit.getBeginNode() + " " + blockEndStates;
StateT exitState = blockEndStates.get(loopExit.getBeginNode());
info.exitStates.add(closure.cloneState(exitState));
}
return info;
}
public static <StateT> void apply(BlockIteratorClosure<StateT> closure, Block start) {
apply(closure, start, closure.getInitialState(), null);
}
public static <StateT> EconomicMap<FixedNode, StateT> apply(BlockIteratorClosure<StateT> closure, Block start, StateT initialState, Predicate<Block> stopAtBlock) {
Deque<Block> blockQueue = new ArrayDeque<>();
EconomicMap<FixedNode, StateT> states = EconomicMap.create(Equivalence.IDENTITY);
StateT state = initialState;
Block current = start;
StructuredGraph graph = start.getBeginNode().graph();
CompilationAlarm compilationAlarm = CompilationAlarm.current();
while (true) {
if (compilationAlarm.hasExpired()) {
int period = CompilationAlarm.Options.CompilationExpirationPeriod.getValue(graph.getOptions());
if (period > 120) {
throw new PermanentBailoutException("Compilation exceeded %d seconds during CFG traversal", period);
} else {
throw new RetryableBailoutException("Compilation exceeded %d seconds during CFG traversal", period);
}
}
Block next = null;
if (stopAtBlock != null && stopAtBlock.test(current)) {
states.put(current.getBeginNode(), state);
} else {
state = closure.processBlock(current, state);
Block[] successors = current.getSuccessors();
if (successors.length == 0) {
} else if (successors.length == 1) {
Block successor = successors[0];
if (successor.isLoopHeader()) {
if (current.isLoopEnd()) {
states.put(current.getEndNode(), state);
} else {
recurseIntoLoop(closure, blockQueue, states, state, successor);
}
} else if (current.getEndNode() instanceof AbstractEndNode) {
AbstractEndNode end = (AbstractEndNode) current.getEndNode();
AbstractMergeNode merge = end.merge();
if (allEndsVisited(states, current, merge)) {
ArrayList<StateT> mergedStates = mergeStates(states, state, current, successor, merge);
state = closure.merge(successor, mergedStates);
next = successor;
} else {
assert !states.containsKey(end);
states.put(end, state);
}
} else {
next = successor;
}
} else {
next = processMultipleSuccessors(closure, blockQueue, states, state, successors);
}
}
if (next != null) {
current = next;
} else if (blockQueue.isEmpty()) {
return states;
} else {
current = blockQueue.removeFirst();
assert current.getPredecessorCount() == 1;
assert states.containsKey(current.getBeginNode());
state = states.removeKey(current.getBeginNode());
}
}
}
private static <StateT> boolean allEndsVisited(EconomicMap<FixedNode, StateT> states, Block current, AbstractMergeNode merge) {
for (AbstractEndNode forwardEnd : merge.forwardEnds()) {
if (forwardEnd != current.getEndNode() && !states.containsKey(forwardEnd)) {
return false;
}
}
return true;
}
private static <StateT> Block processMultipleSuccessors(BlockIteratorClosure<StateT> closure, Deque<Block> blockQueue, EconomicMap<FixedNode, StateT> states, StateT state, Block[] successors) {
assert successors.length > 1;
for (int i = 1; i < successors.length; i++) {
Block successor = successors[i];
blockQueue.addFirst(successor);
states.put(successor.getBeginNode(), closure.cloneState(state));
}
return successors[0];
}
private static <StateT> ArrayList<StateT> mergeStates(EconomicMap<FixedNode, StateT> states, StateT state, Block current, Block successor, AbstractMergeNode merge) {
ArrayList<StateT> mergedStates = new ArrayList<>(merge.forwardEndCount());
for (Block predecessor : successor.getPredecessors()) {
assert predecessor == current || states.containsKey(predecessor.getEndNode());
StateT endState = predecessor == current ? state : states.removeKey(predecessor.getEndNode());
mergedStates.add(endState);
}
return mergedStates;
}
private static <StateT> void recurseIntoLoop(BlockIteratorClosure<StateT> closure, Deque<Block> blockQueue, EconomicMap<FixedNode, StateT> states, StateT state, Block successor) {
Loop<Block> loop = successor.getLoop();
LoopBeginNode loopBegin = (LoopBeginNode) loop.getHeader().getBeginNode();
assert successor.getBeginNode() == loopBegin;
List<StateT> exitStates = closure.processLoop(loop, state);
int i = 0;
assert loop.getExits().size() == exitStates.size();
for (Block exit : loop.getExits()) {
states.put(exit.getBeginNode(), exitStates.get(i++));
blockQueue.addFirst(exit);
}
}
}