package org.ehcache.impl.internal.store.loaderwriter;
import org.ehcache.Cache;
import org.ehcache.core.CacheConfigurationChangeListener;
import org.ehcache.core.Ehcache;
import org.ehcache.core.exceptions.StorePassThroughException;
import org.ehcache.core.util.CollectionUtil;
import org.ehcache.core.spi.store.Store;
import org.ehcache.core.spi.store.WrapperStore;
import org.ehcache.core.spi.store.events.StoreEventSource;
import org.ehcache.expiry.ExpiryPolicy;
import org.ehcache.spi.loaderwriter.BulkCacheLoadingException;
import org.ehcache.spi.loaderwriter.BulkCacheWritingException;
import org.ehcache.spi.loaderwriter.CacheLoaderWriter;
import org.ehcache.spi.loaderwriter.WriteBehindConfiguration;
import org.ehcache.spi.resilience.StoreAccessException;
import org.ehcache.spi.service.ServiceConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.terracotta.context.ContextManager;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import static org.ehcache.core.exceptions.ExceptionFactory.newCacheLoadingException;
import static org.ehcache.core.exceptions.ExceptionFactory.newCacheWritingException;
import static org.ehcache.core.spi.service.ServiceUtils.findSingletonAmongst;
public class LocalLoaderWriterStore<K, V> implements WrapperStore<K, V> {
private static final Logger LOG = LoggerFactory.getLogger(LocalLoaderWriterStore.class);
private static final Supplier<Boolean> SUPPLY_FALSE = () -> Boolean.FALSE;
private final Store<K, V> delegate;
private final CacheLoaderWriter<? super K, V> cacheLoaderWriter;
private final boolean useLoaderInAtomics;
private final ExpiryPolicy<? super K, ? super V> expiry;
public LocalLoaderWriterStore(Store<K, V> delegate, CacheLoaderWriter<? super K, V> cacheLoaderWriter, boolean useLoaderInAtomics,
ExpiryPolicy<? super K, ? super V> expiry) {
this.delegate = delegate;
this.cacheLoaderWriter = cacheLoaderWriter;
this.useLoaderInAtomics = useLoaderInAtomics;
this.expiry = expiry;
ContextManager.associate(delegate).withParent(this);
}
@Override
public ValueHolder<V> get(K key) throws StoreAccessException {
Function<K, V> mappingFunction = k -> {
try {
return cacheLoaderWriter.load(k);
} catch (Exception e) {
throw new StorePassThroughException(newCacheLoadingException(e));
}
};
return delegate.computeIfAbsent(key, mappingFunction);
}
@Override
public boolean containsKey(K key) throws StoreAccessException {
return delegate.containsKey(key);
}
@Override
public PutStatus put(K key, V value) throws StoreAccessException {
BiFunction<K, V, V> remappingFunction = (key1, previousValue) -> {
try {
cacheLoaderWriter.write(key1, value);
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
return value;
};
delegate.getAndCompute(key, remappingFunction);
return Store.PutStatus.PUT;
}
@Override
public ValueHolder<V> putIfAbsent(K key, V value, Consumer<Boolean> put) throws StoreAccessException {
Function<K, V> mappingFunction = k -> {
if (useLoaderInAtomics) {
try {
V loaded = cacheLoaderWriter.load(k);
if (loaded != null) {
return loaded;
}
} catch (Exception e) {
throw new StorePassThroughException(newCacheLoadingException(e));
}
}
try {
cacheLoaderWriter.write(k, value);
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
put.accept(true);
return value;
};
return delegate.computeIfAbsent(key, mappingFunction);
}
@Override
public boolean remove(K key) throws StoreAccessException {
boolean[] modified = { false };
BiFunction<K, V, V> remappingFunction = (key1, previousValue) -> {
modified[0] = (previousValue != null);
try {
cacheLoaderWriter.delete(key1);
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
return null;
};
delegate.getAndCompute(key, remappingFunction);
return modified[0];
}
@Override
public RemoveStatus remove(K key, V value) throws StoreAccessException {
boolean[] hitRemoved = { false, false };
BiFunction<K, V, V> remappingFunction = (k, inCache) -> {
inCache = loadFromLoaderWriter(key, inCache);
if(inCache == null) {
return null;
}
hitRemoved[0] = true;
if (value.equals(inCache)) {
try {
cacheLoaderWriter.delete(k);
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
hitRemoved[1] = true;
return null;
}
return inCache;
};
delegate.computeAndGet(key, remappingFunction, SUPPLY_FALSE, SUPPLY_FALSE);
if (hitRemoved[1]) {
return Store.RemoveStatus.REMOVED;
}
if (hitRemoved[0]) {
return Store.RemoveStatus.KEY_PRESENT;
} else {
return Store.RemoveStatus.KEY_MISSING;
}
}
@Override
public ValueHolder<V> replace(K key, V value) throws StoreAccessException {
@SuppressWarnings("unchecked")
V[] old = (V[]) new Object[1];
BiFunction<K, V, V> remappingFunction = (k, inCache) -> {
inCache = loadFromLoaderWriter(key, inCache);
if(inCache == null) {
return null;
}
try {
cacheLoaderWriter.write(key, value);
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
old[0] = inCache;
if (newValueAlreadyExpired(LOG, expiry, key, inCache, value)) {
return null;
}
return value;
};
delegate.getAndCompute(key, remappingFunction);
if (old[0] == null) {
return null;
}
return new LoaderWriterValueHolder<>(old[0]);
}
@Override
public ReplaceStatus replace(K key, V oldValue, V newValue) throws StoreAccessException {
boolean[] successHit = { false, false };
BiFunction<K, V, V> remappingFunction = (k, inCache) -> {
inCache = loadFromLoaderWriter(key, inCache);
if(inCache == null) {
return null;
}
successHit[1] = true;
if (oldValue.equals(inCache)) {
try {
cacheLoaderWriter.write(key, newValue);
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
successHit[0] = true;
if (newValueAlreadyExpired(LOG, expiry, key, oldValue, newValue)) {
return null;
}
return newValue;
}
return inCache;
};
delegate.computeAndGet(key, remappingFunction, SUPPLY_FALSE, SUPPLY_FALSE);
if (successHit[0]) {
return Store.ReplaceStatus.HIT;
} else {
if (successHit[1]) {
return Store.ReplaceStatus.MISS_PRESENT;
} else {
return Store.ReplaceStatus.MISS_NOT_PRESENT;
}
}
}
@Override
public void clear() throws StoreAccessException {
delegate.clear();
}
@Override
public StoreEventSource<K, V> getStoreEventSource() {
return delegate.getStoreEventSource();
}
@Override
public Iterator<Cache.Entry<K, ValueHolder<V>>> iterator() {
return delegate.iterator();
}
@Override
public ValueHolder<V> getAndCompute(K key, BiFunction<? super K, ? super V, ? extends V> mappingFunction) throws StoreAccessException {
return delegate.getAndCompute(key, (mappedKey, mappedValue) -> {
V newValue = mappingFunction.apply(mappedKey, mappedValue);
if (newValue == null) {
try {
cacheLoaderWriter.delete(mappedKey);
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
return null;
} else {
try {
cacheLoaderWriter.write(mappedKey, newValue);
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
if (newValueAlreadyExpired(LOG, expiry, mappedKey, mappedValue, newValue)) {
return null;
}
return newValue;
}
});
}
@Override
public ValueHolder<V> computeAndGet(K key, BiFunction<? super K, ? super V, ? extends V> mappingFunction, Supplier<Boolean> replaceEqual, Supplier<Boolean> invokeWriter) throws StoreAccessException {
BiFunction<? super K, ? super V, ? extends V> remappingFunction = (mappedKey, mappedValue) -> {
V newValue = mappingFunction.apply(mappedKey, mappedValue);
if (invokeWriter.get()) {
try {
if (newValue != null) {
cacheLoaderWriter.write(mappedKey, newValue);
} else {
cacheLoaderWriter.delete(mappedKey);
}
} catch (Exception e) {
throw new StorePassThroughException(newCacheWritingException(e));
}
}
return newValue;
};
return delegate.computeAndGet(key, remappingFunction, replaceEqual, SUPPLY_FALSE);
}
@Override
public ValueHolder<V> computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) throws StoreAccessException {
throw new UnsupportedOperationException("Not supported");
}
@Override
public Map<K, ValueHolder<V>> bulkCompute(Set<? extends K> keys, Function<Iterable<? extends Map.Entry<? extends K, ? extends V>>, Iterable<? extends Map.Entry<? extends K, ? extends V>>> remappingFunction) throws StoreAccessException {
Set<K> successes = new HashSet<>(1);
Map<K, Exception> failures = new HashMap<>(1);
if(remappingFunction instanceof Ehcache.PutAllFunction) {
return getkValueHolderMap((Ehcache.PutAllFunction<K, V>) remappingFunction, successes, failures);
} else if (remappingFunction instanceof Ehcache.RemoveAllFunction) {
return getkValueHolderMap(keys);
} else {
return delegate.bulkCompute(keys, remappingFunction);
}
}
private Map<K, ValueHolder<V>> getkValueHolderMap(Set<? extends K> keys) throws StoreAccessException {
Set<K> deleteSuccesses = new HashSet<>(1);
Map<K, Exception> deleteFailures = new HashMap<>(1);
Map<K, ? extends V> entriesToRemove = new HashMap<>(keys.size());
for (K key: keys) {
entriesToRemove.put(key, null);
}
int[] actualRemoveCount = { 0 };
Function<Iterable<? extends Map.Entry<? extends K, ? extends V>>, Iterable<? extends Map.Entry<? extends K, ? extends V>>> removalFunction =
entries -> {
Set<K> unknowns = cacheLoaderWriterDeleteAllCall(entries, entriesToRemove, deleteSuccesses, deleteFailures);
int size = CollectionUtil.findBestCollectionSize(entries, 1);
Map<K, V> results = new LinkedHashMap<>(size);
for (Map.Entry<? extends K, ? extends V> entry : entries) {
K key = entry.getKey();
V existingValue = entry.getValue();
if (deleteSuccesses.contains(key)) {
if (existingValue != null) {
++actualRemoveCount[0];
}
results.put(key, null);
entriesToRemove.remove(key);
} else {
if (unknowns.contains(key)) {
results.put(key, null);
} else {
results.put(key, existingValue);
}
}
}
return results.entrySet();
};
Map<K, ValueHolder<V>> map = delegate.bulkCompute(keys, removalFunction);
if (!deleteFailures.isEmpty()) {
throw new BulkCacheWritingException(deleteFailures, deleteSuccesses);
} else {
return map;
}
}
private Map<K, ValueHolder<V>> getkValueHolderMap(Ehcache.PutAllFunction<K, V> remappingFunction, Set<K> successes, Map<K, Exception> failures) throws StoreAccessException {
Ehcache.PutAllFunction<K, V> putAllFunction = remappingFunction;
Map<K, V> entriesToRemap = CollectionUtil.copyMapButFailOnNull(putAllFunction.getEntriesToRemap());
int[] actualPutCount = {0};
Function<Iterable<? extends Map.Entry<? extends K, ? extends V>>, Iterable<? extends Map.Entry<? extends K, ? extends V>>> computeFunction =
entries1 -> {
cacheLoaderWriterWriteAllCall(entries1, entriesToRemap, successes, failures);
int size = CollectionUtil.findBestCollectionSize(entries1, 1);
Map<K, V> mutations = new LinkedHashMap<>(size);
for (Map.Entry<? extends K, ? extends V> entry : entries1) {
K key = entry.getKey();
V existingValue = entry.getValue();
V newValue = entriesToRemap.remove(key);
if (newValueAlreadyExpired(LOG, expiry, key, existingValue, newValue)) {
mutations.put(key, null);
} else if (successes.contains(key)) {
++actualPutCount[0];
mutations.put(key, newValue);
} else {
mutations.put(key, existingValue);
}
}
return mutations.entrySet();
};
Map<K, ValueHolder<V>> computedMap = delegate.bulkCompute(putAllFunction.getEntriesToRemap().keySet(), computeFunction);
if (!failures.isEmpty()) {
throw new BulkCacheWritingException(failures, successes);
}
return computedMap;
}
@Override
public Map<K, ValueHolder<V>> bulkCompute(Set<? extends K> keys, Function<Iterable<? extends Map.Entry<? extends K, ? extends V>>, Iterable<? extends Map.Entry<? extends K, ? extends V>>> remappingFunction, Supplier<Boolean> replaceEqual) throws StoreAccessException {
return null;
}
@Override
public Map<K, ValueHolder<V>> bulkComputeIfAbsent(Set<? extends K> keys, Function<Iterable<? extends K>, Iterable<? extends Map.Entry<? extends K, ? extends V>>> mappingFunction) throws StoreAccessException {
Map<K, V> successes = new HashMap<>(1);
Map<K, Exception> failures = new HashMap<>(1);
if (!(mappingFunction instanceof Ehcache.GetAllFunction)) {
return delegate.bulkComputeIfAbsent(keys, mappingFunction);
}
Function<Iterable<? extends K>, Iterable<? extends Map.Entry<? extends K, ? extends V>>> computeFunction =
keys1 -> {
Map<? super K, ? extends V> loaded;
try {
loaded = cacheLoaderWriter.loadAll(keys1);
} catch(BulkCacheLoadingException bcle) {
loaded = Collections.emptyMap();
collectSuccessesAndFailures(bcle, successes, failures);
} catch (Exception e) {
loaded = Collections.emptyMap();
for (K key : keys1) {
failures.put(key, e);
}
}
int size = CollectionUtil.findBestCollectionSize(keys1, 1);
Map<K, V> computeResult = new LinkedHashMap<>(size);
for (K key : keys1) {
computeResult.put(key, null);
}
if (!loaded.isEmpty()) {
for (K key : keys1) {
V value = loaded.get(key);
successes.put(key, value);
computeResult.put(key, value);
}
}
return computeResult.entrySet();
};
Map<K, V> result = new HashMap<>();
Map<K, Store.ValueHolder<V>> computedMap = delegate.bulkComputeIfAbsent(keys, computeFunction);
if (failures.isEmpty()) {
return computedMap;
} else {
successes.putAll(result);
throw new BulkCacheLoadingException(failures, successes);
}
}
@Override
public List<CacheConfigurationChangeListener> getConfigurationChangeListeners() {
return delegate.getConfigurationChangeListeners();
}
private V loadFromLoaderWriter(K key, V inCache) {
if (inCache == null) {
if (useLoaderInAtomics) {
try {
inCache = cacheLoaderWriter.load(key);
if (inCache == null) {
return null;
}
} catch (Exception e) {
throw new StorePassThroughException(newCacheLoadingException(e));
}
} else {
return null;
}
}
return inCache;
}
private void cacheLoaderWriterWriteAllCall(Iterable<? extends Map.Entry<? extends K, ? extends V>> entries, Map<K, V> entriesToRemap, Set<K> successes, Map<K, Exception> failures) throws IllegalStateException {
Map<K, V> toWrite = new HashMap<>();
for (Map.Entry<? extends K, ? extends V> entry: entries) {
V value = entriesToRemap.get(entry.getKey());
if (value == null) {
continue;
}
toWrite.put(entry.getKey(), value);
}
try {
if (! toWrite.isEmpty()) {
cacheLoaderWriter.writeAll(toWrite.entrySet());
successes.addAll(toWrite.keySet());
}
} catch (BulkCacheWritingException bcwe) {
collectSuccessesAndFailures(bcwe, successes, failures);
} catch (Exception e) {
for (K key: toWrite.keySet()) {
failures.put(key, e);
}
}
}
private Set<K> cacheLoaderWriterDeleteAllCall(Iterable<? extends Map.Entry<? extends K, ? extends V>> entries, Map<K, ? extends V> entriesToRemove, Set<K> successes, Map<K, Exception> failures) {
Set<K> unknowns = new HashSet<>();
Set<K> toDelete = new HashSet<>();
for (Map.Entry<? extends K, ? extends V> entry : entries) {
K key = entry.getKey();
if (entriesToRemove.containsKey(key)) {
toDelete.add(key);
}
}
try {
cacheLoaderWriter.deleteAll(toDelete);
successes.addAll(toDelete);
} catch (BulkCacheWritingException bcwe) {
collectSuccessesAndFailures(bcwe, successes, failures);
} catch (Exception e) {
for (K key : toDelete) {
failures.put(key, e);
unknowns.add(key);
}
}
return unknowns;
}
@SuppressWarnings({ "unchecked" })
private static <K> void collectSuccessesAndFailures(BulkCacheWritingException bcwe, Set<K> successes, Map<K, Exception> failures) {
successes.addAll((Collection<K>)bcwe.getSuccesses());
failures.putAll((Map<K, Exception>)bcwe.getFailures());
}
@SuppressWarnings({ "unchecked" })
private void collectSuccessesAndFailures(BulkCacheLoadingException bcle, Map<K, V> successes, Map<K, Exception> failures) {
successes.putAll((Map<K, V>)bcle.getSuccesses());
failures.putAll((Map<K, Exception>)bcle.getFailures());
}
private static <K, V> boolean newValueAlreadyExpired(Logger logger, ExpiryPolicy<? super K, ? super V> expiry, K key, V oldValue, V newValue) {
if (newValue == null) {
return false;
}
Duration duration;
try {
if (oldValue == null) {
duration = expiry.getExpiryForCreation(key, newValue);
} else {
duration = expiry.getExpiryForUpdate(key, () -> oldValue, newValue);
}
} catch (RuntimeException re) {
logger.error("Expiry computation caused an exception - Expiry duration will be 0 ", re);
return true;
}
return Duration.ZERO.equals(duration);
}
}