package com.oracle.truffle.api.test.polyglot;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.JarURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.CodeSource;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;
import com.oracle.truffle.api.TruffleLanguage;
public class LanguageCacheTest {
@Test
public void testDuplicateLanguageIds() throws Throwable {
CodeSource codeSource = LanguageCacheTest.class.getProtectionDomain().getCodeSource();
Assume.assumeNotNull(codeSource);
Path location = Paths.get(codeSource.getLocation().toURI());
Function<String, List<URL>> loader = new Function<String, List<URL>>() {
@Override
public List<URL> apply(String binaryName) {
try {
if (Files.isRegularFile(location)) {
return Collections.singletonList(new URL("jar:" + location.toUri().toString() + "!/" + binaryName));
} else {
return Collections.singletonList(new URL(location.toUri().toString() + binaryName));
}
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
};
ClassLoader testClassLoader = new TestClassLoader(loader);
try {
invokeLanguageCacheCreateLanguages(LanguageCacheTest.class.getClassLoader(), testClassLoader);
Assert.fail("Expected IllegalStateException");
} catch (IllegalStateException ise) {
}
}
@Test
public void testNestedArchives() throws Throwable {
CodeSource codeSource = LanguageCacheTest.class.getProtectionDomain().getCodeSource();
Assume.assumeNotNull(codeSource);
URL location = codeSource.getLocation();
Path source = Paths.get(location.toURI());
Assume.assumeTrue(Files.isRegularFile(source));
try (NestedJarLoader loader = new NestedJarLoader(source, location + "!/inner.jar!/")) {
ClassLoader testClassLoader = new TestClassLoader(loader);
invokeLanguageCacheCreateLanguages(testClassLoader);
}
}
@SuppressWarnings("unchecked")
private static Map<String, Object> invokeLanguageCacheCreateLanguages(ClassLoader... loaders) throws Throwable {
try {
final Class<?> langCacheClz = Class.forName("com.oracle.truffle.polyglot.LanguageCache", true, LanguageCacheTest.class.getClassLoader());
final Method createLanguages = langCacheClz.getDeclaredMethod("createLanguages", List.class);
createLanguages.setAccessible(true);
class LoaderSupplier implements Supplier<ClassLoader> {
private final ClassLoader classLoader;
LoaderSupplier(ClassLoader classLoader) {
this.classLoader = classLoader;
}
@Override
public ClassLoader get() {
return classLoader;
}
}
return (Map<String, Object>) createLanguages.invoke(null,
Arrays.stream(loaders).map(LoaderSupplier::new).collect(Collectors.toList()));
} catch (InvocationTargetException ite) {
throw ite.getCause();
} catch (ReflectiveOperationException re) {
throw new RuntimeException(re);
}
}
@TruffleLanguage.Registration(id = DuplicateIdLanguage.ID, name = DuplicateIdLanguage.ID, version = "1.0")
public static final class DuplicateIdLanguage extends TruffleLanguage<Void> {
static final String ID = "DuplicateIdLanguage";
@Override
protected Void createContext(Env env) {
return null;
}
}
private static final class TestClassLoader extends ClassLoader {
private static final Set<String> IMPORTANT_RESOURCES;
static {
IMPORTANT_RESOURCES = new HashSet<>();
IMPORTANT_RESOURCES.add(binaryName(DuplicateIdLanguage.class.getName()) + ".class");
IMPORTANT_RESOURCES.add(binaryName(LanguageCacheTestDuplicateIdLanguageProvider.class.getName()) + ".class");
}
private final Function<String, List<URL>> loader;
TestClassLoader(Function<String, List<URL>> loader) {
super(TestClassLoader.class.getClassLoader());
this.loader = loader;
}
@Override
protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
if (!IMPORTANT_RESOURCES.contains(binaryName(name) + ".class")) {
return super.loadClass(name, resolve);
} else {
synchronized (getClassLoadingLock(name)) {
Class<?> c = findLoadedClass(name);
if (c == null) {
c = findClass(name);
}
if (resolve) {
resolveClass(c);
}
return c;
}
}
}
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
String filePath = binaryName(name) + ".class";
if (!IMPORTANT_RESOURCES.contains(filePath)) {
throw new IllegalArgumentException("Only " + String.join(", ", IMPORTANT_RESOURCES) + " can be loaded.");
}
try {
URL location = findResource(filePath);
if (location == null) {
throw new ClassNotFoundException("Cannot load class: " + name);
}
try (InputStream in = location.openStream(); ByteArrayOutputStream out = new ByteArrayOutputStream()) {
copy(in, out);
byte[] content = out.toByteArray();
definePackage(name);
return defineClass(name, content, 0, content.length);
}
} catch (IOException e) {
throw new ClassNotFoundException("Cannot load class: " + name, e);
}
}
@Override
public URL getResource(String name) {
if (!IMPORTANT_RESOURCES.contains(name)) {
return super.getResource(name);
} else {
URL url = findResource(name);
return url != null ? url : getParent().getResource(name);
}
}
@Override
public Enumeration<URL> getResources(String name) throws IOException {
if (!IMPORTANT_RESOURCES.contains(name)) {
return super.getResources(name);
} else {
Enumeration<URL> e1 = findResources(name);
Enumeration<URL> e2 = getParent().getResources(name);
List<URL> result = new ArrayList<>();
addAll(result, e1);
addAll(result, e2);
return Collections.enumeration(result);
}
}
@Override
protected URL findResource(String name) {
try {
Enumeration<URL> e = findResources(name);
return e.hasMoreElements() ? e.nextElement() : null;
} catch (IOException ioe) {
return null;
}
}
@Override
protected Enumeration<URL> findResources(String name) throws IOException {
return Collections.enumeration(loader.apply(name));
}
private static <T> void addAll(Collection<? super T> dest, Enumeration<? extends T> src) {
while (src.hasMoreElements()) {
dest.add(src.nextElement());
}
}
@SuppressWarnings("deprecation")
private void definePackage(String className) {
String packageName = getPackageName(className);
if (getPackage(packageName) == null) {
definePackage(packageName, null, null, null, null, null, null, null);
}
}
private static void copy(InputStream in, OutputStream out) throws IOException {
byte[] buffer = new byte[4096];
while (true) {
int read = in.read(buffer, 0, buffer.length);
if (read == -1) {
return;
}
out.write(buffer, 0, read);
}
}
private static String getPackageName(String className) {
int lastDot = className.lastIndexOf('.');
return lastDot == -1 ? "" : className.substring(0, lastDot);
}
private static String binaryName(String name) {
return name.replace(".", "/");
}
}
private static final class NestedJarLoader implements Function<String, List<URL>>, Closeable {
private final ZipFile zipFile;
private final String relocation;
private NestedJarLoader(Path delegate, String relocation) throws IOException {
if (!relocation.endsWith("!/")) {
throw new IllegalArgumentException("Relocation must point into an archive file.");
}
this.zipFile = new ZipFile(delegate.toFile());
this.relocation = relocation;
}
@Override
public List<URL> apply(String binaryName) {
String entryName = binaryName.charAt(0) == '/' ? binaryName.substring(1) : binaryName;
ZipEntry e = zipFile.getEntry(entryName);
if (e != null) {
try {
URL url = new URL("jar", null, -1, relocation + binaryName, new NestedJarURLStreamHandler(zipFile, e));
return Collections.singletonList(url);
} catch (MalformedURLException murl) {
throw new RuntimeException(murl);
}
}
return Collections.emptyList();
}
@Override
public void close() throws IOException {
zipFile.close();
}
private static final class NestedJarURLStreamHandler extends URLStreamHandler {
private final ZipFile zipFile;
private final ZipEntry entry;
NestedJarURLStreamHandler(ZipFile zipFile, ZipEntry entry) {
this.zipFile = zipFile;
this.entry = entry;
}
@Override
protected URLConnection openConnection(URL u) throws IOException {
return new JarURLConnection(u) {
@Override
public JarFile getJarFile() throws IOException {
throw new UnsupportedOperationException("Not supported.");
}
@Override
public URL getJarFileURL() {
try {
String surl = u.toString();
int index = surl.lastIndexOf("!/");
return new URL(surl.substring(0, index));
} catch (MalformedURLException mue) {
throw new IllegalArgumentException(mue);
}
}
@Override
public InputStream getInputStream() throws IOException {
return zipFile.getInputStream(entry);
}
@Override
public void connect() throws IOException {
}
};
}
}
}
}