package org.springframework.data.mapping.callback;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.function.BiFunction;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.ResolvableType;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ConcurrentReferenceHashMap;
import org.springframework.util.ReflectionUtils;
class DefaultEntityCallbacks implements EntityCallbacks {
private final Map<Class<?>, Method> callbackMethodCache = new ConcurrentReferenceHashMap<>(64);
private final SimpleEntityCallbackInvoker callbackInvoker = new SimpleEntityCallbackInvoker();
private final EntityCallbackDiscoverer callbackDiscoverer;
DefaultEntityCallbacks() {
this.callbackDiscoverer = new EntityCallbackDiscoverer();
}
DefaultEntityCallbacks(BeanFactory beanFactory) {
this.callbackDiscoverer = new EntityCallbackDiscoverer(beanFactory);
}
@Override
public <T> T callback(Class<? extends EntityCallback> callbackType, T entity, Object... args) {
Assert.notNull(entity, "Entity must not be null!");
Class<T> entityType = (Class<T>) (entity != null ? ClassUtils.getUserClass(entity.getClass())
: callbackDiscoverer.resolveDeclaredEntityType(callbackType).getRawClass());
Method callbackMethod = callbackMethodCache.computeIfAbsent(callbackType, it -> {
Method method = EntityCallbackDiscoverer.lookupCallbackMethod(it, entityType, args);
ReflectionUtils.makeAccessible(method);
return method;
});
T value = entity;
for (EntityCallback<T> callback : callbackDiscoverer.getEntityCallbacks(entityType,
ResolvableType.forClass(callbackType))) {
BiFunction<EntityCallback<T>, T, Object> callbackFunction = EntityCallbackDiscoverer
.computeCallbackInvokerFunction(callback, callbackMethod, args);
value = callbackInvoker.invokeCallback(callback, value, callbackFunction);
}
return value;
}
@Override
public void addEntityCallback(EntityCallback<?> callback) {
this.callbackDiscoverer.addEntityCallback(callback);
}
class SimpleEntityCallbackInvoker implements org.springframework.data.mapping.callback.EntityCallbackInvoker {
@Override
public <T> T invokeCallback(EntityCallback<T> callback, T entity,
BiFunction<EntityCallback<T>, T, Object> callbackInvokerFunction) {
try {
Object value = callbackInvokerFunction.apply(callback, entity);
if (value != null) {
return (T) value;
}
throw new IllegalArgumentException(
String.format("Callback invocation on %s returned null value for %s", callback.getClass(), entity));
} catch (ClassCastException ex) {
String msg = ex.getMessage();
if (msg == null || EntityCallbackInvoker.matchesClassCastMessage(msg, entity.getClass())) {
Log logger = LogFactory.getLog(getClass());
if (logger.isDebugEnabled()) {
logger.debug("Non-matching callback type for entity callback: " + callback, ex);
}
return entity;
} else {
throw ex;
}
}
}
}
}