package com.oracle.truffle.api.benchmark;
import java.util.function.Supplier;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import com.oracle.truffle.api.benchmark.DSLInterpreterBenchmarkFactory.CachedDSLNodeGen;
import com.oracle.truffle.api.benchmark.DSLInterpreterBenchmarkFactory.SimpleDSLNodeGen;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.UnsupportedSpecializationException;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
@State(Scope.Thread)
@Fork(value = 1)
public class DSLInterpreterBenchmark extends TruffleBenchmark {
private static final int NODES = 10000;
private static TestRootNode root = new TestRootNode();
private static <T extends AbstractNode> T createNode(Supplier<T> nodeFactory) {
T node = nodeFactory.get();
root.child = node;
root.adoptChildren();
return node;
}
@State(Scope.Thread)
public static class SimpleFirstIterationState {
final SimpleDSLNode[] nodes = new SimpleDSLNode[NODES];
@Setup(Level.Invocation)
public void setup() {
for (int i = 0; i < NODES; i++) {
nodes[i] = createNode(SimpleDSLNodeGen::create);
}
}
}
@State(Scope.Thread)
public static class CachedFirstIterationState {
final CachedDSLNode[] nodes = new CachedDSLNode[NODES];
@Setup(Level.Invocation)
public void setup() {
for (int i = 0; i < NODES; i++) {
nodes[i] = createNode(CachedDSLNodeGen::create);
}
}
}
@Setup
public void setupInterpreterProfile() {
for (int i = 0; i < 100; i++) {
AbstractNode node = createNode(SimpleDSLNodeGen::create);
node.execute(42L);
node.execute(42);
try {
node.execute("");
} catch (UnsupportedSpecializationException e) {
}
node = createNode(SimpleDSLNodeGen::create);
node.execute(42);
node.execute(42L);
try {
node.execute("");
} catch (UnsupportedSpecializationException e) {
}
node = createNode(CachedDSLNodeGen::create);
node.execute(42);
node.execute(42L);
try {
node.execute("");
} catch (UnsupportedSpecializationException e) {
}
}
}
@State(Scope.Thread)
public static class SimpleSecondIterationState {
final SimpleDSLNode[] nodes = new SimpleDSLNode[NODES];
@Setup(Level.Invocation)
public void setup() {
for (int i = 0; i < NODES; i++) {
nodes[i] = createNode(SimpleDSLNodeGen::create);
nodes[i].execute(42);
}
}
}
@State(Scope.Thread)
public static class CachedSecondIterationState {
final CachedDSLNode[] nodes = new CachedDSLNode[NODES];
@Setup(Level.Invocation)
public void setup() {
for (int i = 0; i < NODES; i++) {
nodes[i] = createNode(CachedDSLNodeGen::create);
nodes[i].execute(42);
}
}
}
static final class TestRootNode extends RootNode {
protected TestRootNode() {
super(null);
}
@Child Node child;
@Override
public Object execute(VirtualFrame frame) {
return null;
}
}
abstract static class AbstractNode extends Node {
abstract int execute(Object v);
}
abstract static class SimpleDSLNode extends AbstractNode {
@Specialization
int doInt(int v) {
return v;
}
@Specialization
int doLong(long v) {
return (int) v;
}
}
abstract static class CachedDSLNode extends AbstractNode {
@Specialization
int doCached(@SuppressWarnings("unused") int v, @Cached("CACHED") int cached) {
return cached;
}
@Specialization
int doLong(long v) {
return (int) v;
}
static final int CACHED = 42;
}
@Benchmark
@OperationsPerInvocation(NODES)
public int simpleFirstIteration(SimpleFirstIterationState state) {
Integer v = Integer.valueOf(42);
int sum = 0;
for (int i = 0; i < NODES; i++) {
sum += state.nodes[i].execute(v);
}
return sum;
}
@Benchmark
@OperationsPerInvocation(NODES)
public int simpleSecondIteration(SimpleSecondIterationState state) {
Integer v = Integer.valueOf(42);
int sum = 0;
for (int i = 0; i < NODES; i++) {
sum += state.nodes[i].execute(v);
}
return sum;
}
@Benchmark
@OperationsPerInvocation(NODES)
public int cachedFirstIteration(CachedFirstIterationState state) {
Integer v = Integer.valueOf(42);
int sum = 0;
for (int i = 0; i < NODES; i++) {
sum += state.nodes[i].execute(v);
}
return sum;
}
@Benchmark
@OperationsPerInvocation(NODES)
public int cachedSecondIteration(CachedSecondIterationState state) {
Integer v = Integer.valueOf(42);
int sum = 0;
for (int i = 0; i < NODES; i++) {
sum += state.nodes[i].execute(v);
}
return sum;
}
}