package com.oracle.svm.hosted;
import java.lang.reflect.Executable;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import org.graalvm.nativeimage.ImageSingletons;
import org.graalvm.nativeimage.hosted.Feature;
import com.oracle.graal.pointsto.meta.AnalysisField;
import com.oracle.graal.pointsto.meta.AnalysisMetaAccess;
import com.oracle.graal.pointsto.meta.AnalysisMethod;
import com.oracle.graal.pointsto.meta.AnalysisType;
import com.oracle.svm.core.annotate.AutomaticFeature;
import com.oracle.svm.core.util.UserError;
import com.oracle.svm.core.util.VMError;
import com.oracle.svm.hosted.FeatureImpl.BeforeAnalysisAccessImpl;
import com.oracle.svm.hosted.FeatureImpl.DuringAnalysisAccessImpl;
@AutomaticFeature
public class ReachabilityHandlerFeature implements Feature {
private final IdentityHashMap<Object, Set<Object>> activeHandlers = new IdentityHashMap<>();
private final IdentityHashMap<Object, Map<Object, Set<Object>>> triggeredHandlers = new IdentityHashMap<>();
public static ReachabilityHandlerFeature singleton() {
return ImageSingletons.lookup(ReachabilityHandlerFeature.class);
}
public void registerMethodOverrideReachabilityHandler(BeforeAnalysisAccessImpl a, BiConsumer<DuringAnalysisAccess, Executable> callback, Executable baseMethod) {
registerReachabilityHandler(a, callback, new Executable[]{baseMethod}, false);
}
public void registerSubtypeReachabilityHandler(BeforeAnalysisAccess a, BiConsumer<DuringAnalysisAccess, Class<?>> callback, Class<?> baseClass) {
registerReachabilityHandler(a, callback, new Class<?>[]{baseClass}, false);
}
public void registerClassInitializerReachabilityHandler(BeforeAnalysisAccess a, Consumer<DuringAnalysisAccess> callback, Class<?> clazz) {
registerReachabilityHandler(a, callback, new Class<?>[]{clazz}, true);
}
public void registerReachabilityHandler(BeforeAnalysisAccess a, Consumer<DuringAnalysisAccess> callback, Object[] triggers) {
registerReachabilityHandler(a, callback, triggers, false);
}
private void registerReachabilityHandler(BeforeAnalysisAccess a, Object callback, Object[] triggers, boolean triggerOnClassInitializer) {
if (triggeredHandlers.containsKey(callback)) {
return;
}
BeforeAnalysisAccessImpl access = (BeforeAnalysisAccessImpl) a;
AnalysisMetaAccess metaAccess = access.getMetaAccess();
Set<Object> triggerSet = activeHandlers.computeIfAbsent(callback, c -> new HashSet<>());
for (Object trigger : triggers) {
if (trigger instanceof Class) {
AnalysisType aType = metaAccess.lookupJavaType((Class<?>) trigger);
triggerSet.add(triggerOnClassInitializer ? aType.getClassInitializer() : aType);
} else if (trigger instanceof Field) {
triggerSet.add(metaAccess.lookupJavaField((Field) trigger));
} else if (trigger instanceof Executable) {
triggerSet.add(metaAccess.lookupJavaMethod((Executable) trigger));
} else {
throw UserError.abort("registerReachabilityHandler called with an element that is not a Class, Field, Method, or Constructor: %s", trigger.getClass().getTypeName());
}
}
if (access instanceof DuringAnalysisAccess) {
((DuringAnalysisAccess) access).requireAnalysisIteration();
}
}
@Override
public void duringAnalysis(DuringAnalysisAccess a) {
DuringAnalysisAccessImpl access = (DuringAnalysisAccessImpl) a;
HashSet<Object> handledCallbacks = new HashSet<>();
HashSet<Object> callbacks = new HashSet<>(activeHandlers.keySet());
do {
List<Object> completedCallbacks = new ArrayList<>();
for (Object callback : callbacks) {
Set<Object> triggers = activeHandlers.get(callback);
if (callback instanceof Consumer) {
if (isTriggered(access, triggers)) {
triggeredHandlers.put(callback, null);
toExactCallback(callback).accept(access);
completedCallbacks.add(callback);
}
} else {
VMError.guarantee(callback instanceof BiConsumer);
processReachable(access, callback, triggers);
}
handledCallbacks.add(callback);
}
for (Object completed : completedCallbacks) {
activeHandlers.remove(completed);
handledCallbacks.remove(completed);
}
callbacks = new HashSet<>(activeHandlers.keySet());
callbacks.removeAll(handledCallbacks);
} while (!callbacks.isEmpty());
}
private static boolean isTriggered(DuringAnalysisAccessImpl access, Set<Object> triggers) {
for (Object trigger : triggers) {
if (trigger instanceof AnalysisType) {
if (access.isReachable((AnalysisType) trigger)) {
return true;
}
} else if (trigger instanceof AnalysisField) {
if (access.isReachable((AnalysisField) trigger)) {
return true;
}
} else if (trigger instanceof AnalysisMethod) {
AnalysisMethod triggerMethod = (AnalysisMethod) trigger;
if (access.isReachable(triggerMethod)) {
return true;
}
if (triggerMethod.isClassInitializer() && triggerMethod.getDeclaringClass().isInitialized()) {
return true;
}
} else {
throw VMError.shouldNotReachHere("Unexpected trigger: " + trigger.getClass().getTypeName());
}
}
return false;
}
@SuppressWarnings("unchecked")
private static Consumer<DuringAnalysisAccess> toExactCallback(Object callback) {
return (Consumer<DuringAnalysisAccess>) callback;
}
private void processReachable(DuringAnalysisAccessImpl access, Object callback, Set<Object> triggers) {
Map<Object, Set<Object>> handledTriggers = triggeredHandlers.computeIfAbsent(callback, c -> new IdentityHashMap<>());
for (Object trigger : triggers) {
if (trigger instanceof AnalysisType) {
Set<AnalysisType> newReachable = access.reachableSubtypes(((AnalysisType) trigger));
Set<Object> prevReachable = handledTriggers.computeIfAbsent(trigger, c -> new HashSet<>());
newReachable.removeAll(prevReachable);
for (AnalysisType reachable : newReachable) {
toSubtypeCallback(callback).accept(access, reachable.getJavaClass());
prevReachable.add(reachable);
}
} else if (trigger instanceof AnalysisMethod) {
Set<AnalysisMethod> newReachable = access.reachableMethodOverrides((AnalysisMethod) trigger);
Set<Object> prevReachable = handledTriggers.computeIfAbsent(trigger, c -> new HashSet<>());
newReachable.removeAll(prevReachable);
for (AnalysisMethod reachable : newReachable) {
toOverrideCallback(callback).accept(access, reachable.getJavaMethod());
prevReachable.add(reachable);
}
} else {
throw VMError.shouldNotReachHere("Unexpected subtype/override trigger: " + trigger.getClass().getTypeName());
}
}
}
@SuppressWarnings("unchecked")
private static BiConsumer<DuringAnalysisAccess, Class<?>> toSubtypeCallback(Object callback) {
return (BiConsumer<DuringAnalysisAccess, Class<?>>) callback;
}
@SuppressWarnings("unchecked")
private static BiConsumer<DuringAnalysisAccess, Executable> toOverrideCallback(Object callback) {
return (BiConsumer<DuringAnalysisAccess, Executable>) callback;
}
}