package com.oracle.truffle.tools.profiler.test;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.graalvm.polyglot.Source;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import com.oracle.truffle.api.CallTarget;
import com.oracle.truffle.api.RootCallTarget;
import com.oracle.truffle.api.Truffle;
import com.oracle.truffle.api.TruffleLanguage;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.instrumentation.GenerateWrapper;
import com.oracle.truffle.api.instrumentation.InstrumentableNode;
import com.oracle.truffle.api.instrumentation.ProbeNode;
import com.oracle.truffle.api.instrumentation.ProvidedTags;
import com.oracle.truffle.api.instrumentation.StandardTags;
import com.oracle.truffle.api.instrumentation.Tag;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.source.SourceSection;
import com.oracle.truffle.tools.profiler.CPUSampler;
import com.oracle.truffle.tools.profiler.ProfilerNode;
public class CPUSamplerTest extends AbstractProfilerTest {
private static CPUSampler sampler;
final int executionCount = 10;
@Before
public void setupSampler() {
sampler = CPUSampler.find(context.getEngine());
Assert.assertNotNull(sampler);
synchronized (sampler) {
sampler.setGatherSelfHitTimes(true);
sampler.setDelaySamplingUntilNonInternalLangInit(false);
}
}
@Test
public void testCollectingAndHasData() {
sampler.setCollecting(true);
Assert.assertEquals(0, sampler.getSampleCount());
Assert.assertTrue(sampler.isCollecting());
Assert.assertFalse(sampler.hasData());
for (int i = 0; i < executionCount; i++) {
eval(defaultSourceForSampling);
}
Assert.assertNotEquals(0, sampler.getSampleCount());
Assert.assertTrue(sampler.isCollecting());
Assert.assertTrue(sampler.hasData());
sampler.setCollecting(false);
Assert.assertFalse(sampler.isCollecting());
Assert.assertTrue(sampler.hasData());
sampler.clearData();
Assert.assertFalse(sampler.isCollecting());
Assert.assertEquals(0, sampler.getSampleCount());
Assert.assertFalse(sampler.hasData());
}
Source defaultSourceForSampling = makeSource("ROOT(" +
"DEFINE(foo,ROOT(SLEEP(1)))," +
"DEFINE(bar,ROOT(BLOCK(STATEMENT,LOOP(10, CALL(foo)))))," +
"DEFINE(baz,ROOT(BLOCK(STATEMENT,LOOP(10, CALL(bar)))))," +
"CALL(baz),CALL(bar)" +
")");
@Test
public void testCorrectRootStructure() {
sampler.setFilter(NO_INTERNAL_ROOT_TAG_FILTER);
sampler.setCollecting(true);
for (int i = 0; i < executionCount; i++) {
eval(defaultSourceForSampling);
}
Collection<ProfilerNode<CPUSampler.Payload>> children = sampler.getRootNodes();
Assert.assertEquals(1, children.size());
ProfilerNode<CPUSampler.Payload> program = children.iterator().next();
Assert.assertEquals("", program.getRootName());
checkTimeline(program.getPayload());
children = program.getChildren();
Assert.assertEquals(2, children.size());
Iterator<ProfilerNode<CPUSampler.Payload>> iterator = children.iterator();
ProfilerNode<CPUSampler.Payload> baz = iterator.next();
if (!"baz".equals(baz.getRootName())) {
baz = iterator.next();
}
Assert.assertEquals("baz", baz.getRootName());
checkTimeline(baz.getPayload());
children = baz.getChildren();
Assert.assertEquals(1, children.size());
ProfilerNode<CPUSampler.Payload> bar = children.iterator().next();
Assert.assertEquals("bar", bar.getRootName());
checkTimeline(bar.getPayload());
children = bar.getChildren();
Assert.assertEquals(1, children.size());
ProfilerNode<CPUSampler.Payload> foo = children.iterator().next();
Assert.assertEquals("foo", foo.getRootName());
checkTimeline(foo.getPayload());
children = foo.getChildren();
Assert.assertTrue(children.size() == 0);
}
final Source defaultRecursiveSourceForSampling = makeSource("ROOT(" +
"DEFINE(rfoo,ROOT(BLOCK(RECURSIVE_CALL(foo, 10),SLEEP(1))))," +
"DEFINE(rbar,ROOT(BLOCK(STATEMENT,LOOP(10, CALL(foo)))))," +
"CALL(bar)" +
")");
@Test
@Ignore("non-deterministic failures on spark")
public void testCorrectRootStructureRecursive() {
sampler.setFilter(NO_INTERNAL_ROOT_TAG_FILTER);
sampler.setCollecting(true);
for (int i = 0; i < executionCount; i++) {
eval(defaultRecursiveSourceForSampling);
}
Collection<ProfilerNode<CPUSampler.Payload>> children = sampler.getRootNodes();
Assert.assertEquals(1, children.size());
ProfilerNode<CPUSampler.Payload> program = children.iterator().next();
Assert.assertEquals("", program.getRootName());
checkTimeline(program.getPayload());
children = program.getChildren();
Assert.assertEquals(1, children.size());
ProfilerNode<CPUSampler.Payload> bar = children.iterator().next();
Assert.assertEquals("rbar", bar.getRootName());
checkTimeline(bar.getPayload());
children = bar.getChildren();
Assert.assertEquals(1, children.size());
ProfilerNode<CPUSampler.Payload> foo = children.iterator().next();
Assert.assertEquals("rfoo", foo.getRootName());
checkTimeline(foo.getPayload());
for (int i = 0; i < 10; i++) {
children = foo.getChildren();
Assert.assertEquals(1, children.size());
foo = children.iterator().next();
Assert.assertEquals("rfoo", foo.getRootName());
checkTimeline(bar.getPayload());
}
children = foo.getChildren();
Assert.assertTrue(children.size() == 0);
}
@Test
public void testShadowStackOverflows() {
sampler.setFilter(NO_INTERNAL_ROOT_TAG_FILTER);
sampler.setStackLimit(2);
sampler.setCollecting(true);
for (int i = 0; i < executionCount; i++) {
eval(defaultSourceForSampling);
}
Assert.assertTrue(sampler.hasStackOverflowed());
}
private static void checkTimeline(CPUSampler.Payload payload) {
Assert.assertEquals("Timeline length and self hit count to not match!", payload.getSelfHitCount(), payload.getSelfHitTimes().size());
}
@TruffleLanguage.Registration(id = RecreateShadowStackTestLanguage.ID, name = "RecreateShadowStackTestLanguage", version = "0.1")
@ProvidedTags({StandardTags.StatementTag.class, StandardTags.RootTag.class})
public static class RecreateShadowStackTestLanguage extends TruffleLanguage<Integer> {
public static final String ID = "RecreateShadowStackTestLanguage";
@Override
protected Integer createContext(Env env) {
return 0;
}
@Override
protected boolean isThreadAccessAllowed(Thread thread, boolean singleThreaded) {
return true;
}
@Override
protected CallTarget parse(ParsingRequest request) throws Exception {
com.oracle.truffle.api.source.Source source = request.getSource();
SourceSection statement = null;
SourceSection root = null;
StatementNode startSamplerChild = null;
String sourceString = source.getCharacters().toString();
if (sourceString.equals("Statement Root")) {
statement = source.createSection(0, 9);
root = source.createSection(10, 4);
startSamplerChild = new StatementNode(statement, new SleepNode());
} else if (sourceString.equals("Root Root")) {
statement = source.createUnavailableSection();
root = source.createSection(0, 4);
RootCallTarget sleepTarget = Truffle.getRuntime().createCallTarget(new SRootNode(this, new RootNode(root, new SleepNode())));
startSamplerChild = new StatementNode(statement, new CallNode(Truffle.getRuntime().createDirectCallNode(sleepTarget)));
} else {
Assert.fail("Unsupported parse request.");
}
RootCallTarget innerTarget = Truffle.getRuntime().createCallTarget(
new SRootNode(this, new RootNode(root,
new DummyNode(new StatementNode(statement,
new DummyNode(new StatementNode(statement, new DummyNode(new StartSamplerNode(sampler, startSamplerChild)))))))));
DirectCallNode directCallNode = Truffle.getRuntime().createDirectCallNode(innerTarget);
return Truffle.getRuntime().createCallTarget(
new SRootNode(this,
new RootNode(root, new DummyNode(new StatementNode(statement, new DummyNode(new StatementNode(statement, new DummyNode(new CallNode(directCallNode)))))))));
}
abstract static class SamplerTestNode extends Node {
public abstract Object execute(VirtualFrame frame);
}
@GenerateWrapper
abstract static class SamplerTestInstrumentableNode extends SamplerTestNode implements InstrumentableNode {
final SourceSection sourceSection;
protected SamplerTestInstrumentableNode(SourceSection sourceSection) {
this.sourceSection = sourceSection;
}
@Override
public abstract Object execute(VirtualFrame frame);
@Override
public boolean isInstrumentable() {
return true;
}
@Override
public SourceSection getSourceSection() {
return sourceSection;
}
@Override
public WrapperNode createWrapper(ProbeNode probe) {
return new SamplerTestInstrumentableNodeWrapper(sourceSection, this, probe);
}
}
static class SleepNode extends SamplerTestNode {
@Override
public Object execute(VirtualFrame frame) {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
Assert.fail("Thread interrupted");
}
return 5;
}
}
static class DummyNode extends SamplerTestNode {
@Child SamplerTestNode node;
DummyNode(SamplerTestNode node) {
this.node = node;
}
@Override
public Object execute(VirtualFrame frame) {
return node.execute(frame);
}
}
static class StatementNode extends SamplerTestInstrumentableNode {
@Child SamplerTestNode node;
StatementNode(SourceSection sourceSection, SamplerTestNode node) {
super(sourceSection);
this.node = node;
}
@Override
public boolean hasTag(Class<? extends Tag> tag) {
return tag == StandardTags.StatementTag.class;
}
@Override
public Object execute(VirtualFrame frame) {
return node.execute(frame);
}
}
static class RootNode extends SamplerTestInstrumentableNode {
@Child SamplerTestNode node;
RootNode(SourceSection sourceSection, SamplerTestNode node) {
super(sourceSection);
this.node = node;
}
@Override
public boolean hasTag(Class<? extends Tag> tag) {
return tag == StandardTags.RootTag.class;
}
@Override
public Object execute(VirtualFrame frame) {
return node.execute(frame);
}
}
static class SRootNode extends com.oracle.truffle.api.nodes.RootNode {
@Child SamplerTestNode child;
SRootNode(RecreateShadowStackTestLanguage language, SamplerTestNode child) {
super(language);
this.child = child;
}
@Override
public Object execute(VirtualFrame frame) {
return child.execute(frame);
}
}
static class CallNode extends SamplerTestNode {
@Child DirectCallNode callNode;
CallNode(DirectCallNode callNode) {
this.callNode = callNode;
}
@Override
public Object execute(VirtualFrame frame) {
return callNode.call(new Object[]{});
}
}
static class StartSamplerNode extends SamplerTestNode {
CPUSampler sampler;
@Child SamplerTestNode child;
StartSamplerNode(CPUSampler sampler, SamplerTestNode child) {
this.sampler = sampler;
this.child = child;
}
@Override
public Object execute(VirtualFrame frame) {
Assert.assertTrue("Found roots before enabling sampler", sampler.getRootNodes().isEmpty());
sampler.setCollecting(true);
return child.execute(frame);
}
}
}
@Test
public void testCorrectInitShadowStackStatements() {
sampler.setMode(CPUSampler.Mode.STATEMENTS);
sampler.setFilter(NO_INTERNAL_STATEMENT_TAG_FILTER);
Source test = Source.newBuilder(RecreateShadowStackTestLanguage.ID, "Statement Root", "test").buildLiteral();
context.eval(test);
Collection<ProfilerNode<CPUSampler.Payload>> rootNodes = sampler.getRootNodes();
ProfilerNode<CPUSampler.Payload> current = rootNodes.iterator().next();
current = checkStackState(current, "Statement", false);
current = checkStackState(current, "Statement", false);
current = checkStackState(current, "Statement", false);
current = checkStackState(current, "Statement", false);
checkStackState(current, "Statement", true);
}
private static ProfilerNode<CPUSampler.Payload> checkStackState(ProfilerNode<CPUSampler.Payload> current, String expectedSource, boolean top) {
Assert.assertEquals("Stack not correct", expectedSource, current.getSourceSection().getCharacters().toString());
if (top) {
Assert.assertFalse("Stack too deep", current.getChildren().iterator().hasNext());
return null;
} else {
Assert.assertTrue("Stack not deep enough", current.getChildren().iterator().hasNext());
return current.getChildren().iterator().next();
}
}
@Test
public void testCorrectInitShadowStackRoots() {
sampler.setMode(CPUSampler.Mode.ROOTS);
sampler.setFilter(NO_INTERNAL_ROOT_TAG_FILTER);
Source test = Source.newBuilder(RecreateShadowStackTestLanguage.ID, "Root Root", "test").buildLiteral();
context.eval(test);
Collection<ProfilerNode<CPUSampler.Payload>> rootNodes = sampler.getRootNodes();
ProfilerNode<CPUSampler.Payload> current = rootNodes.iterator().next();
current = checkStackState(current, "Root", false);
current = checkStackState(current, "Root", false);
checkStackState(current, "Root", true);
}
@Test
public void testMultiThreadedRecursive() {
sampler.setFilter(NO_INTERNAL_ROOT_TAG_FILTER);
sampler.setCollecting(true);
Runnable runnable = new Runnable() {
@Override
public void run() {
for (int i = 0; i < executionCount; i++) {
eval(defaultSourceForSampling);
}
}
};
Runnable recursiveRunnable = new Runnable() {
@Override
public void run() {
for (int i = 0; i < executionCount; i++) {
eval(defaultRecursiveSourceForSampling);
}
}
};
Thread first = new Thread(runnable);
first.start();
try {
first.join();
} catch (InterruptedException e) {
Assert.fail("Thread interrupted");
}
Thread second = new Thread(recursiveRunnable);
second.start();
try {
second.join();
} catch (InterruptedException e) {
Assert.fail("Thread interrupted");
}
Map<Thread, Collection<ProfilerNode<CPUSampler.Payload>>> threadToNodesMap = sampler.getThreadToNodesMap();
Collection<ProfilerNode<CPUSampler.Payload>> rootNodes = sampler.getRootNodes();
traverseAndCompareForDifferentSources(rootNodes, threadToNodesMap.get(first));
traverseAndCompareForDifferentSources(rootNodes, threadToNodesMap.get(second));
}
private static void traverseAndCompareForDifferentSources(Collection<ProfilerNode<CPUSampler.Payload>> merged, Collection<ProfilerNode<CPUSampler.Payload>> perThread) {
for (ProfilerNode<CPUSampler.Payload> node : perThread) {
ProfilerNode<CPUSampler.Payload> mergedNode = findNodeBySourceAndRoot(merged, node.getSourceSection(), node.getRootName());
Assert.assertTrue("Merged structure does not mach per thread structure", mergedNode != null);
CPUSampler.Payload mergedNodePayload = mergedNode.getPayload();
CPUSampler.Payload nodePayload = node.getPayload();
Assert.assertTrue("Merged structure does not mach per thread structure", nodePayload.getSelfHitCount() == mergedNodePayload.getSelfHitCount());
Assert.assertTrue("Merged structure does not mach per thread structure", nodePayload.getHitCount() == mergedNodePayload.getHitCount());
traverseAndCompareForDifferentSources(mergedNode.getChildren(), node.getChildren());
}
}
private static ProfilerNode<CPUSampler.Payload> findNodeBySourceAndRoot(Collection<ProfilerNode<CPUSampler.Payload>> merged, SourceSection sourceSection, String rootName) {
for (ProfilerNode<CPUSampler.Payload> node : merged) {
if (node.getSourceSection().equals(sourceSection) && node.getRootName().equals(rootName)) {
return node;
}
}
return null;
}
@Test
public void testMultiThreaded() {
sampler.setFilter(NO_INTERNAL_ROOT_TAG_FILTER);
sampler.setCollecting(true);
Runnable runnable = new Runnable() {
@Override
public void run() {
for (int i = 0; i < executionCount; i++) {
eval(defaultSourceForSampling);
}
}
};
Thread first = new Thread(runnable);
first.start();
try {
first.join();
} catch (InterruptedException e) {
Assert.fail("Thread interrupted");
}
Thread second = new Thread(runnable);
second.start();
try {
second.join();
} catch (InterruptedException e) {
Assert.fail("Thread interrupted");
}
Map<Thread, Collection<ProfilerNode<CPUSampler.Payload>>> threadToNodesMap = sampler.getThreadToNodesMap();
Collection<ProfilerNode<CPUSampler.Payload>> rootNodes = sampler.getRootNodes();
traverseAndCompareForSameSource(rootNodes, threadToNodesMap.get(first), threadToNodesMap.get(second));
}
private void traverseAndCompareForSameSource(Collection<ProfilerNode<CPUSampler.Payload>> rootNodes, Collection<ProfilerNode<CPUSampler.Payload>> profilerNodes1,
Collection<ProfilerNode<CPUSampler.Payload>> profilerNodes2) {
for (ProfilerNode<CPUSampler.Payload> node : rootNodes) {
ProfilerNode<CPUSampler.Payload> found1 = findNodeBySourceAndRoot(profilerNodes1, node.getSourceSection(), node.getRootName());
ProfilerNode<CPUSampler.Payload> found2 = findNodeBySourceAndRoot(profilerNodes2, node.getSourceSection(), node.getRootName());
Assert.assertTrue("Merged structure does not mach per thread structure", found1 != null);
Assert.assertTrue("Merged structure does not mach per thread structure", found2 != null);
CPUSampler.Payload nodePayload = node.getPayload();
CPUSampler.Payload found1Payload = found1.getPayload();
CPUSampler.Payload found2Payload = found2.getPayload();
Assert.assertEquals("Merged structure does not mach per thread structure", nodePayload.getSelfHitCount(), found1Payload.getSelfHitCount() + found2Payload.getSelfHitCount());
Assert.assertEquals("Merged structure does not mach per thread structure", nodePayload.getHitCount(), found1Payload.getHitCount() + found2Payload.getHitCount());
traverseAndCompareForSameSource(node.getChildren(), found1.getChildren(), found2.getChildren());
}
}
@Test
public void testThreadSafe() throws InterruptedException {
sampler.setFilter(NO_INTERNAL_ROOT_TAG_FILTER);
sampler.setCollecting(true);
AtomicBoolean cancelled = new AtomicBoolean();
AtomicInteger iterations = new AtomicInteger();
Runnable runnable = new Runnable() {
@Override
public void run() {
while (!cancelled.get()) {
iterations.incrementAndGet();
eval(defaultSourceForSampling);
}
}
};
Thread execThread = new Thread(runnable);
execThread.start();
ensureIterations(iterations, 2);
Collection<ProfilerNode<CPUSampler.Payload>> oldNodes = sampler.getRootNodes();
try {
for (int i = 0; i < 5; i++) {
Collection<ProfilerNode<CPUSampler.Payload>> newNodes = sampler.getRootNodes();
isSuperset(oldNodes, newNodes);
oldNodes = newNodes;
ensureIterations(iterations, 1);
}
} finally {
cancelled.set(true);
}
execThread.join(10000);
}
private static void ensureIterations(AtomicInteger iterations, int numberIterations) throws InterruptedException {
iterations.set(0);
while (iterations.get() < numberIterations) {
Thread.sleep(5);
}
}
private static void isSuperset(Collection<ProfilerNode<CPUSampler.Payload>> firstRootNodes, Collection<ProfilerNode<CPUSampler.Payload>> secondRootNodes) {
for (ProfilerNode<CPUSampler.Payload> firstNode : firstRootNodes) {
ProfilerNode<CPUSampler.Payload> secondNode = null;
Iterator<ProfilerNode<CPUSampler.Payload>> iterator = secondRootNodes.iterator();
while (iterator.hasNext()) {
secondNode = iterator.next();
if (secondNode.getSourceSection().equals(firstNode.getSourceSection()) && secondNode.getRootName().equals(firstNode.getRootName())) {
break;
}
}
Assert.assertTrue("Profile taken later in execution is not superset of earlier one", secondNode != null);
Assert.assertTrue("Profile taken later in execution is not superset of earlier one", firstNode.getPayload().getSelfHitCount() <= secondNode.getPayload().getSelfHitCount());
isSuperset(firstNode.getChildren(), secondNode.getChildren());
}
}
@Test
public void testNegativePeriod() {
expectProfilerException(() -> sampler.setPeriod(-1), () -> sampler.setCollecting(true));
}
@Test
public void testNegativeDelay() {
expectProfilerException(() -> sampler.setDelay(-1), () -> sampler.setCollecting(true));
}
@Test
public void testStackLimit() {
expectProfilerException(() -> sampler.setStackLimit(-1), () -> sampler.setCollecting(true));
}
@Test
public void testClosedConfig() {
expectProfilerException(() -> {
sampler.close();
sampler.setDelay(1);
}, () -> sampler.setCollecting(true));
}
}