package org.graalvm.compiler.replacements;
import static org.graalvm.compiler.nodes.graphbuilderconf.IntrinsicContext.CompilationContext.INLINE_AFTER_PARSING;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;
import org.graalvm.compiler.core.common.spi.ConstantFieldProvider;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.core.common.type.StampPair;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.graph.Graph;
import org.graalvm.compiler.graph.Node.ValueNumberable;
import org.graalvm.compiler.java.FrameStateBuilder;
import org.graalvm.compiler.java.GraphBuilderPhase;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.AbstractMergeNode;
import org.graalvm.compiler.nodes.BeginNode;
import org.graalvm.compiler.nodes.CallTargetNode.InvokeKind;
import org.graalvm.compiler.nodes.EndNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.InvokeNode;
import org.graalvm.compiler.nodes.InvokeWithExceptionNode;
import org.graalvm.compiler.nodes.KillingBeginNode;
import org.graalvm.compiler.nodes.LogicNode;
import org.graalvm.compiler.nodes.MergeNode;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.FloatingNode;
import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration;
import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration.Plugins;
import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderTool;
import org.graalvm.compiler.nodes.graphbuilderconf.IntrinsicContext;
import org.graalvm.compiler.nodes.java.ExceptionObjectNode;
import org.graalvm.compiler.nodes.java.MethodCallTargetNode;
import org.graalvm.compiler.nodes.spi.StampProvider;
import org.graalvm.compiler.nodes.type.StampTool;
import org.graalvm.compiler.phases.OptimisticOptimizations;
import org.graalvm.compiler.phases.common.DeadCodeEliminationPhase;
import org.graalvm.compiler.phases.common.DeadCodeEliminationPhase.Optionality;
import org.graalvm.compiler.phases.common.inlining.InliningUtil;
import org.graalvm.compiler.phases.util.Providers;
import org.graalvm.compiler.word.WordTypes;
import org.graalvm.word.LocationIdentity;
import jdk.vm.ci.code.BytecodeFrame;
import jdk.vm.ci.meta.ConstantReflectionProvider;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.JavaType;
import jdk.vm.ci.meta.MetaAccessProvider;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import jdk.vm.ci.meta.Signature;
public class GraphKit implements GraphBuilderTool {
protected final Providers providers;
protected final StructuredGraph graph;
protected final WordTypes wordTypes;
protected final GraphBuilderConfiguration.Plugins graphBuilderPlugins;
protected FixedWithNextNode lastFixedNode;
private final List<Structure> structures;
protected abstract static class Structure {
}
public GraphKit(StructuredGraph graph, Providers providers, WordTypes wordTypes, GraphBuilderConfiguration.Plugins graphBuilderPlugins) {
this.providers = providers;
this.graph = graph;
this.wordTypes = wordTypes;
this.graphBuilderPlugins = graphBuilderPlugins;
this.lastFixedNode = graph.start();
structures = new ArrayList<>();
structures.add(new Structure() {
});
}
@Override
public StructuredGraph getGraph() {
return graph;
}
@Override
public ConstantReflectionProvider getConstantReflection() {
return providers.getConstantReflection();
}
@Override
public ConstantFieldProvider getConstantFieldProvider() {
return providers.getConstantFieldProvider();
}
@Override
public MetaAccessProvider getMetaAccess() {
return providers.getMetaAccess();
}
@Override
public StampProvider getStampProvider() {
return providers.getStampProvider();
}
@Override
public boolean parsingIntrinsic() {
return true;
}
public <T extends FloatingNode & ValueNumberable> T unique(T node) {
return graph.unique(changeToWord(node));
}
public <T extends ValueNode> T add(T node) {
return graph.add(changeToWord(node));
}
public <T extends ValueNode> T changeToWord(T node) {
if (wordTypes != null && wordTypes.isWord(node)) {
node.setStamp(wordTypes.getWordStamp(StampTool.typeOrNull(node)));
}
return node;
}
@Override
public <T extends ValueNode> T append(T node) {
T result = graph.addOrUniqueWithInputs(changeToWord(node));
if (result instanceof FixedNode) {
updateLastFixed((FixedNode) result);
}
return result;
}
private void updateLastFixed(FixedNode result) {
assert lastFixedNode != null;
assert result.predecessor() == null;
graph.addAfterFixed(lastFixedNode, result);
if (result instanceof FixedWithNextNode) {
lastFixedNode = (FixedWithNextNode) result;
} else {
lastFixedNode = null;
}
}
public InvokeNode createInvoke(Class<?> declaringClass, String name, ValueNode... args) {
return createInvoke(declaringClass, name, InvokeKind.Static, null, BytecodeFrame.UNKNOWN_BCI, args);
}
public InvokeNode createInvoke(Class<?> declaringClass, String name, InvokeKind invokeKind, FrameStateBuilder frameStateBuilder, int bci, ValueNode... args) {
boolean isStatic = invokeKind == InvokeKind.Static;
ResolvedJavaMethod method = findMethod(declaringClass, name, isStatic);
return createInvoke(method, invokeKind, frameStateBuilder, bci, args);
}
public ResolvedJavaMethod findMethod(Class<?> declaringClass, String name, boolean isStatic) {
ResolvedJavaMethod method = null;
for (Method m : declaringClass.getDeclaredMethods()) {
if (Modifier.isStatic(m.getModifiers()) == isStatic && m.getName().equals(name)) {
assert method == null : "found more than one method in " + declaringClass + " named " + name;
method = providers.getMetaAccess().lookupJavaMethod(m);
}
}
GraalError.guarantee(method != null, "Could not find %s.%s (%s)", declaringClass, name, isStatic ? "static" : "non-static");
return method;
}
public ResolvedJavaMethod findMethod(Class<?> declaringClass, String name, Class<?>... parameterTypes) {
try {
Method m = declaringClass.getDeclaredMethod(name, parameterTypes);
return providers.getMetaAccess().lookupJavaMethod(m);
} catch (NoSuchMethodException | SecurityException e) {
throw new AssertionError(e);
}
}
public InvokeNode createInvoke(ResolvedJavaMethod method, InvokeKind invokeKind, FrameStateBuilder frameStateBuilder, int bci, ValueNode... args) {
assert method.isStatic() == (invokeKind == InvokeKind.Static);
Signature signature = method.getSignature();
JavaType returnType = signature.getReturnType(null);
assert checkArgs(method, args);
StampPair returnStamp = graphBuilderPlugins.getOverridingStamp(this, returnType, false);
if (returnStamp == null) {
returnStamp = StampFactory.forDeclaredType(graph.getAssumptions(), returnType, false);
}
MethodCallTargetNode callTarget = graph.add(createMethodCallTarget(invokeKind, method, args, returnStamp, bci));
InvokeNode invoke = append(new InvokeNode(callTarget, bci));
if (frameStateBuilder != null) {
if (invoke.getStackKind() != JavaKind.Void) {
frameStateBuilder.push(invoke.getStackKind(), invoke);
}
invoke.setStateAfter(frameStateBuilder.create(bci, invoke));
if (invoke.getStackKind() != JavaKind.Void) {
frameStateBuilder.pop(invoke.getStackKind());
}
}
return invoke;
}
protected MethodCallTargetNode createMethodCallTarget(InvokeKind invokeKind, ResolvedJavaMethod targetMethod, ValueNode[] args, StampPair returnStamp, @SuppressWarnings("unused") int bci) {
return new MethodCallTargetNode(invokeKind, targetMethod, args, returnStamp, null);
}
protected final JavaKind asKind(JavaType type) {
return wordTypes != null ? wordTypes.asKind(type) : type.getJavaKind();
}
public boolean checkArgs(ResolvedJavaMethod method, ValueNode... args) {
Signature signature = method.getSignature();
boolean isStatic = method.isStatic();
if (signature.getParameterCount(!isStatic) != args.length) {
throw new AssertionError(graph + ": wrong number of arguments to " + method);
}
int argIndex = 0;
if (!isStatic) {
JavaKind expected = asKind(method.getDeclaringClass());
JavaKind actual = args[argIndex++].stamp(NodeView.DEFAULT).getStackKind();
assert expected == actual : graph + ": wrong kind of value for receiver argument of call to " + method + " [" + actual + " != " + expected + "]";
}
for (int i = 0; i != signature.getParameterCount(false); i++) {
JavaKind expected = asKind(signature.getParameterType(i, method.getDeclaringClass())).getStackKind();
JavaKind actual = args[argIndex++].stamp(NodeView.DEFAULT).getStackKind();
if (expected != actual) {
throw new AssertionError(graph + ": wrong kind of value for argument " + i + " of call to " + method + " [" + actual + " != " + expected + "]");
}
}
return true;
}
public void inlineInvokes() {
while (!graph.getNodes().filter(InvokeNode.class).isEmpty()) {
for (InvokeNode invoke : graph.getNodes().filter(InvokeNode.class).snapshot()) {
inline(invoke);
}
}
new DeadCodeEliminationPhase().apply(graph);
}
public void inline(InvokeNode invoke) {
ResolvedJavaMethod method = ((MethodCallTargetNode) invoke.callTarget()).targetMethod();
MetaAccessProvider metaAccess = providers.getMetaAccess();
Plugins plugins = new Plugins(graphBuilderPlugins);
GraphBuilderConfiguration config = GraphBuilderConfiguration.getSnippetDefault(plugins);
StructuredGraph calleeGraph = new StructuredGraph.Builder(invoke.getOptions(), invoke.getDebug()).method(method).build();
IntrinsicContext initialReplacementContext = new IntrinsicContext(method, method, providers.getReplacements().getDefaultReplacementBytecodeProvider(), INLINE_AFTER_PARSING);
GraphBuilderPhase.Instance instance = new GraphBuilderPhase.Instance(metaAccess, providers.getStampProvider(), providers.getConstantReflection(), providers.getConstantFieldProvider(), config,
OptimisticOptimizations.NONE,
initialReplacementContext);
instance.apply(calleeGraph);
calleeGraph.clearAllStateAfter();
new DeadCodeEliminationPhase(Optionality.Required).apply(calleeGraph);
InliningUtil.inline(invoke, calleeGraph, false, method);
}
protected void pushStructure(Structure structure) {
structures.add(structure);
}
protected <T extends Structure> T getTopStructure(Class<T> expectedClass) {
return expectedClass.cast(structures.get(structures.size() - 1));
}
protected void popStructure() {
structures.remove(structures.size() - 1);
}
protected enum IfState {
CONDITION,
THEN_PART,
ELSE_PART,
FINISHED
}
static class IfStructure extends Structure {
protected IfState state;
protected FixedNode thenPart;
protected FixedNode elsePart;
}
public void startIf(LogicNode condition, double trueProbability) {
AbstractBeginNode thenSuccessor = graph.add(new BeginNode());
AbstractBeginNode elseSuccessor = graph.add(new BeginNode());
append(new IfNode(condition, thenSuccessor, elseSuccessor, trueProbability));
lastFixedNode = null;
IfStructure s = new IfStructure();
s.state = IfState.CONDITION;
s.thenPart = thenSuccessor;
s.elsePart = elseSuccessor;
pushStructure(s);
}
private IfStructure saveLastIfNode() {
IfStructure s = getTopStructure(IfStructure.class);
switch (s.state) {
case CONDITION:
assert lastFixedNode == null;
break;
case THEN_PART:
s.thenPart = lastFixedNode;
break;
case ELSE_PART:
s.elsePart = lastFixedNode;
break;
case FINISHED:
assert false;
break;
}
lastFixedNode = null;
return s;
}
public void thenPart() {
IfStructure s = saveLastIfNode();
lastFixedNode = (FixedWithNextNode) s.thenPart;
s.state = IfState.THEN_PART;
}
public void elsePart() {
IfStructure s = saveLastIfNode();
lastFixedNode = (FixedWithNextNode) s.elsePart;
s.state = IfState.ELSE_PART;
}
public void endIf() {
IfStructure s = saveLastIfNode();
FixedWithNextNode thenPart = s.thenPart instanceof FixedWithNextNode ? (FixedWithNextNode) s.thenPart : null;
FixedWithNextNode elsePart = s.elsePart instanceof FixedWithNextNode ? (FixedWithNextNode) s.elsePart : null;
if (thenPart != null && elsePart != null) {
EndNode thenEnd = graph.add(new EndNode());
graph.addAfterFixed(thenPart, thenEnd);
EndNode elseEnd = graph.add(new EndNode());
graph.addAfterFixed(elsePart, elseEnd);
AbstractMergeNode merge = graph.add(new MergeNode());
merge.addForwardEnd(thenEnd);
merge.addForwardEnd(elseEnd);
lastFixedNode = merge;
} else if (thenPart != null) {
lastFixedNode = thenPart;
} else if (elsePart != null) {
lastFixedNode = elsePart;
} else {
assert lastFixedNode == null;
}
s.state = IfState.FINISHED;
popStructure();
}
static class InvokeWithExceptionStructure extends Structure {
protected enum State {
INVOKE,
NO_EXCEPTION_EDGE,
EXCEPTION_EDGE,
FINISHED
}
protected State state;
protected ExceptionObjectNode exceptionObject;
protected FixedNode noExceptionEdge;
protected FixedNode exceptionEdge;
}
public InvokeWithExceptionNode startInvokeWithException(ResolvedJavaMethod method, InvokeKind invokeKind,
FrameStateBuilder frameStateBuilder, int invokeBci, int exceptionEdgeBci, ValueNode... args) {
assert method.isStatic() == (invokeKind == InvokeKind.Static);
Signature signature = method.getSignature();
JavaType returnType = signature.getReturnType(null);
assert checkArgs(method, args);
StampPair returnStamp = graphBuilderPlugins.getOverridingStamp(this, returnType, false);
if (returnStamp == null) {
returnStamp = StampFactory.forDeclaredType(graph.getAssumptions(), returnType, false);
}
ExceptionObjectNode exceptionObject = add(new ExceptionObjectNode(getMetaAccess()));
if (frameStateBuilder != null) {
FrameStateBuilder exceptionState = frameStateBuilder.copy();
exceptionState.clearStack();
exceptionState.push(JavaKind.Object, exceptionObject);
exceptionState.setRethrowException(false);
exceptionObject.setStateAfter(exceptionState.create(exceptionEdgeBci, exceptionObject));
}
MethodCallTargetNode callTarget = graph.add(createMethodCallTarget(invokeKind, method, args, returnStamp, invokeBci));
InvokeWithExceptionNode invoke = append(new InvokeWithExceptionNode(callTarget, exceptionObject, invokeBci));
AbstractBeginNode noExceptionEdge = graph.add(KillingBeginNode.create(LocationIdentity.any()));
invoke.setNext(noExceptionEdge);
if (frameStateBuilder != null) {
if (invoke.getStackKind() != JavaKind.Void) {
frameStateBuilder.push(invoke.getStackKind(), invoke);
}
invoke.setStateAfter(frameStateBuilder.create(invokeBci, invoke));
if (invoke.getStackKind() != JavaKind.Void) {
frameStateBuilder.pop(invoke.getStackKind());
}
}
lastFixedNode = null;
InvokeWithExceptionStructure s = new InvokeWithExceptionStructure();
s.state = InvokeWithExceptionStructure.State.INVOKE;
s.noExceptionEdge = noExceptionEdge;
s.exceptionEdge = exceptionObject;
s.exceptionObject = exceptionObject;
pushStructure(s);
return invoke;
}
private InvokeWithExceptionStructure saveLastInvokeWithExceptionNode() {
InvokeWithExceptionStructure s = getTopStructure(InvokeWithExceptionStructure.class);
switch (s.state) {
case INVOKE:
assert lastFixedNode == null;
break;
case NO_EXCEPTION_EDGE:
s.noExceptionEdge = lastFixedNode;
break;
case EXCEPTION_EDGE:
s.exceptionEdge = lastFixedNode;
break;
case FINISHED:
assert false;
break;
}
lastFixedNode = null;
return s;
}
public void noExceptionPart() {
InvokeWithExceptionStructure s = saveLastInvokeWithExceptionNode();
lastFixedNode = (FixedWithNextNode) s.noExceptionEdge;
s.state = InvokeWithExceptionStructure.State.NO_EXCEPTION_EDGE;
}
public void exceptionPart() {
InvokeWithExceptionStructure s = saveLastInvokeWithExceptionNode();
lastFixedNode = (FixedWithNextNode) s.exceptionEdge;
s.state = InvokeWithExceptionStructure.State.EXCEPTION_EDGE;
}
public ExceptionObjectNode exceptionObject() {
InvokeWithExceptionStructure s = getTopStructure(InvokeWithExceptionStructure.class);
return s.exceptionObject;
}
public AbstractMergeNode endInvokeWithException() {
InvokeWithExceptionStructure s = saveLastInvokeWithExceptionNode();
FixedWithNextNode noExceptionEdge = s.noExceptionEdge instanceof FixedWithNextNode ? (FixedWithNextNode) s.noExceptionEdge : null;
FixedWithNextNode exceptionEdge = s.exceptionEdge instanceof FixedWithNextNode ? (FixedWithNextNode) s.exceptionEdge : null;
AbstractMergeNode merge = null;
if (noExceptionEdge != null && exceptionEdge != null) {
EndNode noExceptionEnd = graph.add(new EndNode());
graph.addAfterFixed(noExceptionEdge, noExceptionEnd);
EndNode exceptionEnd = graph.add(new EndNode());
graph.addAfterFixed(exceptionEdge, exceptionEnd);
merge = graph.add(new MergeNode());
merge.addForwardEnd(noExceptionEnd);
merge.addForwardEnd(exceptionEnd);
lastFixedNode = merge;
} else if (noExceptionEdge != null) {
lastFixedNode = noExceptionEdge;
} else if (exceptionEdge != null) {
lastFixedNode = exceptionEdge;
} else {
assert lastFixedNode == null;
}
s.state = InvokeWithExceptionStructure.State.FINISHED;
popStructure();
return merge;
}
}