package org.springframework.data.mapping.callback;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
import org.springframework.aop.framework.AopProxyUtils;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ConcurrentReferenceHashMap;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.comparator.Comparators;
class EntityCallbackDiscoverer {
private final CallbackRetriever defaultRetriever = new CallbackRetriever(false);
private final Map<CallbackCacheKey, CallbackRetriever> retrieverCache = new ConcurrentHashMap<>(64);
private final Map<Class<?>, ResolvableType> entityTypeCache = new ConcurrentReferenceHashMap<>(64);
@Nullable private ClassLoader beanClassLoader;
@Nullable private BeanFactory beanFactory;
private Object retrievalMutex = this.defaultRetriever;
EntityCallbackDiscoverer() {}
EntityCallbackDiscoverer(BeanFactory beanFactory) {
setBeanFactory(beanFactory);
}
void addEntityCallback(EntityCallback<?> callback) {
Assert.notNull(callback, "Callback must not be null!");
synchronized (this.retrievalMutex) {
Object singletonTarget = AopProxyUtils.getSingletonTarget(callback);
if (singletonTarget instanceof EntityCallback) {
this.defaultRetriever.entityCallbacks.remove(singletonTarget);
}
this.defaultRetriever.entityCallbacks.add(callback);
this.retrieverCache.clear();
}
}
void addEntityCallbackBean(String callbackBeanName) {
synchronized (this.retrievalMutex) {
this.defaultRetriever.entityCallbackBeans.add(callbackBeanName);
this.retrieverCache.clear();
}
}
void removeEntityCallback(EntityCallback<?> callback) {
synchronized (this.retrievalMutex) {
this.defaultRetriever.entityCallbacks.remove(callback);
this.retrieverCache.clear();
}
}
void removeEntityCallbackBean(String callbackBeanName) {
synchronized (this.retrievalMutex) {
this.defaultRetriever.entityCallbackBeans.remove(callbackBeanName);
this.retrieverCache.clear();
}
}
void clear() {
synchronized (this.retrievalMutex) {
this.defaultRetriever.entityCallbacks.clear();
this.defaultRetriever.entityCallbackBeans.clear();
this.retrieverCache.clear();
}
}
<T extends S, S> Collection<EntityCallback<S>> getEntityCallbacks(Class<T> entity, ResolvableType callbackType) {
Class<?> sourceType = entity;
CallbackCacheKey cacheKey = new CallbackCacheKey(callbackType, sourceType);
CallbackRetriever retriever = this.retrieverCache.get(cacheKey);
if (retriever != null) {
return (Collection<EntityCallback<S>>) (Collection) retriever.getEntityCallbacks();
}
if (this.beanClassLoader == null || (ClassUtils.isCacheSafe(entity.getClass(), this.beanClassLoader)
&& (sourceType == null || ClassUtils.isCacheSafe(sourceType, this.beanClassLoader)))) {
synchronized (this.retrievalMutex) {
retriever = this.retrieverCache.get(cacheKey);
if (retriever != null) {
return (Collection<EntityCallback<S>>) (Collection) retriever.getEntityCallbacks();
}
retriever = new CallbackRetriever(true);
Collection<EntityCallback<?>> callbacks = retrieveEntityCallbacks(ResolvableType.forClass(sourceType),
callbackType, retriever);
this.retrieverCache.put(cacheKey, retriever);
return (Collection<EntityCallback<S>>) (Collection) callbacks;
}
} else {
return (Collection<EntityCallback<S>>) (Collection) retrieveEntityCallbacks(callbackType, callbackType, null);
}
}
@Nullable
ResolvableType resolveDeclaredEntityType(Class<?> callbackType) {
ResolvableType eventType = entityTypeCache.get(callbackType);
if (eventType == null) {
eventType = ResolvableType.forClass(callbackType).as(EntityCallback.class).getGeneric();
entityTypeCache.put(callbackType, eventType);
}
return (eventType != ResolvableType.NONE ? eventType : null);
}
private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType entityType, ResolvableType callbackType,
@Nullable CallbackRetriever retriever) {
List<EntityCallback<?>> allCallbacks = new ArrayList<>();
Set<EntityCallback<?>> callbacks;
Set<String> callbackBeans;
synchronized (this.retrievalMutex) {
callbacks = new LinkedHashSet<>(this.defaultRetriever.entityCallbacks);
callbackBeans = new LinkedHashSet<>(this.defaultRetriever.entityCallbackBeans);
}
for (EntityCallback<?> callback : callbacks) {
if (supportsEvent(callback, entityType, callbackType)) {
if (retriever != null) {
retriever.getEntityCallbacks().add(callback);
}
allCallbacks.add(callback);
}
}
if (!callbackBeans.isEmpty()) {
BeanFactory beanFactory = getRequiredBeanFactory();
for (String callbackBeanName : callbackBeans) {
try {
Class<?> callbackImplType = beanFactory.getType(callbackBeanName);
if (callbackImplType == null || supportsEvent(callbackImplType, entityType)) {
EntityCallback<?> callback = beanFactory.getBean(callbackBeanName, EntityCallback.class);
if (!allCallbacks.contains(callback) && supportsEvent(callback, entityType, callbackType)) {
if (retriever != null) {
if (beanFactory.isSingleton(callbackBeanName)) {
retriever.entityCallbacks.add(callback);
} else {
retriever.entityCallbackBeans.add(callbackBeanName);
}
}
allCallbacks.add(callback);
}
}
} catch (NoSuchBeanDefinitionException ex) {
}
}
}
AnnotationAwareOrderComparator.sort(allCallbacks);
if (retriever != null && retriever.entityCallbackBeans.isEmpty()) {
retriever.entityCallbacks.clear();
retriever.entityCallbacks.addAll(allCallbacks);
}
return allCallbacks;
}
protected boolean supportsEvent(Class<?> callback, ResolvableType entityType) {
ResolvableType declaredEventType = resolveDeclaredEntityType(callback);
return (declaredEventType == null || declaredEventType.isAssignableFrom(entityType));
}
protected boolean supportsEvent(EntityCallback<?> callback, ResolvableType entityType, ResolvableType callbackType) {
return supportsEvent(callback.getClass(), entityType)
&& callbackType.isAssignableFrom(ResolvableType.forInstance(callback));
}
public void setBeanClassLoader(ClassLoader classLoader) {
this.beanClassLoader = classLoader;
}
public void setBeanFactory(BeanFactory beanFactory) {
this.beanFactory = beanFactory;
if (beanFactory instanceof ConfigurableBeanFactory) {
ConfigurableBeanFactory cbf = (ConfigurableBeanFactory) beanFactory;
if (this.beanClassLoader == null) {
this.beanClassLoader = cbf.getBeanClassLoader();
}
this.retrievalMutex = cbf.getSingletonMutex();
}
defaultRetriever.discoverEntityCallbacks(this.beanFactory);
this.retrieverCache.clear();
}
@Nullable
static Method lookupCallbackMethod(Class<?> callbackType, Class<?> entityType, Object[] args) {
Collection<Method> methods = new ArrayList<>(1);
ReflectionUtils.doWithMethods(callbackType, methods::add, method -> {
if (!Modifier.isPublic(method.getModifiers()) || method.getParameterCount() != args.length + 1
|| method.isBridge() || ReflectionUtils.isObjectMethod(method)) {
return false;
}
return ClassUtils.isAssignable(method.getParameterTypes()[0], entityType);
});
if (methods.size() == 1) {
return methods.iterator().next();
}
throw new IllegalStateException(
String.format("%s does not define a callback method accepting %s and %s additional arguments.",
ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length));
}
static <T> BiFunction<EntityCallback<T>, T, Object> computeCallbackInvokerFunction(EntityCallback<T> callback,
Method callbackMethod, Object[] args) {
return (entityCallback, entity) -> {
Object[] invocationArgs = new Object[args.length + 1];
invocationArgs[0] = entity;
if (args.length > 0) {
System.arraycopy(args, 0, invocationArgs, 1, args.length);
}
return ReflectionUtils.invokeMethod(callbackMethod, callback, invocationArgs);
};
}
private BeanFactory getRequiredBeanFactory() {
Assert.state(beanFactory != null,
"EntityCallbacks cannot retrieve callback beans because it is not associated with a BeanFactory");
return beanFactory;
}
class CallbackRetriever {
private final Set<EntityCallback<?>> entityCallbacks = new LinkedHashSet<>();
private final Set<String> entityCallbackBeans = new LinkedHashSet<>();
private final boolean preFiltered;
CallbackRetriever(boolean preFiltered) {
this.preFiltered = preFiltered;
}
Collection<EntityCallback<?>> getEntityCallbacks() {
List<EntityCallback<?>> allCallbacks = new ArrayList<>(
this.entityCallbacks.size() + this.entityCallbackBeans.size());
allCallbacks.addAll(this.entityCallbacks);
if (!this.entityCallbackBeans.isEmpty()) {
BeanFactory beanFactory = getRequiredBeanFactory();
for (String callbackBeanName : this.entityCallbackBeans) {
try {
EntityCallback<?> callback = beanFactory.getBean(callbackBeanName, EntityCallback.class);
if (this.preFiltered || !allCallbacks.contains(callback)) {
allCallbacks.add(callback);
}
} catch (NoSuchBeanDefinitionException ex) {
}
}
}
if (!this.preFiltered || !this.entityCallbackBeans.isEmpty()) {
AnnotationAwareOrderComparator.sort(allCallbacks);
}
return allCallbacks;
}
void discoverEntityCallbacks(BeanFactory beanFactory) {
beanFactory.getBeanProvider(EntityCallback.class).stream().forEach(entityCallbacks::add);
}
}
static final class CallbackCacheKey implements Comparable<CallbackCacheKey> {
private final ResolvableType callbackType;
private final Class<?> entityType;
public CallbackCacheKey(ResolvableType callbackType, @Nullable Class<?> entityType) {
Assert.notNull(callbackType, "Callback type must not be null");
Assert.notNull(entityType, "Entity type must not be null");
this.callbackType = callbackType;
this.entityType = entityType;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
CallbackCacheKey otherKey = (CallbackCacheKey) other;
return (this.callbackType.equals(otherKey.callbackType)
&& ObjectUtils.nullSafeEquals(this.entityType, otherKey.entityType));
}
@Override
public int hashCode() {
return this.callbackType.hashCode() * 17 + ObjectUtils.nullSafeHashCode(this.entityType);
}
@Override
public int compareTo(CallbackCacheKey other) {
return Comparators.<CallbackCacheKey> nullsHigh().thenComparing(it -> callbackType.toString())
.thenComparing(it -> entityType.getName()).compare(this, other);
}
}
}