package org.graalvm.compiler.hotspot.phases.aot;
import static org.graalvm.compiler.core.common.cfg.AbstractControlFlowGraph.strictlyDominates;
import static org.graalvm.compiler.hotspot.nodes.aot.LoadMethodCountersNode.getLoadMethodCountersNodes;
import static org.graalvm.compiler.nodes.ConstantNode.getConstantNodes;
import java.util.HashSet;
import java.util.List;
import org.graalvm.compiler.core.common.cfg.BlockMap;
import org.graalvm.compiler.core.common.type.ObjectStamp;
import org.graalvm.compiler.core.common.type.Stamp;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.NodeMap;
import org.graalvm.compiler.hotspot.meta.HotSpotConstantLoadAction;
import org.graalvm.compiler.hotspot.nodes.aot.InitializeKlassNode;
import org.graalvm.compiler.hotspot.nodes.aot.LoadConstantIndirectlyFixedNode;
import org.graalvm.compiler.hotspot.nodes.aot.LoadConstantIndirectlyNode;
import org.graalvm.compiler.hotspot.nodes.aot.LoadMethodCountersNode;
import org.graalvm.compiler.hotspot.nodes.aot.ResolveConstantNode;
import org.graalvm.compiler.hotspot.nodes.aot.ResolveDynamicConstantNode;
import org.graalvm.compiler.hotspot.nodes.aot.ResolveMethodAndLoadCountersNode;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.AbstractMergeNode;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.FrameState;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.LoopExitNode;
import org.graalvm.compiler.nodes.StateSplit;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.StructuredGraph.ScheduleResult;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.FloatingNode;
import org.graalvm.compiler.nodes.cfg.Block;
import org.graalvm.compiler.phases.BasePhase;
import org.graalvm.compiler.phases.graph.ReentrantNodeIterator;
import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.NodeIteratorClosure;
import org.graalvm.compiler.phases.schedule.SchedulePhase;
import org.graalvm.compiler.phases.schedule.SchedulePhase.SchedulingStrategy;
import org.graalvm.compiler.phases.tiers.PhaseContext;
import org.graalvm.util.EconomicMap;
import jdk.vm.ci.code.BytecodeFrame;
import jdk.vm.ci.hotspot.HotSpotMetaspaceConstant;
import jdk.vm.ci.hotspot.HotSpotObjectConstant;
import jdk.vm.ci.hotspot.HotSpotResolvedJavaType;
import jdk.vm.ci.hotspot.HotSpotResolvedObjectType;
import jdk.vm.ci.meta.Constant;
import jdk.vm.ci.meta.ConstantReflectionProvider;
import jdk.vm.ci.meta.ResolvedJavaType;
public class ReplaceConstantNodesPhase extends BasePhase<PhaseContext> {
private static final HashSet<Class<?>> builtIns = new HashSet<>();
private final boolean verifyFingerprints;
static {
builtIns.add(Boolean.class);
Class<?> characterCacheClass = Character.class.getDeclaredClasses()[0];
assert "java.lang.Character$CharacterCache".equals(characterCacheClass.getName());
builtIns.add(characterCacheClass);
Class<?> byteCacheClass = Byte.class.getDeclaredClasses()[0];
assert "java.lang.Byte$ByteCache".equals(byteCacheClass.getName());
builtIns.add(byteCacheClass);
Class<?> shortCacheClass = Short.class.getDeclaredClasses()[0];
assert "java.lang.Short$ShortCache".equals(shortCacheClass.getName());
builtIns.add(shortCacheClass);
Class<?> integerCacheClass = Integer.class.getDeclaredClasses()[0];
assert "java.lang.Integer$IntegerCache".equals(integerCacheClass.getName());
builtIns.add(integerCacheClass);
Class<?> longCacheClass = Long.class.getDeclaredClasses()[0];
assert "java.lang.Long$LongCache".equals(longCacheClass.getName());
builtIns.add(longCacheClass);
}
private static boolean isReplacementNode(Node n) {
return n instanceof LoadConstantIndirectlyNode ||
n instanceof LoadConstantIndirectlyFixedNode ||
n instanceof ResolveDynamicConstantNode ||
n instanceof ResolveConstantNode ||
n instanceof InitializeKlassNode;
}
private static boolean anyUsagesNeedReplacement(ConstantNode node) {
return node.usages().filter(n -> !isReplacementNode(n)).isNotEmpty();
}
private static boolean anyUsagesNeedReplacement(LoadMethodCountersNode node) {
return node.usages().filter(n -> !(n instanceof ResolveMethodAndLoadCountersNode)).isNotEmpty();
}
private static boolean checkForBadFingerprint(HotSpotResolvedJavaType type) {
if (type.isArray()) {
if (type.getElementalType().isPrimitive()) {
return false;
}
return ((HotSpotResolvedObjectType) (type.getElementalType())).getFingerprint() == 0;
}
return ((HotSpotResolvedObjectType) type).getFingerprint() == 0;
}
private static void insertReplacement(StructuredGraph graph, FrameStateMapperClosure stateMapper, FloatingNode node, FixedWithNextNode replacement) {
FixedWithNextNode insertionPoint = findInsertionPoint(graph, stateMapper, node);
graph.addAfterFixed(insertionPoint, replacement);
stateMapper.addState(replacement, stateMapper.getState(insertionPoint));
}
private static FixedWithNextNode findInsertionPoint(StructuredGraph graph, FrameStateMapperClosure stateMapper, FloatingNode node) {
FixedWithNextNode fixed = findFixedBeforeFloating(graph, node);
FixedWithNextNode result = findFixedWithValidState(graph, stateMapper, fixed);
return result;
}
private static FixedWithNextNode findFixedBeforeFloating(StructuredGraph graph, FloatingNode node) {
ScheduleResult schedule = graph.getLastSchedule();
NodeMap<Block> nodeToBlock = schedule.getNodeToBlockMap();
Block block = nodeToBlock.get(node);
BlockMap<List<Node>> blockToNodes = schedule.getBlockToNodesMap();
FixedWithNextNode result = null;
for (Node n : blockToNodes.get(block)) {
if (n.equals(node)) {
break;
}
if (n instanceof FixedWithNextNode) {
result = (FixedWithNextNode) n;
}
}
assert result != null;
return result;
}
private static FixedWithNextNode findFixedWithValidState(StructuredGraph graph, FrameStateMapperClosure stateMapper, FixedWithNextNode node) {
ScheduleResult schedule = graph.getLastSchedule();
NodeMap<Block> nodeToBlock = schedule.getNodeToBlockMap();
Block block = nodeToBlock.get(node);
Node n = node;
do {
if (isFixedWithValidState(stateMapper, n)) {
return (FixedWithNextNode) n;
}
while (n != block.getBeginNode()) {
n = n.predecessor();
if (isFixedWithValidState(stateMapper, n)) {
return (FixedWithNextNode) n;
}
}
block = block.getDominator();
if (block != null) {
n = block.getEndNode();
}
} while (block != null);
return graph.start();
}
private static boolean isFixedWithValidState(FrameStateMapperClosure stateMapper, Node n) {
if (n instanceof FixedWithNextNode) {
FixedWithNextNode fixed = (FixedWithNextNode) n;
assert stateMapper.getState(fixed) != null;
if (!BytecodeFrame.isPlaceholderBci(stateMapper.getState(fixed).bci)) {
return true;
}
}
return false;
}
private static class FrameStateMapperClosure extends NodeIteratorClosure<FrameState> {
private NodeMap<FrameState> reachingStates;
@Override
protected FrameState processNode(FixedNode node, FrameState previousState) {
FrameState currentState = previousState;
if (node instanceof StateSplit) {
StateSplit stateSplit = (StateSplit) node;
FrameState stateAfter = stateSplit.stateAfter();
if (stateAfter != null) {
currentState = stateAfter;
}
}
reachingStates.put(node, currentState);
return currentState;
}
@Override
protected FrameState merge(AbstractMergeNode merge, List<FrameState> states) {
FrameState singleFrameState = singleFrameState(states);
FrameState currentState = singleFrameState == null ? merge.stateAfter() : singleFrameState;
reachingStates.put(merge, currentState);
return currentState;
}
@Override
protected FrameState afterSplit(AbstractBeginNode node, FrameState oldState) {
return oldState;
}
@Override
protected EconomicMap<LoopExitNode, FrameState> processLoop(LoopBeginNode loop, FrameState initialState) {
return ReentrantNodeIterator.processLoop(this, loop, initialState).exitStates;
}
private static FrameState singleFrameState(List<FrameState> states) {
FrameState singleState = states.get(0);
for (int i = 1; i < states.size(); ++i) {
if (states.get(i) != singleState) {
return null;
}
}
return singleState;
}
FrameStateMapperClosure(StructuredGraph graph) {
reachingStates = new NodeMap<>(graph);
}
public FrameState getState(Node n) {
return reachingStates.get(n);
}
public void addState(Node n, FrameState s) {
reachingStates.setAndGrow(n, s);
}
}
private static void tryToReplaceWithExisting(StructuredGraph graph, ConstantNode node) {
ScheduleResult schedule = graph.getLastSchedule();
NodeMap<Block> nodeToBlock = schedule.getNodeToBlockMap();
BlockMap<List<Node>> blockToNodes = schedule.getBlockToNodesMap();
EconomicMap<Block, Node> blockToExisting = EconomicMap.create();
for (Node n : node.usages().filter(n -> isReplacementNode(n))) {
blockToExisting.put(nodeToBlock.get(n), n);
}
for (Node use : node.usages().filter(n -> !isReplacementNode(n)).snapshot()) {
boolean replaced = false;
Block b = nodeToBlock.get(use);
Node e = blockToExisting.get(b);
if (e != null) {
for (Node n : blockToNodes.get(b)) {
if (n.equals(use)) {
break;
}
if (n.equals(e)) {
use.replaceFirstInput(node, e);
replaced = true;
break;
}
}
}
if (!replaced) {
for (Block d : blockToExisting.getKeys()) {
if (strictlyDominates(d, b)) {
use.replaceFirstInput(node, blockToExisting.get(d));
break;
}
}
}
}
}
private static void replaceWithResolution(StructuredGraph graph, FrameStateMapperClosure stateMapper, ConstantNode node) {
HotSpotMetaspaceConstant metaspaceConstant = (HotSpotMetaspaceConstant) node.asConstant();
HotSpotResolvedJavaType type = (HotSpotResolvedJavaType) metaspaceConstant.asResolvedJavaType();
ResolvedJavaType topMethodHolder = graph.method().getDeclaringClass();
ValueNode replacement;
if (type.isArray() && type.getComponentType().isPrimitive()) {
replacement = graph.addOrUnique(new LoadConstantIndirectlyNode(node));
} else if (type.equals(topMethodHolder) || (type.isAssignableFrom(topMethodHolder) && !type.isInterface())) {
replacement = graph.addOrUnique(new LoadConstantIndirectlyNode(node));
} else {
FixedWithNextNode fixedReplacement;
if (builtIns.contains(type.mirror())) {
fixedReplacement = graph.add(new ResolveConstantNode(node, HotSpotConstantLoadAction.INITIALIZE));
} else {
fixedReplacement = graph.add(new ResolveConstantNode(node));
}
insertReplacement(graph, stateMapper, node, fixedReplacement);
replacement = fixedReplacement;
}
node.replaceAtUsages(replacement, n -> !isReplacementNode(n));
}
private void handleHotSpotMetaspaceConstant(StructuredGraph graph, FrameStateMapperClosure stateMapper, ConstantNode node) {
HotSpotMetaspaceConstant metaspaceConstant = (HotSpotMetaspaceConstant) node.asConstant();
HotSpotResolvedJavaType type = (HotSpotResolvedJavaType) metaspaceConstant.asResolvedJavaType();
if (type != null) {
if (verifyFingerprints && checkForBadFingerprint(type)) {
throw new GraalError("Type with bad fingerprint: " + type);
}
assert !metaspaceConstant.isCompressed() : "No support for replacing compressed metaspace constants";
tryToReplaceWithExisting(graph, node);
if (anyUsagesNeedReplacement(node)) {
replaceWithResolution(graph, stateMapper, node);
}
} else {
throw new GraalError("Unsupported metaspace constant type: " + type);
}
}
private static void handleHotSpotObjectConstant(StructuredGraph graph, FrameStateMapperClosure stateMapper, ConstantNode node) {
HotSpotObjectConstant constant = (HotSpotObjectConstant) node.asJavaConstant();
HotSpotResolvedJavaType type = (HotSpotResolvedJavaType) constant.getType();
if (type.mirror().equals(String.class)) {
assert !constant.isCompressed() : "No support for replacing compressed oop constants";
FixedWithNextNode replacement = graph.add(new ResolveConstantNode(node));
insertReplacement(graph, stateMapper, node, replacement);
node.replaceAtUsages(replacement, n -> !(n instanceof ResolveConstantNode));
} else {
throw new GraalError("Unsupported object constant type: " + type);
}
}
private static void handleLoadMethodCounters(StructuredGraph graph, FrameStateMapperClosure stateMapper, LoadMethodCountersNode node, PhaseContext context) {
ResolvedJavaType type = node.getMethod().getDeclaringClass();
Stamp hubStamp = context.getStampProvider().createHubStamp((ObjectStamp) StampFactory.objectNonNull());
ConstantReflectionProvider constantReflection = context.getConstantReflection();
ConstantNode klassHint = ConstantNode.forConstant(hubStamp, constantReflection.asObjectHub(type), context.getMetaAccess(), graph);
FixedWithNextNode replacement = graph.add(new ResolveMethodAndLoadCountersNode(node.getMethod(), klassHint));
insertReplacement(graph, stateMapper, node, replacement);
node.replaceAtUsages(replacement, n -> !(n instanceof ResolveMethodAndLoadCountersNode));
}
private static void replaceLoadMethodCounters(StructuredGraph graph, FrameStateMapperClosure stateMapper, PhaseContext context) {
new SchedulePhase(SchedulingStrategy.LATEST_OUT_OF_LOOPS, true).apply(graph, false);
for (LoadMethodCountersNode node : getLoadMethodCountersNodes(graph)) {
if (anyUsagesNeedReplacement(node)) {
handleLoadMethodCounters(graph, stateMapper, node, context);
}
}
}
private void replaceKlassesAndObjects(StructuredGraph graph, FrameStateMapperClosure stateMapper) {
new SchedulePhase(SchedulingStrategy.LATEST_OUT_OF_LOOPS, true).apply(graph, false);
for (ConstantNode node : getConstantNodes(graph)) {
Constant constant = node.asConstant();
if (constant instanceof HotSpotMetaspaceConstant && anyUsagesNeedReplacement(node)) {
handleHotSpotMetaspaceConstant(graph, stateMapper, node);
} else if (constant instanceof HotSpotObjectConstant && anyUsagesNeedReplacement(node)) {
handleHotSpotObjectConstant(graph, stateMapper, node);
}
}
}
@Override
protected void run(StructuredGraph graph, PhaseContext context) {
FrameStateMapperClosure stateMapper = new FrameStateMapperClosure(graph);
ReentrantNodeIterator.apply(stateMapper, graph.start(), null);
replaceLoadMethodCounters(graph, stateMapper, context);
replaceKlassesAndObjects(graph, stateMapper);
}
@Override
public boolean checkContract() {
return false;
}
public ReplaceConstantNodesPhase() {
this(true);
}
public ReplaceConstantNodesPhase(boolean verifyFingerprints) {
this.verifyFingerprints = verifyFingerprints;
}
}