package org.junit.jupiter.engine.descriptor;
import static java.util.stream.Collectors.toList;
import static org.junit.platform.commons.util.AnnotationUtils.findAnnotatedFields;
import static org.junit.platform.commons.util.AnnotationUtils.findAnnotation;
import static org.junit.platform.commons.util.AnnotationUtils.findRepeatableAnnotations;
import static org.junit.platform.commons.util.ReflectionUtils.isNotPrivate;
import static org.junit.platform.commons.util.ReflectionUtils.tryToReadFieldValue;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.function.Predicate;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.engine.extension.ExtensionRegistrar;
import org.junit.jupiter.engine.extension.MutableExtensionRegistry;
import org.junit.platform.commons.util.Preconditions;
import org.junit.platform.commons.util.ReflectionUtils;
final class ExtensionUtils {
private ExtensionUtils() {
}
static MutableExtensionRegistry populateNewExtensionRegistryFromExtendWithAnnotation(
MutableExtensionRegistry parentRegistry, AnnotatedElement annotatedElement) {
Preconditions.notNull(parentRegistry, "Parent ExtensionRegistry must not be null");
Preconditions.notNull(annotatedElement, "AnnotatedElement must not be null");
List<Class<? extends Extension>> extensionTypes = findRepeatableAnnotations(annotatedElement, ExtendWith.class).stream()
.map(ExtendWith::value)
.flatMap(Arrays::stream)
.collect(toList());
return MutableExtensionRegistry.createRegistryFrom(parentRegistry, extensionTypes);
}
static void registerExtensionsFromFields(ExtensionRegistrar registrar, Class<?> clazz, Object instance) {
Preconditions.notNull(registrar, "ExtensionRegistrar must not be null");
Preconditions.notNull(clazz, "Class must not be null");
Predicate<Field> predicate = (instance == null ? ReflectionUtils::isStatic : ReflectionUtils::isNotStatic);
List<Field> fields = new ArrayList<>(findAnnotatedFields(clazz, RegisterExtension.class, predicate));
fields.sort(orderComparator);
fields.forEach(field -> {
Preconditions.condition(isNotPrivate(field),
() -> String.format(
"Failed to register extension via @RegisterExtension field [%s]: field must not be private.",
field));
tryToReadFieldValue(field, instance).ifSuccess(value -> {
Preconditions.condition(value instanceof Extension, () -> String.format(
"Failed to register extension via @RegisterExtension field [%s]: field value's type [%s] must implement an [%s] API.",
field, (value != null ? value.getClass().getName() : null), Extension.class.getName()));
registrar.registerExtension((Extension) value, field);
});
});
}
private static final Comparator<Field> orderComparator =
(field1, field2) -> Integer.compare(getOrder(field1), getOrder(field2));
private static int getOrder(Field field) {
return findAnnotation(field, Order.class).map(Order::value).orElse(Order.DEFAULT);
}
}