package org.junit.jupiter.engine.extension;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Stream.concat;
import static org.apiguardian.api.API.Status.INTERNAL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apiguardian.api.API;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.engine.config.JupiterConfiguration;
import org.junit.platform.commons.logging.Logger;
import org.junit.platform.commons.logging.LoggerFactory;
import org.junit.platform.commons.util.ClassLoaderUtils;
import org.junit.platform.commons.util.Preconditions;
import org.junit.platform.commons.util.ReflectionUtils;
@API(status = INTERNAL, since = "5.5")
public class MutableExtensionRegistry implements ExtensionRegistry, ExtensionRegistrar {
private static final Logger logger = LoggerFactory.getLogger(MutableExtensionRegistry.class);
private static final List<Extension> DEFAULT_EXTENSIONS = Collections.unmodifiableList(Arrays.asList(
new DisabledCondition(),
new TempDirectory(),
new TimeoutExtension(),
new RepeatedTestExtension(),
new TestInfoParameterResolver(),
new TestReporterParameterResolver()));
public static MutableExtensionRegistry createRegistryWithDefaultExtensions(JupiterConfiguration configuration) {
MutableExtensionRegistry extensionRegistry = new MutableExtensionRegistry(null);
logger.trace(() -> "Registering default extensions: " + DEFAULT_EXTENSIONS.stream()
.map(extension -> extension.getClass().getName())
.collect(toList()));
DEFAULT_EXTENSIONS.forEach(extensionRegistry::registerDefaultExtension);
if (configuration.isExtensionAutoDetectionEnabled()) {
registerAutoDetectedExtensions(extensionRegistry);
}
return extensionRegistry;
}
private static void registerAutoDetectedExtensions(MutableExtensionRegistry extensionRegistry) {
Iterable<Extension> extensions = ServiceLoader.load(Extension.class, ClassLoaderUtils.getDefaultClassLoader());
logger.config(() -> "Registering auto-detected extensions: "
+ StreamSupport.stream(extensions.spliterator(), false)
.map(extension -> extension.getClass().getName())
.collect(toList()));
extensions.forEach(extensionRegistry::registerDefaultExtension);
}
public static MutableExtensionRegistry createRegistryFrom(MutableExtensionRegistry parentRegistry,
List<Class<? extends Extension>> extensionTypes) {
Preconditions.notNull(parentRegistry, "parentRegistry must not be null");
MutableExtensionRegistry registry = new MutableExtensionRegistry(parentRegistry);
extensionTypes.forEach(registry::registerExtension);
return registry;
}
private final MutableExtensionRegistry parent;
private final Set<Class<? extends Extension>> registeredExtensionTypes = new LinkedHashSet<>();
private final List<Extension> registeredExtensions = new ArrayList<>();
private MutableExtensionRegistry(MutableExtensionRegistry parent) {
this.parent = parent;
}
@Override
public <E extends Extension> Stream<E> stream(Class<E> extensionType) {
if (this.parent == null) {
return streamLocal(extensionType);
}
return concat(this.parent.stream(extensionType), streamLocal(extensionType));
}
private <E extends Extension> Stream<E> streamLocal(Class<E> extensionType) {
return this.registeredExtensions.stream()
.filter(extensionType::isInstance)
.map(extensionType::cast);
}
private boolean isAlreadyRegistered(Class<? extends Extension> extensionType) {
return (this.registeredExtensionTypes.contains(extensionType)
|| (this.parent != null && this.parent.isAlreadyRegistered(extensionType)));
}
void registerExtension(Class<? extends Extension> extensionType) {
if (!isAlreadyRegistered(extensionType)) {
registerExtension(ReflectionUtils.newInstance(extensionType));
this.registeredExtensionTypes.add(extensionType);
}
}
private void registerDefaultExtension(Extension extension) {
this.registeredExtensions.add(extension);
this.registeredExtensionTypes.add(extension.getClass());
}
private void registerExtension(Extension extension) {
registerExtension(extension, extension);
}
@Override
public void registerExtension(Extension extension, Object source) {
Preconditions.notNull(extension, "Extension must not be null");
Preconditions.notNull(source, "source must not be null");
logger.trace(() -> String.format("Registering extension [%s] from source [%s].", extension, source));
this.registeredExtensions.add(extension);
}
}