package org.graalvm.compiler.virtual.phases.ea;
import static org.graalvm.compiler.core.common.GraalOptions.ReadEliminationMaxLoopVisits;
import static jdk.internal.vm.compiler.word.LocationIdentity.any;
import java.util.Iterator;
import java.util.List;
import jdk.internal.vm.compiler.collections.EconomicMap;
import jdk.internal.vm.compiler.collections.EconomicSet;
import jdk.internal.vm.compiler.collections.Equivalence;
import jdk.internal.vm.compiler.collections.MapCursor;
import org.graalvm.compiler.core.common.cfg.Loop;
import org.graalvm.compiler.core.common.type.Stamp;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.nodes.FieldLocationIdentity;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.LoopExitNode;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.ProxyNode;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.ValuePhiNode;
import org.graalvm.compiler.nodes.ValueProxyNode;
import org.graalvm.compiler.nodes.cfg.Block;
import org.graalvm.compiler.nodes.cfg.ControlFlowGraph;
import org.graalvm.compiler.nodes.extended.GuardedNode;
import org.graalvm.compiler.nodes.extended.GuardingNode;
import org.graalvm.compiler.nodes.extended.RawLoadNode;
import org.graalvm.compiler.nodes.extended.RawStoreNode;
import org.graalvm.compiler.nodes.extended.UnsafeAccessNode;
import org.graalvm.compiler.nodes.java.AccessFieldNode;
import org.graalvm.compiler.nodes.java.LoadFieldNode;
import org.graalvm.compiler.nodes.java.StoreFieldNode;
import org.graalvm.compiler.nodes.memory.MemoryCheckpoint;
import org.graalvm.compiler.nodes.memory.ReadNode;
import org.graalvm.compiler.nodes.memory.WriteNode;
import org.graalvm.compiler.nodes.type.StampTool;
import org.graalvm.compiler.nodes.util.GraphUtil;
import org.graalvm.compiler.options.OptionValues;
import org.graalvm.compiler.virtual.phases.ea.ReadEliminationBlockState.CacheEntry;
import org.graalvm.compiler.virtual.phases.ea.ReadEliminationBlockState.LoadCacheEntry;
import org.graalvm.compiler.virtual.phases.ea.ReadEliminationBlockState.UnsafeLoadCacheEntry;
import jdk.internal.vm.compiler.word.LocationIdentity;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaType;
public final class ReadEliminationClosure extends EffectsClosure<ReadEliminationBlockState> {
private final boolean considerGuards;
public ReadEliminationClosure(ControlFlowGraph cfg, boolean considerGuards) {
super(null, cfg);
this.considerGuards = considerGuards;
}
@Override
protected ReadEliminationBlockState getInitialState() {
return new ReadEliminationBlockState();
}
@Override
protected boolean processNode(Node node, ReadEliminationBlockState state, GraphEffectList effects, FixedWithNextNode lastFixedNode) {
boolean deleted = false;
if (node instanceof AccessFieldNode) {
AccessFieldNode access = (AccessFieldNode) node;
if (access.isVolatile()) {
processIdentity(state, any());
} else {
ValueNode object = GraphUtil.unproxify(access.object());
LoadCacheEntry identifier = new LoadCacheEntry(object, new FieldLocationIdentity(access.field()));
ValueNode cachedValue = state.getCacheEntry(identifier);
if (node instanceof LoadFieldNode) {
if (cachedValue != null && access.stamp(NodeView.DEFAULT).isCompatible(cachedValue.stamp(NodeView.DEFAULT))) {
effects.replaceAtUsages(access, cachedValue, access);
addScalarAlias(access, cachedValue);
deleted = true;
} else {
state.addCacheEntry(identifier, access);
}
} else {
assert node instanceof StoreFieldNode;
StoreFieldNode store = (StoreFieldNode) node;
ValueNode value = getScalarAlias(store.value());
if (GraphUtil.unproxify(value) == GraphUtil.unproxify(cachedValue)) {
effects.deleteNode(store);
deleted = true;
}
state.killReadCache(identifier.identity);
state.addCacheEntry(identifier, value);
}
}
} else if (node instanceof ReadNode) {
ReadNode read = (ReadNode) node;
if (read.getLocationIdentity().isSingle()) {
ValueNode object = GraphUtil.unproxify(read.getAddress());
LoadCacheEntry identifier = new LoadCacheEntry(object, read.getLocationIdentity());
ValueNode cachedValue = state.getCacheEntry(identifier);
if (cachedValue != null && areValuesReplaceable(read, cachedValue, considerGuards)) {
effects.replaceAtUsages(read, cachedValue, read);
addScalarAlias(read, cachedValue);
deleted = true;
} else {
state.addCacheEntry(identifier, read);
}
}
} else if (node instanceof WriteNode) {
WriteNode write = (WriteNode) node;
if (write.getLocationIdentity().isSingle()) {
ValueNode object = GraphUtil.unproxify(write.getAddress());
LoadCacheEntry identifier = new LoadCacheEntry(object, write.getLocationIdentity());
ValueNode cachedValue = state.getCacheEntry(identifier);
ValueNode value = getScalarAlias(write.value());
if (GraphUtil.unproxify(value) == GraphUtil.unproxify(cachedValue)) {
effects.deleteNode(write);
deleted = true;
}
processIdentity(state, write.getLocationIdentity());
state.addCacheEntry(identifier, value);
} else {
processIdentity(state, write.getLocationIdentity());
}
} else if (node instanceof UnsafeAccessNode) {
ResolvedJavaType type = StampTool.typeOrNull(((UnsafeAccessNode) node).object());
if (type != null && !type.isArray()) {
if (node instanceof RawLoadNode) {
RawLoadNode load = (RawLoadNode) node;
if (load.getLocationIdentity().isSingle()) {
ValueNode object = GraphUtil.unproxify(load.object());
UnsafeLoadCacheEntry identifier = new UnsafeLoadCacheEntry(object, load.offset(), load.getLocationIdentity());
ValueNode cachedValue = state.getCacheEntry(identifier);
if (cachedValue != null && areValuesReplaceable(load, cachedValue, considerGuards)) {
effects.replaceAtUsages(load, cachedValue, load);
addScalarAlias(load, cachedValue);
deleted = true;
} else {
state.addCacheEntry(identifier, load);
}
}
} else {
assert node instanceof RawStoreNode;
RawStoreNode write = (RawStoreNode) node;
if (write.getLocationIdentity().isSingle()) {
ValueNode object = GraphUtil.unproxify(write.object());
UnsafeLoadCacheEntry identifier = new UnsafeLoadCacheEntry(object, write.offset(), write.getLocationIdentity());
ValueNode cachedValue = state.getCacheEntry(identifier);
ValueNode value = getScalarAlias(write.value());
if (GraphUtil.unproxify(value) == GraphUtil.unproxify(cachedValue)) {
effects.deleteNode(write);
deleted = true;
}
processIdentity(state, write.getLocationIdentity());
state.addCacheEntry(identifier, value);
} else {
processIdentity(state, write.getLocationIdentity());
}
}
}
} else if (node instanceof MemoryCheckpoint.Single) {
LocationIdentity identity = ((MemoryCheckpoint.Single) node).getLocationIdentity();
processIdentity(state, identity);
} else if (node instanceof MemoryCheckpoint.Multi) {
for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) {
processIdentity(state, identity);
}
}
return deleted;
}
private static boolean areValuesReplaceable(ValueNode originalValue, ValueNode replacementValue, boolean considerGuards) {
return originalValue.stamp(NodeView.DEFAULT).isCompatible(replacementValue.stamp(NodeView.DEFAULT)) &&
(!considerGuards || (getGuard(originalValue) == null || getGuard(originalValue) == getGuard(replacementValue)));
}
private static GuardingNode getGuard(ValueNode node) {
if (node instanceof GuardedNode) {
GuardedNode guardedNode = (GuardedNode) node;
return guardedNode.getGuard();
}
return null;
}
private static void processIdentity(ReadEliminationBlockState state, LocationIdentity identity) {
if (identity.isAny()) {
state.killReadCache();
return;
}
state.killReadCache(identity);
}
@Override
protected void processLoopExit(LoopExitNode exitNode, ReadEliminationBlockState initialState, ReadEliminationBlockState exitState, GraphEffectList effects) {
if (exitNode.graph().hasValueProxies()) {
MapCursor<CacheEntry<?>, ValueNode> entry = exitState.getReadCache().getEntries();
while (entry.advance()) {
if (initialState.getReadCache().get(entry.getKey()) != entry.getValue()) {
ProxyNode proxy = new ValueProxyNode(exitState.getCacheEntry(entry.getKey()), exitNode);
effects.addFloatingNode(proxy, "readCacheProxy");
exitState.getReadCache().put(entry.getKey(), proxy);
}
}
}
}
@Override
protected ReadEliminationBlockState cloneState(ReadEliminationBlockState other) {
return new ReadEliminationBlockState(other);
}
@Override
protected MergeProcessor createMergeProcessor(Block merge) {
return new ReadEliminationMergeProcessor(merge);
}
private class ReadEliminationMergeProcessor extends EffectsClosure<ReadEliminationBlockState>.MergeProcessor {
private final EconomicMap<Object, ValuePhiNode> materializedPhis = EconomicMap.create(Equivalence.DEFAULT);
ReadEliminationMergeProcessor(Block mergeBlock) {
super(mergeBlock);
}
protected ValuePhiNode getCachedPhi(CacheEntry<?> virtual, Stamp stamp) {
ValuePhiNode result = materializedPhis.get(virtual);
if (result == null) {
result = createValuePhi(stamp);
materializedPhis.put(virtual, result);
}
return result;
}
@Override
protected void merge(List<ReadEliminationBlockState> states) {
MapCursor<CacheEntry<?>, ValueNode> cursor = states.get(0).readCache.getEntries();
while (cursor.advance()) {
CacheEntry<?> key = cursor.getKey();
ValueNode value = cursor.getValue();
boolean phi = false;
for (int i = 1; i < states.size(); i++) {
ValueNode otherValue = states.get(i).readCache.get(key);
if (otherValue == null || !value.stamp(NodeView.DEFAULT).isCompatible(otherValue.stamp(NodeView.DEFAULT))) {
value = null;
phi = false;
break;
}
if (!phi && otherValue != value) {
phi = true;
}
}
if (phi) {
PhiNode phiNode = getCachedPhi(key, value.stamp(NodeView.DEFAULT).unrestricted());
mergeEffects.addFloatingNode(phiNode, "mergeReadCache");
for (int i = 0; i < states.size(); i++) {
ValueNode v = states.get(i).getCacheEntry(key);
assert phiNode.stamp(NodeView.DEFAULT).isCompatible(v.stamp(NodeView.DEFAULT)) : "Cannot create read elimination phi for inputs with incompatible stamps.";
setPhiInput(phiNode, i, v);
}
newState.addCacheEntry(key, phiNode);
} else if (value != null) {
newState.addCacheEntry(key, value);
}
}
for (PhiNode phi : getPhis()) {
if (phi.getStackKind() == JavaKind.Object) {
for (CacheEntry<?> entry : states.get(0).readCache.getKeys()) {
if (entry.object == getPhiValueAt(phi, 0)) {
mergeReadCachePhi(phi, entry, states);
}
}
}
}
}
private void mergeReadCachePhi(PhiNode phi, CacheEntry<?> identifier, List<ReadEliminationBlockState> states) {
ValueNode[] values = new ValueNode[states.size()];
values[0] = states.get(0).getCacheEntry(identifier.duplicateWithObject(getPhiValueAt(phi, 0)));
if (values[0] != null) {
for (int i = 1; i < states.size(); i++) {
ValueNode value = states.get(i).getCacheEntry(identifier.duplicateWithObject(getPhiValueAt(phi, i)));
if (value == null || !values[i - 1].stamp(NodeView.DEFAULT).isCompatible(value.stamp(NodeView.DEFAULT))) {
return;
}
values[i] = value;
}
CacheEntry<?> newIdentifier = identifier.duplicateWithObject(phi);
PhiNode phiNode = getCachedPhi(newIdentifier, values[0].stamp(NodeView.DEFAULT).unrestricted());
mergeEffects.addFloatingNode(phiNode, "mergeReadCachePhi");
for (int i = 0; i < values.length; i++) {
setPhiInput(phiNode, i, values[i]);
}
newState.addCacheEntry(newIdentifier, phiNode);
}
}
}
@Override
protected void processKilledLoopLocations(Loop<Block> loop, ReadEliminationBlockState initialState, ReadEliminationBlockState mergedStates) {
assert initialState != null;
assert mergedStates != null;
if (initialState.readCache.size() > 0) {
LoopKillCache loopKilledLocations = loopLocationKillCache.get(loop);
if (loopKilledLocations == null) {
loopKilledLocations = new LoopKillCache(1);
loopLocationKillCache.put(loop, loopKilledLocations);
} else {
OptionValues options = loop.getHeader().getBeginNode().getOptions();
if (loopKilledLocations.visits() > ReadEliminationMaxLoopVisits.getValue(options)) {
loopKilledLocations.setKillsAll();
} else {
EconomicSet<LocationIdentity> forwardEndLiveLocations = EconomicSet.create(Equivalence.DEFAULT);
for (CacheEntry<?> entry : initialState.readCache.getKeys()) {
forwardEndLiveLocations.add(entry.getIdentity());
}
for (CacheEntry<?> entry : mergedStates.readCache.getKeys()) {
forwardEndLiveLocations.remove(entry.getIdentity());
}
for (LocationIdentity location : forwardEndLiveLocations) {
loopKilledLocations.rememberLoopKilledLocation(location);
}
if (debug.isLogEnabled() && loopKilledLocations != null) {
debug.log("[Early Read Elimination] Setting loop killed locations of loop at node %s with %s",
loop.getHeader().getBeginNode(), forwardEndLiveLocations);
}
}
loopKilledLocations.visited();
}
}
}
@Override
protected ReadEliminationBlockState stripKilledLoopLocations(Loop<Block> loop, ReadEliminationBlockState originalInitialState) {
ReadEliminationBlockState initialState = super.stripKilledLoopLocations(loop, originalInitialState);
LoopKillCache loopKilledLocations = loopLocationKillCache.get(loop);
if (loopKilledLocations != null && loopKilledLocations.loopKillsLocations()) {
Iterator<CacheEntry<?>> it = initialState.readCache.getKeys().iterator();
while (it.hasNext()) {
CacheEntry<?> entry = it.next();
if (loopKilledLocations.containsLocation(entry.getIdentity())) {
it.remove();
}
}
}
return initialState;
}
}