package com.oracle.truffle.trufflenode.jniboundaryprofiler;
import java.lang.instrument.Instrumentation;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.object.DynamicObject;
import com.oracle.truffle.js.runtime.builtins.JSFunction;
public class ProfilingAgent {
public static final int DumpEvery = Integer.getInteger("node.native.profiler.interval", 0);
public static final int DumpOnlyTopMethods = Integer.getInteger("node.native.profiler.dumptop", 0);
private static final Deque<String> callStack = new ArrayDeque<>();
private static final Map<String, PerfCounter> bindingExecTimes = new HashMap<>(100);
private static final Map<String, PerfCounter> bindingCalls = new HashMap<>(100);
private static final Map<String, Map<String, PerfCounter>> jniExecTimes = new HashMap<>(100);
private static final Map<String, Map<String, PerfCounter>> jniCalls = new HashMap<>(100);
private static long last = System.nanoTime();
private static long lastJniCallBegin = 0;
private static Map<String, PerfCounter> currentJNICalls;
private static Map<String, PerfCounter> currentJNIExecTimes;
private static int jniMethodCallStack = 0;
private static long firstBoundaryCrossedAt = 0;
public static void premain(@SuppressWarnings("unused") String agentArgs, Instrumentation inst) {
System.out.println("=== Native boundary profiling agent active ===");
inst.addTransformer(new ProfilingTransformer());
}
static {
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
dumpCounters();
}
});
}
private static class PerfCounter implements Comparable<PerfCounter> {
private long value;
PerfCounter() {
this.value = 0;
}
public long longValue() {
return value;
}
public void increment() {
this.value++;
}
public void increment(long inc) {
this.value += inc;
}
public int compareTo(PerfCounter o) {
return (value < o.value) ? -1 : ((value == o.value) ? 0 : 1);
}
}
public static long getNativeCalls(String lbl) {
return bindingCalls.get(lbl).longValue();
}
public static long getJniCalls(String binding, String jniLabel) {
return jniCalls.get(binding).get(jniLabel).longValue();
}
private static String getLabel(String apiName, Object label) throws AssertionError {
String lbl;
if (JSFunction.isJSFunction(label)) {
lbl = JSFunction.getName((DynamicObject) label);
} else if (label instanceof Object[]) {
Object[] args = (Object[]) label;
assert args.length > 1 && JSFunction.isJSFunction(args[1]);
lbl = JSFunction.getName((DynamicObject) args[1]);
} else {
throw new AssertionError("Must instrument calls to JSFunction objects");
}
lbl = "".equals(lbl) ? apiName + ": <unknown>" : apiName + ": " + lbl;
return lbl;
}
public static double getSamplingTime() {
long elapsedTime = (System.nanoTime() - last);
return elapsedTime / 1_000_000_000.0;
}
@TruffleBoundary
public static void bindingCallBegin(String apiName, Object label) {
String lbl = getLabel(apiName, label);
if (callStack.size() == 0) {
Map<String, PerfCounter> calls = jniCalls.get(lbl);
if (calls == null) {
calls = new HashMap<>();
jniCalls.put(lbl, calls);
}
currentJNICalls = calls;
Map<String, PerfCounter> times = jniExecTimes.get(lbl);
if (times == null) {
times = new HashMap<>();
jniExecTimes.put(lbl, times);
}
currentJNIExecTimes = times;
firstBoundaryCrossedAt = System.nanoTime();
}
callStack.push(lbl);
}
@TruffleBoundary
public static void bindingCallEnd() {
String lbl = callStack.pop();
if (callStack.size() == 0) {
if (jniMethodCallStack != 0) {
throw new AssertionError("Broken instrumentation! (not all JNI method calls have returned: " + jniMethodCallStack + ")");
}
long end = System.nanoTime();
PerfCounter totalHits = bindingCalls.get(lbl);
if (totalHits == null) {
totalHits = new PerfCounter();
bindingCalls.put(lbl, totalHits);
}
totalHits.increment();
long elapsedTime = end - firstBoundaryCrossedAt;
PerfCounter total = bindingExecTimes.get(lbl);
if (total == null) {
total = new PerfCounter();
bindingExecTimes.put(lbl, total);
}
total.increment(elapsedTime);
jniExecTimes.put(lbl, currentJNIExecTimes);
if (DumpEvery > 0 && getSamplingTime() > DumpEvery) {
dumpCounters();
}
}
}
@TruffleBoundary
public static void jniCallBegin(String lbl) {
if (callStack.size() != 0) {
if (jniMethodCallStack++ == 0) {
PerfCounter totalHits = currentJNICalls.get(lbl);
if (totalHits == null) {
totalHits = new PerfCounter();
currentJNICalls.put(lbl, totalHits);
}
totalHits.increment();
lastJniCallBegin = System.nanoTime();
}
}
}
@TruffleBoundary
public static void jniCallEnd(String lbl) {
if (callStack.size() != 0) {
if (--jniMethodCallStack == 0) {
long elapsedTime = System.nanoTime() - lastJniCallBegin;
PerfCounter total = currentJNIExecTimes.get(lbl);
if (total == null) {
total = new PerfCounter();
currentJNIExecTimes.put(lbl, total);
}
total.increment(elapsedTime);
}
}
}
@TruffleBoundary
public static void dumpCounters() {
double window = getSamplingTime();
last = System.nanoTime();
System.out.println("\n=== Sampling interval: " + window + " seconds ===");
Map<String, PerfCounter> sortedTimes = bindingExecTimes.entrySet().stream().sorted(Collections.reverseOrder(Entry.comparingByValue())).collect(
Collectors.toMap(Entry::getKey, Entry::getValue,
(e1, e2) -> e1, LinkedHashMap::new));
System.out.println("\n=== Time spent in node.js native calls ===");
for (Entry<String, PerfCounter> entry : sortedTimes.entrySet()) {
double time = entry.getValue().longValue() / 1000000.0;
double perc = (time / (window * 1000)) * 100;
String line = String.format("[%6.2f %%] %-80s |time %10.3f ms |#calls %7d (JS->Cpp)", perc, entry.getKey(), time, bindingCalls.get(entry.getKey()).longValue());
System.out.println(line);
}
System.out.println("\n=== Breakdown of Java methods executed during native calls (presumibly JNI calls) ===");
int dumped = 0;
for (Entry<String, PerfCounter> entry : sortedTimes.entrySet()) {
double nativeTime = entry.getValue().longValue() / 1000000.0;
double perc = (nativeTime / (window * 1000)) * 100;
String header = String.format("[%6.2f %%] %-80s ", perc, entry.getKey());
System.out.println(header);
Map<String, PerfCounter> jniTime = jniExecTimes.get(entry.getKey());
Map<String, PerfCounter> sortedCalls = jniCalls.get(entry.getKey()).entrySet().stream().sorted(Collections.reverseOrder(Entry.comparingByValue())).collect(
Collectors.toMap(Entry::getKey, Entry::getValue,
(e1, e2) -> e1, LinkedHashMap::new));
double total = 0;
double totalTime = 0;
for (Entry<String, PerfCounter> nestedentry : sortedCalls.entrySet()) {
double time = jniTime.get(nestedentry.getKey()) == null ? 0 : jniTime.get(nestedentry.getKey()).longValue() / 1000000.0;
double ratio = nestedentry.getValue().longValue() / (double) bindingCalls.get(entry.getKey()).longValue();
String nestedLine = String.format(" %-91s |#calls %7d |time %10.3f ms |jni calls avg ~%4.1f (Cpp->JS)", nestedentry.getKey(), nestedentry.getValue().longValue(),
time,
ratio);
System.out.println(nestedLine);
total += ratio;
totalTime += time;
}
System.out.println(String.format("\n %92s |total native time %10.3f ms", "", nativeTime));
System.out.println(String.format(" %92s |total time in Java space (~) %10.3f ms", "", totalTime));
System.out.println(String.format(" %92s |total native calls %7d ", "", bindingCalls.get(entry.getKey()).longValue()));
System.out.println(String.format(" %92s |avg JNI Java calls per native call (~) %4.1f \n", "", total));
if (DumpOnlyTopMethods > 0 && ++dumped == DumpOnlyTopMethods) {
break;
}
}
jniExecTimes.clear();
jniCalls.clear();
bindingCalls.clear();
bindingExecTimes.clear();
}
}