package com.oracle.svm.hosted.code;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.graalvm.nativeimage.ImageSingletons;
import org.graalvm.nativeimage.hosted.Feature;
import com.oracle.graal.pointsto.meta.AnalysisMethod;
import com.oracle.svm.core.annotate.AutomaticFeature;
import com.oracle.svm.core.annotate.RestrictHeapAccess;
import com.oracle.svm.core.annotate.RestrictHeapAccess.Access;
import com.oracle.svm.core.annotate.Uninterruptible;
import com.oracle.svm.core.heap.RestrictHeapAccessCallees;
import com.oracle.svm.core.util.VMError;
import com.oracle.svm.hosted.FeatureImpl.DuringAnalysisAccessImpl;
import com.oracle.svm.hosted.meta.HostedMethod;
import jdk.vm.ci.code.BytecodePosition;
import jdk.vm.ci.meta.ResolvedJavaMethod;
public class RestrictHeapAccessCalleesImpl implements RestrictHeapAccessCallees {
private Map<AnalysisMethod, RestrictionInfo> calleeToCallerMap;
private List<ResolvedJavaMethod> assertionErrorConstructorList;
private boolean initialized;
public RestrictHeapAccessCalleesImpl() {
calleeToCallerMap = Collections.emptyMap();
this.assertionErrorConstructorList = Collections.emptyList();
initialized = false;
}
public void setAssertionErrorConstructors(List<ResolvedJavaMethod> resolvedConstructorList) {
if (assertionErrorConstructorList.isEmpty()) {
assertionErrorConstructorList = resolvedConstructorList;
}
}
public RestrictionInfo getRestrictionInfo(ResolvedJavaMethod method) {
return calleeToCallerMap.get(methodToKey(method));
}
@Override
public boolean mustNotAllocate(ResolvedJavaMethod method) {
RestrictionInfo info = getRestrictionInfo(method);
return info != null && (info.getAccess() == Access.NO_ALLOCATION || info.getAccess().isMoreRestrictiveThan(Access.NO_ALLOCATION));
}
public Map<AnalysisMethod, RestrictionInfo> getCallerMap() {
return calleeToCallerMap;
}
public void aggregateMethods(Collection<AnalysisMethod> methods) {
assert !initialized : "RestrictHeapAccessCallees.aggregateMethods: Should only initialize once.";
final Map<AnalysisMethod, RestrictionInfo> aggregation = new HashMap<>();
final MethodAggregator visitor = new MethodAggregator(aggregation, assertionErrorConstructorList);
final AnalysisMethodCalleeWalker walker = new AnalysisMethodCalleeWalker();
for (AnalysisMethod method : methods) {
final RestrictHeapAccess annotation = method.getAnnotation(RestrictHeapAccess.class);
if ((annotation != null && annotation.access() != Access.UNRESTRICTED) || method.isAnnotationPresent(Uninterruptible.class)) {
for (AnalysisMethod calleeImpl : method.getImplementations()) {
walker.walkMethod(calleeImpl, visitor);
}
}
}
calleeToCallerMap = Collections.unmodifiableMap(aggregation);
initialized = true;
}
private static AnalysisMethod methodToKey(ResolvedJavaMethod method) {
final AnalysisMethod result;
if (method instanceof AnalysisMethod) {
result = (AnalysisMethod) method;
} else if (method instanceof HostedMethod) {
result = ((HostedMethod) method).getWrapped();
} else {
throw VMError.shouldNotReachHere("RestrictHeapAccessCallees.methodToKey: ResolvedJavaMethod is neither an AnalysisMethod nor a HostedMethod: " + method);
}
return result;
}
static class MethodAggregator extends AnalysisMethodCalleeWalker.CallPathVisitor {
private final Map<AnalysisMethod, RestrictionInfo> calleeToCallerMap;
private final List<ResolvedJavaMethod> assertionErrorConstructorList;
MethodAggregator(Map<AnalysisMethod, RestrictionInfo> calleeToCallerMap, List<ResolvedJavaMethod> assertionErrorConstructorList) {
this.calleeToCallerMap = calleeToCallerMap;
this.assertionErrorConstructorList = assertionErrorConstructorList;
}
@Override
public VisitResult visitMethod(AnalysisMethod callee, AnalysisMethod caller, BytecodePosition invokePosition, int depth) {
Access access = Access.UNRESTRICTED;
boolean overridesCallers = false;
boolean fromUninterruptible = false;
if (callee.isAnnotationPresent(Uninterruptible.class)) {
access = Access.NO_ALLOCATION;
fromUninterruptible = true;
}
RestrictHeapAccess annotation = callee.getAnnotation(RestrictHeapAccess.class);
if (annotation != null) {
access = annotation.access();
overridesCallers = annotation.overridesCallers();
fromUninterruptible = false;
}
if (overridesCallers || caller == null) {
if (access == Access.UNRESTRICTED) {
return VisitResult.CUT;
}
} else {
RestrictionInfo callerInfo = calleeToCallerMap.get(caller);
Access callerAccess = callerInfo.getAccess();
if (callerAccess.equals(access) || callerAccess.isMoreRestrictiveThan(access)) {
if (callerInfo.isFromUninterruptible()) {
if (caller.getAnnotation(Uninterruptible.class) != null && caller.getAnnotation(Uninterruptible.class).calleeMustBe()) {
access = callerAccess;
fromUninterruptible = true;
} else if (access == Access.UNRESTRICTED) {
return VisitResult.CUT;
}
} else {
access = callerAccess;
fromUninterruptible = false;
}
}
}
if (access == Access.NO_ALLOCATION && assertionErrorConstructorList != null && assertionErrorConstructorList.contains(callee)) {
return VisitResult.CUT;
}
RestrictionInfo restrictionInfo = calleeToCallerMap.get(callee);
if (restrictionInfo != null && !access.isMoreRestrictiveThan(restrictionInfo.getAccess())) {
return VisitResult.CUT;
}
StackTraceElement callerStackTraceElement = (invokePosition != null) ? invokePosition.getMethod().asStackTraceElement(invokePosition.getBCI()) : null;
restrictionInfo = new RestrictionInfo(access, caller, callerStackTraceElement, callee, fromUninterruptible);
calleeToCallerMap.put(callee, restrictionInfo);
return VisitResult.CONTINUE;
}
}
public static class RestrictionInfo {
private final RestrictHeapAccess.Access access;
private final AnalysisMethod caller;
private final StackTraceElement invocationStackTraceElement;
private final AnalysisMethod method;
private final boolean fromUninterruptible;
RestrictionInfo(Access access, AnalysisMethod caller, StackTraceElement stackTraceElement, AnalysisMethod method, boolean fromUninterruptible) {
this.access = access;
this.caller = caller;
this.invocationStackTraceElement = stackTraceElement;
this.method = method;
this.fromUninterruptible = fromUninterruptible;
}
public Access getAccess() {
return access;
}
public AnalysisMethod getCaller() {
return caller;
}
public StackTraceElement getInvocationStackTraceElement() {
return invocationStackTraceElement;
}
public AnalysisMethod getMethod() {
return method;
}
public boolean isFromUninterruptible() {
return fromUninterruptible;
}
}
}
@AutomaticFeature
class RestrictHeapAccessCalleesFeature implements Feature {
@Override
public void afterRegistration(AfterRegistrationAccess access) {
ImageSingletons.add(RestrictHeapAccessCallees.class, new RestrictHeapAccessCalleesImpl());
}
@Override
public void duringAnalysis(DuringAnalysisAccess access) {
List<ResolvedJavaMethod> assertionErrorConstructorList = initializeAssertionErrorConstructors(access);
((RestrictHeapAccessCalleesImpl) ImageSingletons.lookup(RestrictHeapAccessCallees.class)).setAssertionErrorConstructors(assertionErrorConstructorList);
}
private static List<ResolvedJavaMethod> initializeAssertionErrorConstructors(DuringAnalysisAccess access) {
final List<ResolvedJavaMethod> result = new ArrayList<>();
result.add(findAssertionConstructor(access));
result.add(findAssertionConstructor(access, boolean.class));
result.add(findAssertionConstructor(access, char.class));
result.add(findAssertionConstructor(access, int.class));
result.add(findAssertionConstructor(access, long.class));
result.add(findAssertionConstructor(access, float.class));
result.add(findAssertionConstructor(access, double.class));
result.add(findAssertionConstructor(access, Object.class));
result.add(findAssertionConstructor(access, String.class, Throwable.class));
return result;
}
private static ResolvedJavaMethod findAssertionConstructor(DuringAnalysisAccess access, Class<?>... parameterTypes) {
try {
final Constructor<AssertionError> reflectiveConstructor = AssertionError.class.getConstructor(parameterTypes);
final ResolvedJavaMethod resolvedConstructor = ((DuringAnalysisAccessImpl) access).getMetaAccess().lookupJavaMethod(reflectiveConstructor);
return resolvedConstructor;
} catch (NoSuchMethodException | SecurityException ex) {
throw VMError.shouldNotReachHere("Should have found AssertionError constructor." + ex);
}
}
}