package com.oracle.svm.diagnosticsagent;
import static com.oracle.svm.jni.JNIObjectHandles.nullHandle;
import static com.oracle.svm.jvmtiagentbase.Support.check;
import static com.oracle.svm.jvmtiagentbase.Support.jvmtiFunctions;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.graalvm.compiler.serviceprovider.JavaVersionUtil;
import org.graalvm.nativeimage.StackValue;
import org.graalvm.nativeimage.UnmanagedMemory;
import org.graalvm.nativeimage.c.function.CEntryPoint;
import org.graalvm.nativeimage.c.function.CEntryPointLiteral;
import org.graalvm.nativeimage.c.function.CFunctionPointer;
import org.graalvm.nativeimage.c.struct.SizeOf;
import org.graalvm.nativeimage.c.type.CIntPointer;
import org.graalvm.nativeimage.c.type.CTypeConversion;
import org.graalvm.nativeimage.c.type.WordPointer;
import org.graalvm.nativeimage.hosted.Feature;
import org.graalvm.word.WordFactory;
import com.oracle.svm.core.c.function.CEntryPointOptions;
import com.oracle.svm.core.util.VMError;
import com.oracle.svm.hosted.agent.TracingAdvisor;
import com.oracle.svm.jni.nativeapi.JNIEnvironment;
import com.oracle.svm.jni.nativeapi.JNIJavaVM;
import com.oracle.svm.jni.nativeapi.JNIMethodId;
import com.oracle.svm.jni.nativeapi.JNIObjectHandle;
import com.oracle.svm.jvmtiagentbase.AgentIsolate;
import com.oracle.svm.jvmtiagentbase.JNIHandleSet;
import com.oracle.svm.jvmtiagentbase.JvmtiAgentBase;
import com.oracle.svm.jvmtiagentbase.Support;
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiCapabilities;
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiEnv;
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiEnv11;
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiEvent;
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiEventCallbacks;
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiEventMode;
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiInterface;
public class NativeImageDiagnosticsAgent extends JvmtiAgentBase<NativeImageDiagnosticsAgentJNIHandleSet> {
private static final CEntryPointLiteral<CFunctionPointer> ON_CLASS_PREPARE = CEntryPointLiteral.create(NativeImageDiagnosticsAgent.class, "onClassPrepare",
JvmtiEnv.class, JNIEnvironment.class, JNIObjectHandle.class, JNIObjectHandle.class);
private static final CEntryPointLiteral<CFunctionPointer> ON_BREAKPOINT = CEntryPointLiteral.create(NativeImageDiagnosticsAgent.class, "onBreakpoint",
JvmtiEnv.class, JNIEnvironment.class, JNIObjectHandle.class, JNIMethodId.class, long.class);
private final Map<Long, ClassHandleHolder> clinitClassMap = new ConcurrentHashMap<>();
private final Map<Long, ClassHandleHolder> initClassMap = new ConcurrentHashMap<>();
private TracingAdvisor advisor;
private static final class ClassHandleHolder {
final JNIObjectHandle clazz;
ClassHandleHolder(JNIObjectHandle clazz) {
this.clazz = clazz;
}
}
private static final class MethodIdHolder {
final JNIMethodId methodId;
MethodIdHolder(JNIMethodId methodId) {
this.methodId = methodId;
}
}
@Override
protected JNIHandleSet constructJavaHandles(JNIEnvironment env) {
return new NativeImageDiagnosticsAgentJNIHandleSet(env);
}
@Override
protected int onLoadCallback(JNIJavaVM vm, JvmtiEnv jvmti, JvmtiEventCallbacks callbacks, String options) {
advisor = new TracingAdvisor(options);
enableCapabilities(jvmti);
callbacks.setClassPrepare(ON_CLASS_PREPARE.getFunctionPointer());
callbacks.setBreakpoint(ON_BREAKPOINT.getFunctionPointer());
jvmti.getFunctions().SetEventNotificationMode().invoke(jvmti, JvmtiEventMode.JVMTI_ENABLE, JvmtiEvent.JVMTI_EVENT_BREAKPOINT, nullHandle());
return 0;
}
@Override
protected void onVMInitCallback(JvmtiEnv jvmti, JNIEnvironment jni, JNIObjectHandle thread) {
if (JavaVersionUtil.JAVA_SPEC > 8) {
openInstrumentationModuleToAllOtherModules((JvmtiEnv11) jvmti, jni);
}
handles().initializeTrackingSupportHandles(jni);
jvmti.getFunctions().SetEventNotificationMode().invoke(jvmti, JvmtiEventMode.JVMTI_ENABLE, JvmtiEvent.JVMTI_EVENT_CLASS_PREPARE, nullHandle());
setConstructorBreakpointsForLoadedClasses(jvmti, jni);
}
private void setConstructorBreakpointsForLoadedClasses(JvmtiEnv jvmti, JNIEnvironment jni) {
CIntPointer classCountPtr = StackValue.get(CIntPointer.class);
WordPointer classesPtr = StackValue.get(WordPointer.class);
check(jvmtiFunctions().GetLoadedClasses().invoke(jvmti, classCountPtr, classesPtr));
WordPointer classesArray = classesPtr.read();
int classCount = classCountPtr.read();
for (int i = 0; i < classCount; ++i) {
JNIObjectHandle clazz = classesArray.read(i);
String className = Support.getClassNameOrNull(jni, clazz);
if (advisor.shouldTraceObjectInstantiation(className)) {
setConstructorBreakpointsForClass(jvmti, jni, clazz, className);
}
}
jvmtiFunctions().Deallocate().invoke(jvmti, classesArray);
}
private void onClassPrepareCallback(JvmtiEnv jvmti, JNIEnvironment jni, JNIObjectHandle clazz) {
String className = Support.getClassNameOrNull(jni, clazz);
if (className != null) {
if (advisor.shouldTraceClassInitialization(className)) {
JNIMethodId clinitMethodId = getClassClinitMethodIdOrNull(jvmti, clazz);
if (clinitMethodId.notEqual(nullHandle())) {
JNIObjectHandle klass = handles().newTrackedGlobalRef(jni, clazz);
ClassHandleHolder classHandleHolder = new ClassHandleHolder(klass);
clinitClassMap.put(clinitMethodId.rawValue(), classHandleHolder);
check(jvmti.getFunctions().SetBreakpoint().invoke(jvmti, clinitMethodId, 0L));
} else {
System.err.println("Trace class initialization requested for " + className + " but the class has not been instrumented with <clinit>.");
}
}
if (advisor.shouldTraceObjectInstantiation(className)) {
setConstructorBreakpointsForClass(jvmti, jni, clazz, className);
}
}
}
private void setConstructorBreakpointsForClass(JvmtiEnv jvmti, JNIEnvironment jni, JNIObjectHandle clazz, String className) {
List<MethodIdHolder> initMethodIds = getClassMethodIdsWithName(jvmti, clazz, "<init>");
if (initMethodIds.size() != 0) {
JNIObjectHandle klass = handles().newTrackedGlobalRef(jni, clazz);
ClassHandleHolder classHandleHolder = new ClassHandleHolder(klass);
for (MethodIdHolder holder : initMethodIds) {
initClassMap.put(holder.methodId.rawValue(), classHandleHolder);
check(jvmti.getFunctions().SetBreakpoint().invoke(jvmti, holder.methodId, 0L));
}
} else {
System.err.println("Trace object instantiation requested for " + className + " but the class has no constructors.");
}
}
private void onBreakpointCallback(JvmtiEnv jvmti, JNIEnvironment jni, JNIObjectHandle thread, JNIMethodId method) {
if (clinitClassMap.get(method.rawValue()) != null) {
handleClinitBreakpoint(jvmti, jni, method);
} else if (initClassMap.get(method.rawValue()) != null) {
handleInitBreakpoint(jvmti, jni, thread);
} else {
throw VMError.shouldNotReachHere(
"Breakpoint hit for a method that isn't tracked in the diagnostics agent. (For developers: have you set a breakpoint in a method that isn't <clinit> or <init>)");
}
}
private void handleClinitBreakpoint(JvmtiEnv jvmti, JNIEnvironment jni, JNIMethodId method) {
JNIObjectHandle clazz = clinitClassMap.get(method.rawValue()).clazz;
JavaStackTraceCreator stackTraceCreator = new JavaStackTraceCreator(jvmti, jni);
JNIObjectHandle threadStackTrace = stackTraceCreator.getStackTraceArray();
reportClassInitialized(jni, clazz, threadStackTrace);
}
private void handleInitBreakpoint(JvmtiEnv jvmti, JNIEnvironment jni, JNIObjectHandle thread) {
WordPointer thisPtr = StackValue.get(WordPointer.class);
check(jvmti.getFunctions().GetLocalInstance().invoke(jvmti, thread, 0, thisPtr));
JNIObjectHandle thisHandle = thisPtr.read();
ObjectInstantiationTraceCreator stackTraceCreator = new ObjectInstantiationTraceCreator(jvmti, jni);
JNIObjectHandle threadStackTrace = stackTraceCreator.getStackTraceArray();
if (!stackTraceCreator.encounteredObjectInstantiatedReportCall()) {
reportObjectInstantiated(jni, thisHandle, threadStackTrace);
}
}
private static void enableCapabilities(JvmtiEnv jvmti) {
JvmtiCapabilities capabilities = UnmanagedMemory.calloc(SizeOf.get(JvmtiCapabilities.class));
check(jvmti.getFunctions().GetCapabilities().invoke(jvmti, capabilities));
capabilities.setCanGenerateBreakpointEvents(1);
capabilities.setCanAccessLocalVariables(1);
check(jvmti.getFunctions().AddCapabilities().invoke(jvmti, capabilities));
capabilities.setCanGetLineNumbers(1);
capabilities.setCanGetSourceFileName(1);
jvmti.getFunctions().AddCapabilities().invoke(jvmti, capabilities);
UnmanagedMemory.free(capabilities);
}
private void openInstrumentationModuleToAllOtherModules(JvmtiEnv11 jvmti, JNIEnvironment jni) {
JNIObjectHandle moduleClass = handles().findClass(jni, "java/lang/Module");
JNIMethodId moduleGetName = handles().getMethodId(jni, moduleClass, "getName", "()Ljava/lang/String;", false);
try (CTypeConversion.CCharPointerHolder packageName = Support.toCString("org.graalvm.nativeimage.impl.clinit")) {
CIntPointer moduleCountPtr = StackValue.get(CIntPointer.class);
WordPointer modulesPtr = StackValue.get(WordPointer.class);
check(jvmti.getFunctions().GetAllModules().invoke(jvmti, moduleCountPtr, modulesPtr));
int moduleCount = moduleCountPtr.read();
WordPointer modulesArrayPtr = modulesPtr.read();
JNIObjectHandle clinitTrackingSupportModule = nullHandle();
for (int i = 0; i < moduleCount; ++i) {
JNIObjectHandle module = modulesArrayPtr.read(i);
VMError.guarantee(module.notEqual(nullHandle()), "Unexpected null handle while iterating over modules.");
JNIObjectHandle moduleName = Support.callObjectMethod(jni, module, moduleGetName);
String name = Support.fromJniString(jni, moduleName);
if (name != null && name.equals("org.graalvm.sdk")) {
clinitTrackingSupportModule = module;
break;
}
}
VMError.guarantee(clinitTrackingSupportModule.notEqual(nullHandle()), "The the module name that provides clinit reporting support has changed.");
for (int i = 0; i < moduleCount; ++i) {
JNIObjectHandle module = modulesArrayPtr.read(i);
check(jvmti.getFunctions().AddModuleOpens().invoke(jvmti, clinitTrackingSupportModule, packageName.get(), module));
}
jvmti.getFunctions().Deallocate().invoke(jvmti, modulesArrayPtr);
}
}
private static List<MethodIdHolder> getClassMethodIdsWithName(JvmtiEnv jvmti, JNIObjectHandle clazz, String methodName) {
List<MethodIdHolder> methodIds = new ArrayList<>();
CIntPointer methodCountPtr = StackValue.get(CIntPointer.class);
WordPointer methodsPtr = StackValue.get(WordPointer.class);
check(jvmti.getFunctions().GetClassMethods().invoke(jvmti, clazz, methodCountPtr, methodsPtr));
int methodCount = methodCountPtr.read();
WordPointer methodsArray = methodsPtr.read();
for (int i = 0; i < methodCount; ++i) {
JNIMethodId methodId = methodsArray.read(i);
String currentMethodName = Support.getMethodNameOr(methodId, "");
if (currentMethodName.equals(methodName)) {
methodIds.add(new MethodIdHolder(methodId));
}
}
check(jvmti.getFunctions().Deallocate().invoke(jvmti, methodsPtr.read()));
return methodIds;
}
private static JNIMethodId getClassClinitMethodIdOrNull(JvmtiEnv jvmti, JNIObjectHandle clazz) {
List<MethodIdHolder> classMethodIdsWithName = getClassMethodIdsWithName(jvmti, clazz, "<clinit>");
VMError.guarantee(classMethodIdsWithName.size() < 2);
return classMethodIdsWithName.size() == 1 ? classMethodIdsWithName.get(0).methodId : WordFactory.nullPointer();
}
private void reportClassInitialized(JNIEnvironment jni, JNIObjectHandle clazz, JNIObjectHandle stackTrace) {
Support.callStaticVoidMethodLL(jni, handles().getClassInitializationTrackingClassHandle(), handles().getReportClassInitializedMethodId(), clazz, stackTrace);
}
private void reportObjectInstantiated(JNIEnvironment jni, JNIObjectHandle object, JNIObjectHandle stackTrace) {
Support.callStaticVoidMethodLL(jni, handles().getClassInitializationTrackingClassHandle(), handles().getReportObjectInstantiatedMethodId(), object, stackTrace);
}
@CEntryPoint
@CEntryPointOptions(prologue = AgentIsolate.Prologue.class)
@SuppressWarnings("unused")
private static void onClassPrepare(JvmtiEnv jvmti, JNIEnvironment jni,
JNIObjectHandle thread, JNIObjectHandle clazz) {
NativeImageDiagnosticsAgent agent = singleton();
agent.onClassPrepareCallback(jvmti, jni, clazz);
}
@CEntryPoint
@CEntryPointOptions(prologue = AgentIsolate.Prologue.class)
@SuppressWarnings("unused")
private static void onBreakpoint(JvmtiEnv jvmti, JNIEnvironment jni, JNIObjectHandle thread, JNIMethodId method, long location) {
NativeImageDiagnosticsAgent agent = singleton();
agent.onBreakpointCallback(jvmti, jni, thread, method);
}
@Override
protected int onUnloadCallback(JNIJavaVM vm) {
return 0;
}
@Override
protected void onVMStartCallback(JvmtiEnv jvmti, JNIEnvironment jni) {
}
@Override
protected void onVMDeathCallback(JvmtiEnv jvmti, JNIEnvironment jni) {
}
@Override
protected int getRequiredJvmtiVersion() {
if (JavaVersionUtil.JAVA_SPEC > 8) {
return JvmtiInterface.JVMTI_VERSION_9;
}
return JvmtiInterface.JVMTI_VERSION_1_2;
}
public static class RegistrationFeature implements Feature {
@Override
public void afterRegistration(AfterRegistrationAccess access) {
JvmtiAgentBase.registerAgent(new NativeImageDiagnosticsAgent());
}
}
}