package org.junit.jupiter.engine.descriptor;
import static java.util.stream.Collectors.toList;
import static org.junit.platform.commons.util.AnnotationUtils.findAnnotatedFields;
import static org.junit.platform.commons.util.AnnotationUtils.findAnnotation;
import static org.junit.platform.commons.util.AnnotationUtils.findRepeatableAnnotations;
import static org.junit.platform.commons.util.ReflectionUtils.isPrivate;
import static org.junit.platform.commons.util.ReflectionUtils.isStatic;
import static org.junit.platform.commons.util.ReflectionUtils.tryToReadFieldValue;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.function.Predicate;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.engine.extension.ExtensionRegistry;
import org.junit.platform.commons.util.Preconditions;
final class ExtensionUtils {
private static final Predicate<Field> isStaticExtension = new IsStaticExtensionField();
private static final Predicate<Field> isNonStaticExtension = new IsNonStaticExtensionField();
private ExtensionUtils() {
}
static ExtensionRegistry populateNewExtensionRegistryFromExtendWithAnnotation(ExtensionRegistry parentRegistry,
AnnotatedElement annotatedElement) {
Preconditions.notNull(annotatedElement, "AnnotatedElement must not be null");
Preconditions.notNull(parentRegistry, "Parent ExtensionRegistry must not be null");
List<Class<? extends Extension>> extensionTypes = findRepeatableAnnotations(annotatedElement, ExtendWith.class).stream()
.map(ExtendWith::value)
.flatMap(Arrays::stream)
.collect(toList());
return ExtensionRegistry.createRegistryFrom(parentRegistry, extensionTypes);
}
static void registerExtensionsFromFields(ExtensionRegistry registry, Class<?> clazz, Object instance) {
Preconditions.notNull(clazz, "Class must not be null");
Preconditions.notNull(registry, "ExtensionRegistry must not be null");
Predicate<Field> predicate = (instance == null) ? isStaticExtension : isNonStaticExtension;
List<Field> fields = new ArrayList<>(findAnnotatedFields(clazz, RegisterExtension.class, predicate));
fields.sort(orderComparator);
fields.forEach(field -> {
tryToReadFieldValue(field, instance).ifSuccess(value -> {
Preconditions.notNull(value, () -> String.format(
"Failed to register extension via @RegisterExtension field [%s]: field must not be null when evaluated.",
field));
Extension extension = (Extension) value;
registry.registerExtension(extension, field);
});
});
}
private static final Comparator<Field> orderComparator =
(field1, field2) -> Integer.compare(getOrder(field1), getOrder(field2));
private static int getOrder(Field field) {
return findAnnotation(field, Order.class).map(Order::value).orElse(Integer.MAX_VALUE);
}
static class IsNonStaticExtensionField implements Predicate<Field> {
@Override
public boolean test(Field field) {
if (isStatic(field)) {
return false;
}
if (isPrivate(field)) {
return false;
}
if (!Extension.class.isAssignableFrom(field.getType())) {
return false;
}
return true;
}
}
static class IsStaticExtensionField implements Predicate<Field> {
@Override
public boolean test(Field field) {
if (!isStatic(field)) {
return false;
}
if (isPrivate(field)) {
return false;
}
if (!Extension.class.isAssignableFrom(field.getType())) {
return false;
}
return true;
}
}
}