package org.graalvm.compiler.phases.common;
import static org.graalvm.compiler.graph.Graph.NodeEvent.NODE_ADDED;
import static org.graalvm.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
import static jdk.internal.vm.compiler.word.LocationIdentity.any;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import jdk.internal.vm.compiler.collections.EconomicMap;
import jdk.internal.vm.compiler.collections.EconomicSet;
import jdk.internal.vm.compiler.collections.Equivalence;
import jdk.internal.vm.compiler.collections.UnmodifiableMapCursor;
import org.graalvm.compiler.core.common.cfg.Loop;
import org.graalvm.compiler.debug.DebugCloseable;
import org.graalvm.compiler.graph.Graph.NodeEventScope;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.AbstractMergeNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.InvokeWithExceptionNode;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.LoopEndNode;
import org.graalvm.compiler.nodes.LoopExitNode;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.ReturnNode;
import org.graalvm.compiler.nodes.StartNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNodeUtil;
import org.graalvm.compiler.nodes.calc.FloatingNode;
import org.graalvm.compiler.nodes.cfg.Block;
import org.graalvm.compiler.nodes.cfg.ControlFlowGraph;
import org.graalvm.compiler.nodes.cfg.HIRLoop;
import org.graalvm.compiler.nodes.memory.FloatableAccessNode;
import org.graalvm.compiler.nodes.memory.FloatingAccessNode;
import org.graalvm.compiler.nodes.memory.FloatingReadNode;
import org.graalvm.compiler.nodes.memory.MemoryAccess;
import org.graalvm.compiler.nodes.memory.MemoryAnchorNode;
import org.graalvm.compiler.nodes.memory.MemoryCheckpoint;
import org.graalvm.compiler.nodes.memory.MemoryMap;
import org.graalvm.compiler.nodes.memory.MemoryMapNode;
import org.graalvm.compiler.nodes.memory.MemoryNode;
import org.graalvm.compiler.nodes.memory.MemoryPhiNode;
import org.graalvm.compiler.nodes.memory.ReadNode;
import org.graalvm.compiler.nodes.util.GraphUtil;
import org.graalvm.compiler.phases.Phase;
import org.graalvm.compiler.phases.common.util.HashSetNodeEventListener;
import org.graalvm.compiler.phases.graph.ReentrantNodeIterator;
import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.LoopInfo;
import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.NodeIteratorClosure;
import jdk.internal.vm.compiler.word.LocationIdentity;
public class FloatingReadPhase extends Phase {
private boolean createFloatingReads;
private boolean createMemoryMapNodes;
public static class MemoryMapImpl implements MemoryMap {
private final EconomicMap<LocationIdentity, MemoryNode> lastMemorySnapshot;
public MemoryMapImpl(MemoryMapImpl memoryMap) {
lastMemorySnapshot = EconomicMap.create(Equivalence.DEFAULT, memoryMap.lastMemorySnapshot);
}
public MemoryMapImpl(StartNode start) {
this();
lastMemorySnapshot.put(any(), start);
}
public MemoryMapImpl() {
lastMemorySnapshot = EconomicMap.create(Equivalence.DEFAULT);
}
@Override
public MemoryNode getLastLocationAccess(LocationIdentity locationIdentity) {
MemoryNode lastLocationAccess;
if (locationIdentity.isImmutable()) {
return null;
} else {
lastLocationAccess = lastMemorySnapshot.get(locationIdentity);
if (lastLocationAccess == null) {
lastLocationAccess = lastMemorySnapshot.get(any());
assert lastLocationAccess != null;
}
return lastLocationAccess;
}
}
@Override
public Iterable<LocationIdentity> getLocations() {
return lastMemorySnapshot.getKeys();
}
public EconomicMap<LocationIdentity, MemoryNode> getMap() {
return lastMemorySnapshot;
}
}
public FloatingReadPhase() {
this(true, false);
}
public FloatingReadPhase(boolean createFloatingReads, boolean createMemoryMapNodes) {
this.createFloatingReads = createFloatingReads;
this.createMemoryMapNodes = createMemoryMapNodes;
}
@Override
public float codeSizeIncrease() {
return 1.25f;
}
private static EconomicSet<Node> removeExternallyUsedNodes(EconomicSet<Node> set) {
boolean change;
do {
change = false;
for (Iterator<Node> iter = set.iterator(); iter.hasNext();) {
Node node = iter.next();
for (Node usage : node.usages()) {
if (!set.contains(usage)) {
change = true;
iter.remove();
break;
}
}
}
} while (change);
return set;
}
protected void processNode(FixedNode node, EconomicSet<LocationIdentity> currentState) {
if (node instanceof MemoryCheckpoint.Single) {
processIdentity(currentState, ((MemoryCheckpoint.Single) node).getLocationIdentity());
} else if (node instanceof MemoryCheckpoint.Multi) {
for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) {
processIdentity(currentState, identity);
}
}
}
private static void processIdentity(EconomicSet<LocationIdentity> currentState, LocationIdentity identity) {
if (identity.isMutable()) {
currentState.add(identity);
}
}
protected void processBlock(Block b, EconomicSet<LocationIdentity> currentState) {
for (FixedNode n : b.getNodes()) {
processNode(n, currentState);
}
}
private EconomicSet<LocationIdentity> processLoop(HIRLoop loop, EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops) {
LoopBeginNode loopBegin = (LoopBeginNode) loop.getHeader().getBeginNode();
EconomicSet<LocationIdentity> result = modifiedInLoops.get(loopBegin);
if (result != null) {
return result;
}
result = EconomicSet.create(Equivalence.DEFAULT);
for (Loop<Block> inner : loop.getChildren()) {
result.addAll(processLoop((HIRLoop) inner, modifiedInLoops));
}
for (Block b : loop.getBlocks()) {
if (b.getLoop() == loop) {
processBlock(b, result);
}
}
modifiedInLoops.put(loopBegin, result);
return result;
}
@Override
@SuppressWarnings("try")
protected void run(StructuredGraph graph) {
EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops = null;
if (graph.hasLoops()) {
modifiedInLoops = EconomicMap.create(Equivalence.IDENTITY);
ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, false, false);
for (Loop<?> l : cfg.getLoops()) {
HIRLoop loop = (HIRLoop) l;
processLoop(loop, modifiedInLoops);
}
}
HashSetNodeEventListener listener = new HashSetNodeEventListener(EnumSet.of(NODE_ADDED, ZERO_USAGES));
try (NodeEventScope nes = graph.trackNodeEvents(listener)) {
ReentrantNodeIterator.apply(new FloatingReadClosure(modifiedInLoops, createFloatingReads, createMemoryMapNodes), graph.start(), new MemoryMapImpl(graph.start()));
}
for (Node n : removeExternallyUsedNodes(listener.getNodes())) {
if (n.isAlive() && n instanceof FloatingNode) {
n.replaceAtUsages(null);
GraphUtil.killWithUnusedFloatingInputs(n);
}
}
if (createFloatingReads) {
assert !graph.isAfterFloatingReadPhase();
graph.setAfterFloatingReadPhase(true);
}
}
public static MemoryMapImpl mergeMemoryMaps(AbstractMergeNode merge, List<? extends MemoryMap> states) {
MemoryMapImpl newState = new MemoryMapImpl();
EconomicSet<LocationIdentity> keys = EconomicSet.create(Equivalence.DEFAULT);
for (MemoryMap other : states) {
keys.addAll(other.getLocations());
}
assert checkNoImmutableLocations(keys);
for (LocationIdentity key : keys) {
int mergedStatesCount = 0;
boolean isPhi = false;
MemoryNode merged = null;
for (MemoryMap state : states) {
MemoryNode last = state.getLastLocationAccess(key);
if (isPhi) {
((MemoryPhiNode) merged).addInput(ValueNodeUtil.asNode(last));
} else {
if (merged == last) {
} else if (merged == null) {
merged = last;
} else {
MemoryPhiNode phi = merge.graph().addWithoutUnique(new MemoryPhiNode(merge, key));
for (int j = 0; j < mergedStatesCount; j++) {
phi.addInput(ValueNodeUtil.asNode(merged));
}
phi.addInput(ValueNodeUtil.asNode(last));
merged = phi;
isPhi = true;
}
}
mergedStatesCount++;
}
newState.lastMemorySnapshot.put(key, merged);
}
return newState;
}
private static boolean checkNoImmutableLocations(EconomicSet<LocationIdentity> keys) {
keys.forEach(t -> {
assert t.isMutable();
});
return true;
}
public static class FloatingReadClosure extends NodeIteratorClosure<MemoryMapImpl> {
private final EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops;
private boolean createFloatingReads;
private boolean createMemoryMapNodes;
public FloatingReadClosure(EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops, boolean createFloatingReads, boolean createMemoryMapNodes) {
this.modifiedInLoops = modifiedInLoops;
this.createFloatingReads = createFloatingReads;
this.createMemoryMapNodes = createMemoryMapNodes;
}
@Override
protected MemoryMapImpl processNode(FixedNode node, MemoryMapImpl state) {
if (node instanceof MemoryAnchorNode) {
processAnchor((MemoryAnchorNode) node, state);
return state;
}
if (node instanceof MemoryAccess) {
processAccess((MemoryAccess) node, state);
}
if (createFloatingReads & node instanceof FloatableAccessNode) {
processFloatable((FloatableAccessNode) node, state);
} else if (node instanceof MemoryCheckpoint.Single) {
processCheckpoint((MemoryCheckpoint.Single) node, state);
} else if (node instanceof MemoryCheckpoint.Multi) {
processCheckpoint((MemoryCheckpoint.Multi) node, state);
}
assert MemoryCheckpoint.TypeAssertion.correctType(node) : node;
if (createMemoryMapNodes && node instanceof ReturnNode) {
((ReturnNode) node).setMemoryMap(node.graph().unique(new MemoryMapNode(state.lastMemorySnapshot)));
}
return state;
}
private static void processAnchor(MemoryAnchorNode anchor, MemoryMapImpl state) {
for (Node node : anchor.usages().snapshot()) {
if (node instanceof MemoryAccess) {
MemoryAccess access = (MemoryAccess) node;
if (access.getLastLocationAccess() == anchor) {
MemoryNode lastLocationAccess = state.getLastLocationAccess(access.getLocationIdentity());
assert lastLocationAccess != null;
access.setLastLocationAccess(lastLocationAccess);
}
}
}
if (anchor.hasNoUsages()) {
anchor.graph().removeFixed(anchor);
}
}
private static void processAccess(MemoryAccess access, MemoryMapImpl state) {
LocationIdentity locationIdentity = access.getLocationIdentity();
if (!locationIdentity.equals(LocationIdentity.any())) {
MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
access.setLastLocationAccess(lastLocationAccess);
}
}
private static void processCheckpoint(MemoryCheckpoint.Single checkpoint, MemoryMapImpl state) {
processIdentity(checkpoint.getLocationIdentity(), checkpoint, state);
}
private static void processCheckpoint(MemoryCheckpoint.Multi checkpoint, MemoryMapImpl state) {
for (LocationIdentity identity : checkpoint.getLocationIdentities()) {
processIdentity(identity, checkpoint, state);
}
}
private static void processIdentity(LocationIdentity identity, MemoryCheckpoint checkpoint, MemoryMapImpl state) {
if (identity.isAny()) {
state.lastMemorySnapshot.clear();
}
if (identity.isMutable()) {
state.lastMemorySnapshot.put(identity, checkpoint);
}
}
@SuppressWarnings("try")
private static void processFloatable(FloatableAccessNode accessNode, MemoryMapImpl state) {
StructuredGraph graph = accessNode.graph();
LocationIdentity locationIdentity = accessNode.getLocationIdentity();
if (accessNode.canFloat()) {
assert accessNode.getNullCheck() == false;
MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
try (DebugCloseable position = accessNode.withNodeSourcePosition()) {
FloatingAccessNode floatingNode = accessNode.asFloatingNode(lastLocationAccess);
graph.replaceFixedWithFloating(accessNode, floatingNode);
}
}
}
@Override
protected MemoryMapImpl merge(AbstractMergeNode merge, List<MemoryMapImpl> states) {
return mergeMemoryMaps(merge, states);
}
@Override
protected MemoryMapImpl afterSplit(AbstractBeginNode node, MemoryMapImpl oldState) {
MemoryMapImpl result = new MemoryMapImpl(oldState);
if (node.predecessor() instanceof InvokeWithExceptionNode) {
InvokeWithExceptionNode invoke = (InvokeWithExceptionNode) node.predecessor();
result.lastMemorySnapshot.put(invoke.getLocationIdentity(), (MemoryCheckpoint) node);
}
return result;
}
@Override
protected EconomicMap<LoopExitNode, MemoryMapImpl> processLoop(LoopBeginNode loop, MemoryMapImpl initialState) {
EconomicSet<LocationIdentity> modifiedLocations = modifiedInLoops.get(loop);
EconomicMap<LocationIdentity, MemoryPhiNode> phis = EconomicMap.create(Equivalence.DEFAULT);
if (modifiedLocations.contains(LocationIdentity.any())) {
modifiedLocations = EconomicSet.create(Equivalence.DEFAULT, modifiedLocations);
modifiedLocations.addAll(initialState.lastMemorySnapshot.getKeys());
}
for (LocationIdentity location : modifiedLocations) {
createMemoryPhi(loop, initialState, phis, location);
}
initialState.lastMemorySnapshot.putAll(phis);
LoopInfo<MemoryMapImpl> loopInfo = ReentrantNodeIterator.processLoop(this, loop, initialState);
UnmodifiableMapCursor<LoopEndNode, MemoryMapImpl> endStateCursor = loopInfo.endStates.getEntries();
while (endStateCursor.advance()) {
int endIndex = loop.phiPredecessorIndex(endStateCursor.getKey());
UnmodifiableMapCursor<LocationIdentity, MemoryPhiNode> phiCursor = phis.getEntries();
while (phiCursor.advance()) {
LocationIdentity key = phiCursor.getKey();
PhiNode phi = phiCursor.getValue();
phi.initializeValueAt(endIndex, ValueNodeUtil.asNode(endStateCursor.getValue().getLastLocationAccess(key)));
}
}
return loopInfo.exitStates;
}
private static void createMemoryPhi(LoopBeginNode loop, MemoryMapImpl initialState, EconomicMap<LocationIdentity, MemoryPhiNode> phis, LocationIdentity location) {
MemoryPhiNode phi = loop.graph().addWithoutUnique(new MemoryPhiNode(loop, location));
phi.addInput(ValueNodeUtil.asNode(initialState.getLastLocationAccess(location)));
phis.put(location, phi);
}
}
}