package org.junit.jupiter.engine.execution;
import static org.apiguardian.api.API.Status.INTERNAL;
import static org.junit.jupiter.engine.support.JupiterThrowableCollectorFactory.createThrowableCollector;
import static org.junit.platform.commons.util.ReflectionUtils.getWrapperType;
import static org.junit.platform.commons.util.ReflectionUtils.isAssignableTo;
import java.util.Comparator;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apiguardian.api.API;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import org.junit.jupiter.api.extension.ExtensionContext.Store.CloseableResource;
import org.junit.jupiter.api.extension.ExtensionContextException;
import org.junit.platform.engine.support.hierarchical.ThrowableCollector;
@API(status = INTERNAL, since = "5.0")
public class ExtensionValuesStore {
private static final Comparator<StoredValue> REVERSE_INSERT_ORDER = Comparator.<StoredValue, Integer> comparing(
it -> it.order).reversed();
private final AtomicInteger insertOrderSequence = new AtomicInteger();
private final ConcurrentMap<CompositeKey, StoredValue> storedValues = new ConcurrentHashMap<>(4);
private final ExtensionValuesStore parentStore;
public ExtensionValuesStore(ExtensionValuesStore parentStore) {
this.parentStore = parentStore;
}
public void closeAllStoredCloseableValues() {
ThrowableCollector throwableCollector = createThrowableCollector();
storedValues.values().stream()
.filter(storedValue -> storedValue.evaluate() instanceof CloseableResource)
.sorted(REVERSE_INSERT_ORDER)
.map(storedValue -> (CloseableResource) storedValue.evaluate())
.forEach(resource -> throwableCollector.execute(resource::close));
throwableCollector.assertEmpty();
}
Object get(Namespace namespace, Object key) {
StoredValue storedValue = getStoredValue(new CompositeKey(namespace, key));
return (storedValue != null ? storedValue.evaluate() : null);
}
<T> T get(Namespace namespace, Object key, Class<T> requiredType) {
Object value = get(namespace, key);
return castToRequiredType(key, value, requiredType);
}
<K, V> Object getOrComputeIfAbsent(Namespace namespace, K key, Function<K, V> defaultCreator) {
CompositeKey compositeKey = new CompositeKey(namespace, key);
StoredValue storedValue = getStoredValue(compositeKey);
if (storedValue == null) {
StoredValue newValue = storedValue(new MemoizingSupplier(() -> defaultCreator.apply(key)));
storedValue = Optional.ofNullable(storedValues.putIfAbsent(compositeKey, newValue)).orElse(newValue);
}
return storedValue.evaluate();
}
<K, V> V getOrComputeIfAbsent(Namespace namespace, K key, Function<K, V> defaultCreator, Class<V> requiredType) {
Object value = getOrComputeIfAbsent(namespace, key, defaultCreator);
return castToRequiredType(key, value, requiredType);
}
void put(Namespace namespace, Object key, Object value) {
storedValues.put(new CompositeKey(namespace, key), storedValue(() -> value));
}
private StoredValue storedValue(Supplier<Object> value) {
return new StoredValue(insertOrderSequence.getAndIncrement(), value);
}
Object remove(Namespace namespace, Object key) {
StoredValue previous = storedValues.remove(new CompositeKey(namespace, key));
return (previous != null ? previous.evaluate() : null);
}
<T> T remove(Namespace namespace, Object key, Class<T> requiredType) {
Object value = remove(namespace, key);
return castToRequiredType(key, value, requiredType);
}
private StoredValue getStoredValue(CompositeKey compositeKey) {
StoredValue storedValue = storedValues.get(compositeKey);
if (storedValue != null) {
return storedValue;
}
if (parentStore != null) {
return parentStore.getStoredValue(compositeKey);
}
return null;
}
@SuppressWarnings("unchecked")
private <T> T castToRequiredType(Object key, Object value, Class<T> requiredType) {
if (value == null) {
return null;
}
if (isAssignableTo(value, requiredType)) {
if (requiredType.isPrimitive()) {
return (T) getWrapperType(requiredType).cast(value);
}
return requiredType.cast(value);
}
throw new ExtensionContextException(
String.format("Object stored under key [%s] is not of required type [%s]", key, requiredType.getName()));
}
private static class CompositeKey {
private final Namespace namespace;
private final Object key;
private CompositeKey(Namespace namespace, Object key) {
this.namespace = namespace;
this.key = key;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
CompositeKey that = (CompositeKey) o;
return this.namespace.equals(that.namespace) && this.key.equals(that.key);
}
@Override
public int hashCode() {
return Objects.hash(namespace, key);
}
}
private static class StoredValue {
private final int order;
private final Supplier<Object> supplier;
public StoredValue(int order, Supplier<Object> supplier) {
this.order = order;
this.supplier = supplier;
}
private Object evaluate() {
return supplier.get();
}
}
private static class MemoizingSupplier implements Supplier<Object> {
private static final Object NO_VALUE_SET = new Object();
private final Lock lock = new ReentrantLock();
private final Supplier<Object> delegate;
private volatile Object value = NO_VALUE_SET;
private MemoizingSupplier(Supplier<Object> delegate) {
this.delegate = delegate;
}
@Override
public Object get() {
if (value == NO_VALUE_SET) {
lock.lock();
try {
if (value == NO_VALUE_SET) {
value = delegate.get();
}
}
finally {
lock.unlock();
}
}
return value;
}
}
}