package org.junit.jupiter.engine.execution;
import static java.util.stream.Collectors.joining;
import static org.apiguardian.api.API.Status.INTERNAL;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apiguardian.api.API;
import org.junit.jupiter.api.extension.InvocationInterceptor;
import org.junit.jupiter.api.extension.InvocationInterceptor.Invocation;
import org.junit.jupiter.engine.extension.ExtensionRegistry;
import org.junit.platform.commons.JUnitException;
import org.junit.platform.commons.logging.Logger;
import org.junit.platform.commons.logging.LoggerFactory;
import org.junit.platform.commons.util.ExceptionUtils;
@API(status = INTERNAL, since = "5.5")
public class InvocationInterceptorChain {
public <T> T invoke(Invocation<T> invocation, ExtensionRegistry extensionRegistry, InterceptorCall<T> call) {
List<InvocationInterceptor> interceptors = extensionRegistry.getExtensions(InvocationInterceptor.class);
if (interceptors.isEmpty()) {
return proceed(invocation);
}
return chainAndInvoke(invocation, call, interceptors);
}
private <T> T chainAndInvoke(Invocation<T> invocation, InterceptorCall<T> call,
List<InvocationInterceptor> interceptors) {
ValidatingInvocation<T> validatingInvocation = new ValidatingInvocation<>(invocation, interceptors);
Invocation<T> chainedInvocation = chainInterceptors(validatingInvocation, call, interceptors);
T result = proceed(chainedInvocation);
validatingInvocation.verifyInvokedAtLeastOnce();
return result;
}
private <T> Invocation<T> chainInterceptors(Invocation<T> invocation, InterceptorCall<T> call,
List<InvocationInterceptor> interceptors) {
Invocation<T> result = invocation;
ListIterator<InvocationInterceptor> iterator = interceptors.listIterator(interceptors.size());
while (iterator.hasPrevious()) {
InvocationInterceptor interceptor = iterator.previous();
result = new InterceptedInvocation<>(result, call, interceptor);
}
return result;
}
private <T> T proceed(Invocation<T> invocation) {
try {
return invocation.proceed();
}
catch (Throwable t) {
throw ExceptionUtils.throwAsUncheckedException(t);
}
}
@FunctionalInterface
public interface InterceptorCall<T> {
T apply(InvocationInterceptor interceptor, Invocation<T> invocation) throws Throwable;
static InterceptorCall<Void> ofVoid(VoidInterceptorCall call) {
return ((interceptorChain, invocation) -> {
call.apply(interceptorChain, invocation);
return null;
});
}
}
@FunctionalInterface
public interface VoidInterceptorCall {
void apply(InvocationInterceptor interceptor, Invocation<Void> invocation) throws Throwable;
}
private static class InterceptedInvocation<T> implements Invocation<T> {
private final Invocation<T> invocation;
private final InterceptorCall<T> call;
private final InvocationInterceptor interceptor;
InterceptedInvocation(Invocation<T> invocation, InterceptorCall<T> call, InvocationInterceptor interceptor) {
this.invocation = invocation;
this.call = call;
this.interceptor = interceptor;
}
@Override
public T proceed() throws Throwable {
return call.apply(interceptor, invocation);
}
@Override
public void skip() {
invocation.skip();
}
}
private static class ValidatingInvocation<T> implements Invocation<T> {
private static final Logger LOG = LoggerFactory.getLogger(ValidatingInvocation.class);
private final AtomicBoolean invokedOrSkipped = new AtomicBoolean();
private final Invocation<T> delegate;
private final List<InvocationInterceptor> interceptors;
ValidatingInvocation(Invocation<T> delegate, List<InvocationInterceptor> interceptors) {
this.delegate = delegate;
this.interceptors = interceptors;
}
@Override
public T proceed() throws Throwable {
markInvokedOrSkipped();
return delegate.proceed();
}
@Override
public void skip() {
LOG.debug(() -> "The invocation is skipped");
markInvokedOrSkipped();
delegate.skip();
}
private void markInvokedOrSkipped() {
if (!invokedOrSkipped.compareAndSet(false, true)) {
fail("Chain of InvocationInterceptors called invocation multiple times instead of just once");
}
}
void verifyInvokedAtLeastOnce() {
if (!invokedOrSkipped.get()) {
fail("Chain of InvocationInterceptors never called invocation");
}
}
private void fail(String prefix) {
String commaSeparatedInterceptorClasses = interceptors.stream().map(Object::getClass).map(
Class::getName).collect(joining(", "));
throw new JUnitException(prefix + ": " + commaSeparatedInterceptorClasses);
}
}
}