/*
* Copyright Terracotta, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.ehcache.core.util;
import org.ehcache.core.osgi.SafeOsgi;
import org.ehcache.core.osgi.OsgiServiceLoader;
import java.io.IOException;
import java.net.URL;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.ServiceLoader;
import java.util.function.Supplier;
import static java.security.AccessController.doPrivileged;
import static java.util.Collections.enumeration;
import static java.util.Collections.list;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Stream.concat;
import static java.util.stream.Stream.of;
public class ClassLoading {
private static final ClassLoader DEFAULT_CLASSLOADER;
static {
DEFAULT_CLASSLOADER = delegationChain(() -> Thread.currentThread().getContextClassLoader(), ChainedClassLoader.class.getClassLoader());
}
public static ClassLoader getDefaultClassLoader() {
return DEFAULT_CLASSLOADER;
}
public static <T> Iterable<T> servicesOfType(Class<T> serviceType) {
if (SafeOsgi.useOSGiServiceLoading()) {
return OsgiServiceLoader.load(serviceType);
} else {
return ServiceLoader.load(serviceType, ClassLoading.class.getClassLoader());
}
}
@SuppressWarnings("unchecked")
public static ClassLoader delegationChain(Supplier<ClassLoader> loader, ClassLoader ... loaders) {
return doPrivileged((PrivilegedAction<ClassLoader>) () -> new ChainedClassLoader(concat(of(loader), of(loaders).map(l -> () -> l)).collect(toList())));
}
@SuppressWarnings("unchecked")
public static ClassLoader delegationChain(ClassLoader ... loaders) {
return doPrivileged((PrivilegedAction<ClassLoader>) () -> new ChainedClassLoader(of(loaders).<Supplier<ClassLoader>>map(l -> () -> l).collect(toList())));
}
private static class ChainedClassLoader extends ClassLoader {
private final List<Supplier<ClassLoader>> loaders;
public ChainedClassLoader(List<Supplier<ClassLoader>> loaders) {
this.loaders = loaders;
}
@Override
public Class<?> loadClass(String name) throws ClassNotFoundException {
ClassNotFoundException lastFailure = new ClassNotFoundException(name);
for (Supplier<ClassLoader> loader : loaders) {
ClassLoader classLoader = loader.get();
if (classLoader != null) {
try {
return classLoader.loadClass(name);
} catch (ClassNotFoundException cnfe) {
lastFailure = cnfe;
}
}
}
throw lastFailure;
}
@Override
public URL getResource(String name) {
for (Supplier<ClassLoader> loader : loaders) {
ClassLoader classLoader = loader.get();
if (classLoader != null) {
URL resource = classLoader.getResource(name);
if (resource != null) {
return resource;
}
}
}
return null;
}
@Override
public Enumeration<URL> getResources(String name) throws IOException {
Collection<URL> aggregate = new ArrayList<>();
for (Supplier<ClassLoader> loader : loaders) {
ClassLoader classLoader = loader.get();
if (classLoader != null) {
aggregate.addAll(list(classLoader.getResources(name)));
}
}
return enumeration(aggregate);
}
}
}