package io.dropwizard.testing.junit5;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.util.ReflectionUtils;
import javax.annotation.Nullable;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;
public class DropwizardExtensionsSupport implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback, AfterEachCallback {
private static Set<Field> findAnnotatedFields(Class<?> testClass, boolean isStaticMember) {
final Set<Field> set = Arrays.stream(testClass.getDeclaredFields()).
filter(m -> isStaticMember == Modifier.isStatic(m.getModifiers())).
filter(m -> DropwizardExtension.class.isAssignableFrom(m.getType())).
collect(Collectors.toSet());
if (!testClass.getSuperclass().equals(Object.class)) {
set.addAll(findAnnotatedFields(testClass.getSuperclass(), isStaticMember));
}
return set;
}
@Override
public void afterAll(ExtensionContext extensionContext) throws Exception {
try {
afterAll(extensionContext.getRequiredTestClass());
} catch (Exception e) {
throw e;
} catch (Throwable e) {
throw new Exception(e);
}
}
private void afterAll(Class<?> cls) throws Throwable {
final Class<?> enclosingClass = cls.getEnclosingClass();
if (enclosingClass != null) {
afterAll(enclosingClass);
}
for (Field member : findAnnotatedFields(cls, true)) {
getDropwizardExtension(member, null).after();
}
}
@Override
public void afterEach(ExtensionContext extensionContext) throws Exception {
final Object testInstance = extensionContext.getTestInstance()
.orElseThrow(() -> new IllegalStateException("Unable to get the current test instance"));
try {
afterEach(testInstance, testInstance.getClass());
} catch (Exception e) {
throw e;
} catch (Throwable e) {
throw new Exception(e);
}
}
private void afterEach(Object testInstance, Class<?> cls) throws Throwable {
final Class<?> enclosingClass = testInstance.getClass().getEnclosingClass();
if (enclosingClass != null) {
final Object enclosing = getEnclosingInstance(testInstance);
if (enclosing != null) {
afterEach(enclosing, enclosingClass);
}
}
for (Field member : findAnnotatedFields(cls, false)) {
getDropwizardExtension(member, testInstance).after();
}
}
@Override
public void beforeAll(ExtensionContext extensionContext) throws Exception {
try {
beforeAll(extensionContext.getRequiredTestClass());
} catch (Exception e) {
throw e;
} catch (Throwable e) {
throw new Exception(e);
}
}
private void beforeAll(Class<?> cls) throws Throwable {
final Class<?> enclosingClass = cls.getEnclosingClass();
if (enclosingClass != null) {
beforeAll(enclosingClass);
}
for (Field member : findAnnotatedFields(cls, true)) {
getDropwizardExtension(member, null).before();
}
}
@Override
public void beforeEach(ExtensionContext extensionContext) throws Exception {
final Object testInstance = extensionContext.getTestInstance()
.orElseThrow(() -> new IllegalStateException("Unable to get the current test instance"));
try {
beforeEach(testInstance, testInstance.getClass());
} catch (Exception e) {
throw e;
} catch (Throwable e) {
throw new Exception(e);
}
}
private void beforeEach(Object testInstance, Class<?> cls) throws Throwable {
final Class<?> enclosingClass = cls.getEnclosingClass();
if (enclosingClass != null) {
final Object enclosing = getEnclosingInstance(testInstance);
if (enclosing != null) {
beforeEach(enclosing, cls);
}
}
for (Field member : findAnnotatedFields(testInstance.getClass(), false)) {
getDropwizardExtension(member, testInstance).before();
}
}
@Nullable
private Object getEnclosingInstance(Object o) throws IllegalAccessException {
final Class<?> innerClass = o.getClass();
if (innerClass.getEnclosingClass() == null) {
return null;
}
for (int i = 0; i < 10; i++) {
try {
final Field field = innerClass.getDeclaredField("this$" + i);
field.setAccessible(true);
return field.get(o);
} catch (NoSuchFieldException e) {
}
}
return null;
}
private DropwizardExtension getDropwizardExtension(Field member, @Nullable Object o) throws IllegalAccessException {
return (DropwizardExtension) ReflectionUtils.makeAccessible(member).get(o);
}
}