package org.graalvm.compiler.truffle.test;
import static org.graalvm.compiler.core.common.CompilationIdentifier.INVALID_COMPILATION_ID;
import static org.graalvm.compiler.core.common.CompilationRequestIdentifier.asCompilationRequest;
import static org.graalvm.compiler.debug.DebugOptions.DumpOnError;
import org.graalvm.compiler.code.CompilationResult;
import org.graalvm.compiler.core.common.CompilationIdentifier;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.debug.DebugDumpScope;
import org.graalvm.compiler.nodes.FrameState;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions;
import org.graalvm.compiler.nodes.java.MethodCallTargetNode;
import org.graalvm.compiler.options.OptionValues;
import org.graalvm.compiler.phases.PhaseSuite;
import org.graalvm.compiler.phases.common.DeadCodeEliminationPhase;
import org.graalvm.compiler.phases.tiers.HighTierContext;
import org.graalvm.compiler.truffle.common.TruffleCompilationTask;
import org.graalvm.compiler.truffle.common.TruffleCompilerRuntime;
import org.graalvm.compiler.truffle.common.TruffleDebugJavaMethod;
import org.graalvm.compiler.truffle.compiler.PartialEvaluator;
import org.graalvm.compiler.truffle.runtime.OptimizedCallTarget;
import org.graalvm.compiler.truffle.runtime.TruffleInlining;
import org.junit.Assert;
import com.oracle.truffle.api.CallTarget;
import com.oracle.truffle.api.Truffle;
import com.oracle.truffle.api.nodes.ControlFlowException;
import com.oracle.truffle.api.nodes.RootNode;
import jdk.vm.ci.code.BailoutException;
import jdk.vm.ci.meta.SpeculationLog;
public abstract class PartialEvaluationTest extends TruffleCompilerImplTest {
protected CompilationResult lastCompilationResult;
DebugContext lastDebug;
private volatile PhaseSuite<HighTierContext> suite;
private boolean preventDumping = false;
public PartialEvaluationTest() {
}
protected OptimizedCallTarget assertPartialEvalEquals(String methodName, RootNode root) {
return assertPartialEvalEquals(methodName, root, new Object[0]);
}
private CompilationIdentifier getCompilationId(final OptimizedCallTarget compilable) {
return this.getTruffleCompiler(compilable).createCompilationIdentifier(compilable);
}
protected OptimizedCallTarget compileHelper(String methodName, RootNode root, Object[] arguments) {
final OptimizedCallTarget compilable = (OptimizedCallTarget) (Truffle.getRuntime()).createCallTarget(root);
CompilationIdentifier compilationId = getCompilationId(compilable);
StructuredGraph graph = partialEval(compilable, arguments, compilationId);
this.lastCompilationResult = getTruffleCompiler(compilable).compilePEGraph(graph, methodName, null, compilable, asCompilationRequest(compilationId), null,
newTask());
this.lastCompiledGraph = graph;
return compilable;
}
protected void assertPartialEvalEquals(RootNode expected, RootNode actual, Object[] arguments) {
final OptimizedCallTarget expectedTarget = (OptimizedCallTarget) Truffle.getRuntime().createCallTarget(expected);
final OptimizedCallTarget actualTarget = (OptimizedCallTarget) Truffle.getRuntime().createCallTarget(actual);
BailoutException lastBailout = null;
for (int i = 0; i < 10; i++) {
try {
CompilationIdentifier expectedId = getCompilationId(expectedTarget);
StructuredGraph expectedGraph = partialEval(expectedTarget, arguments, expectedId);
getTruffleCompiler(expectedTarget).compilePEGraph(expectedGraph, "expectedTest", getSuite(expectedTarget), expectedTarget, asCompilationRequest(expectedId), null,
newTask());
removeFrameStates(expectedGraph);
CompilationIdentifier actualId = getCompilationId(actualTarget);
StructuredGraph actualGraph = partialEval(actualTarget, arguments, actualId);
getTruffleCompiler(actualTarget).compilePEGraph(actualGraph, "actualTest", getSuite(actualTarget), actualTarget, asCompilationRequest(actualId), null,
newTask());
removeFrameStates(actualGraph);
assertEquals(expectedGraph, actualGraph, true, true);
return;
} catch (BailoutException e) {
if (e.isPermanent()) {
throw e;
}
lastBailout = e;
continue;
}
}
if (lastBailout != null) {
throw lastBailout;
}
}
private static TruffleCompilationTask newTask() {
return new TruffleCompilationTask() {
@Override
public boolean isCancelled() {
return false;
}
@Override
public boolean isLastTier() {
return true;
}
};
}
protected OptimizedCallTarget assertPartialEvalEquals(String methodName, RootNode root, Object[] arguments) {
final OptimizedCallTarget compilable = (OptimizedCallTarget) Truffle.getRuntime().createCallTarget(root);
BailoutException lastBailout = null;
for (int i = 0; i < 10; i++) {
try {
CompilationIdentifier compilationId = getCompilationId(compilable);
StructuredGraph actual = partialEval(compilable, arguments, compilationId);
getTruffleCompiler(compilable).compilePEGraph(actual, methodName, getSuite(compilable), compilable, asCompilationRequest(compilationId), null, newTask());
removeFrameStates(actual);
StructuredGraph expected = parseForComparison(methodName, actual.getDebug());
removeFrameStates(expected);
assertEquals(expected, actual, true, true);
return compilable;
} catch (BailoutException e) {
if (e.isPermanent()) {
throw e;
}
lastBailout = e;
continue;
}
}
if (lastBailout != null) {
throw lastBailout;
}
return compilable;
}
protected void assertPartialEvalNoInvokes(RootNode root) {
assertPartialEvalNoInvokes(root, new Object[0]);
}
protected void assertPartialEvalNoInvokes(RootNode root, Object[] arguments) {
CallTarget callTarget = Truffle.getRuntime().createCallTarget(root);
assertPartialEvalNoInvokes(callTarget, arguments);
}
protected void assertPartialEvalNoInvokes(CallTarget callTarget, Object[] arguments) {
StructuredGraph actual = partialEval((OptimizedCallTarget) callTarget, arguments, INVALID_COMPILATION_ID);
for (MethodCallTargetNode node : actual.getNodes(MethodCallTargetNode.TYPE)) {
Assert.fail("Found invalid method call target node: " + node + " (" + node.targetMethod() + ")");
}
}
protected StructuredGraph partialEval(RootNode root, Object... arguments) {
return partialEval((OptimizedCallTarget) Truffle.getRuntime().createCallTarget(root), arguments, INVALID_COMPILATION_ID);
}
protected StructuredGraph partialEval(OptimizedCallTarget compilable, Object[] arguments) {
return partialEval(compilable, arguments, INVALID_COMPILATION_ID);
}
protected void compile(OptimizedCallTarget compilable, StructuredGraph graph) {
String methodName = "test";
CompilationIdentifier compilationId = getCompilationId(compilable);
getTruffleCompiler(compilable).compilePEGraph(graph, methodName, getSuite(compilable), compilable, asCompilationRequest(compilationId), null, newTask());
}
@SuppressWarnings("try")
protected StructuredGraph partialEval(OptimizedCallTarget compilable, Object[] arguments, CompilationIdentifier compilationId) {
try {
compilable.call(arguments);
} catch (IgnoreError e) {
}
try {
compilable.call(arguments);
} catch (IgnoreError e) {
}
try {
compilable.call(arguments);
} catch (IgnoreError e) {
}
OptionValues options = getGraalOptions();
DebugContext debug = getDebugContext(options);
lastDebug = debug;
try (DebugContext.Scope s = debug.scope("TruffleCompilation", new TruffleDebugJavaMethod(compilable))) {
SpeculationLog speculationLog = compilable.getCompilationSpeculationLog();
if (speculationLog != null) {
speculationLog.collectFailedSpeculations();
}
final PartialEvaluator partialEvaluator = getTruffleCompiler(compilable).getPartialEvaluator();
final PartialEvaluator.Request request = partialEvaluator.new Request(compilable.getOptionValues(), debug, compilable, partialEvaluator.rootForCallTarget(compilable),
new TruffleInlining(),
compilationId, speculationLog, null);
return partialEvaluator.evaluate(request);
} catch (Throwable e) {
throw debug.handle(e);
}
}
protected OptionValues getGraalOptions() {
OptionValues options = TruffleCompilerRuntime.getRuntime().getGraalOptions(OptionValues.class);
if (preventDumping) {
options = new OptionValues(options, DumpOnError, false);
}
return options;
}
protected void removeFrameStates(StructuredGraph graph) {
for (FrameState frameState : graph.getNodes(FrameState.TYPE)) {
frameState.replaceAtUsages(null);
frameState.safeDelete();
}
new DeadCodeEliminationPhase().apply(graph);
}
@SuppressWarnings("try")
protected StructuredGraph parseForComparison(final String methodName, DebugContext debug) {
try (DebugContext.Scope s = debug.scope("Truffle", new DebugDumpScope("Comparison: " + methodName))) {
StructuredGraph graph = parseEager(methodName, AllowAssumptions.YES);
compile(graph.method(), graph);
return graph;
} catch (Throwable e) {
throw debug.handle(e);
}
}
private PhaseSuite<HighTierContext> getSuite(OptimizedCallTarget callTarget) {
PhaseSuite<HighTierContext> result = suite;
if (result == null) {
synchronized (this) {
result = suite;
if (result == null) {
result = getTruffleCompiler(callTarget).createGraphBuilderSuite();
suite = result;
}
}
}
return result;
}
@SuppressWarnings("serial")
protected static final class IgnoreError extends ControlFlowException {
}
protected class PreventDumping implements AutoCloseable {
private final boolean previous;
protected PreventDumping() {
previous = preventDumping;
preventDumping = true;
}
@Override
public void close() {
preventDumping = previous;
}
}
}