package org.graalvm.compiler.replacements.verifier;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Function;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.PackageElement;
import javax.lang.model.element.TypeElement;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.ArrayType;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.type.TypeVariable;
import javax.lang.model.type.WildcardType;
import javax.tools.Diagnostic;
import javax.tools.JavaFileObject;
public class PluginGenerator {
private final Map<Element, List<GeneratedPlugin>> plugins;
public PluginGenerator() {
this.plugins = new HashMap<>();
}
public void addPlugin(GeneratedPlugin plugin) {
Element topLevel = getTopLevelClass(plugin.intrinsicMethod);
List<GeneratedPlugin> list = plugins.get(topLevel);
if (list == null) {
list = new ArrayList<>();
plugins.put(topLevel, list);
}
list.add(plugin);
}
public void generateAll(ProcessingEnvironment env) {
for (Entry<Element, List<GeneratedPlugin>> entry : plugins.entrySet()) {
disambiguateNames(entry.getValue());
createPluginFactory(env, entry.getKey(), entry.getValue());
}
}
private static Element getTopLevelClass(Element element) {
Element prev = element;
Element enclosing = element.getEnclosingElement();
while (enclosing != null && enclosing.getKind() != ElementKind.PACKAGE) {
prev = enclosing;
enclosing = enclosing.getEnclosingElement();
}
return prev;
}
private static void disambiguateWith(List<GeneratedPlugin> plugins, Function<GeneratedPlugin, String> genName) {
plugins.sort(Comparator.comparing(GeneratedPlugin::getPluginName));
GeneratedPlugin current = plugins.get(0);
String currentName = current.getPluginName();
for (int i = 1; i < plugins.size(); i++) {
GeneratedPlugin next = plugins.get(i);
if (currentName.equals(next.getPluginName())) {
if (current != null) {
current.setPluginName(genName.apply(current));
current = null;
}
next.setPluginName(genName.apply(next));
} else {
current = next;
currentName = current.getPluginName();
}
}
}
private static void appendSimpleTypeName(StringBuilder ret, TypeMirror type) {
switch (type.getKind()) {
case DECLARED:
DeclaredType declared = (DeclaredType) type;
TypeElement element = (TypeElement) declared.asElement();
ret.append(element.getSimpleName());
break;
case TYPEVAR:
appendSimpleTypeName(ret, ((TypeVariable) type).getUpperBound());
break;
case WILDCARD:
appendSimpleTypeName(ret, ((WildcardType) type).getExtendsBound());
break;
case ARRAY:
appendSimpleTypeName(ret, ((ArrayType) type).getComponentType());
ret.append("Array");
break;
default:
ret.append(type);
}
}
private static void disambiguateNames(List<GeneratedPlugin> plugins) {
disambiguateWith(plugins, plugin -> {
StringBuilder ret = new StringBuilder(plugin.getPluginName());
for (VariableElement param : plugin.intrinsicMethod.getParameters()) {
ret.append('_');
appendSimpleTypeName(ret, param.asType());
}
return ret.toString();
});
disambiguateWith(plugins, new Function<GeneratedPlugin, String>() {
private int idx = 0;
@Override
public String apply(GeneratedPlugin plugin) {
return plugin.getPluginName() + "_" + (idx++);
}
});
}
private static void createPluginFactory(ProcessingEnvironment env, Element topLevelClass, List<GeneratedPlugin> plugins) {
PackageElement pkg = (PackageElement) topLevelClass.getEnclosingElement();
String genClassName = "PluginFactory_" + topLevelClass.getSimpleName();
try {
JavaFileObject factory = env.getFiler().createSourceFile(pkg.getQualifiedName() + "." + genClassName, topLevelClass);
try (PrintWriter out = new PrintWriter(factory.openWriter())) {
out.printf("// CheckStyle: stop header check\n");
out.printf("// CheckStyle: stop line length check\n");
out.printf("// GENERATED CONTENT - DO NOT EDIT\n");
out.printf("// GENERATORS: %s, %s\n", VerifierAnnotationProcessor.class.getName(), PluginGenerator.class.getName());
out.printf("package %s;\n", pkg.getQualifiedName());
out.printf("\n");
createImports(out, plugins);
out.printf("\n");
out.printf("@ServiceProvider(NodeIntrinsicPluginFactory.class)\n");
out.printf("public class %s implements NodeIntrinsicPluginFactory {\n", genClassName);
for (GeneratedPlugin plugin : plugins) {
out.printf("\n");
plugin.generate(env, out);
}
out.printf("\n");
createPluginFactoryMethod(out, plugins);
out.printf("}\n");
}
} catch (IOException e) {
env.getMessager().printMessage(Diagnostic.Kind.ERROR, e.getMessage());
}
}
protected static void createImports(PrintWriter out, List<GeneratedPlugin> plugins) {
out.printf("import jdk.vm.ci.meta.ResolvedJavaMethod;\n");
out.printf("import org.graalvm.compiler.serviceprovider.ServiceProvider;\n");
out.printf("\n");
out.printf("import org.graalvm.compiler.nodes.ValueNode;\n");
out.printf("import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderContext;\n");
out.printf("import org.graalvm.compiler.nodes.graphbuilderconf.GeneratedInvocationPlugin;\n");
out.printf("import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugin;\n");
out.printf("import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins;\n");
out.printf("import org.graalvm.compiler.nodes.graphbuilderconf.NodeIntrinsicPluginFactory;\n");
HashSet<String> extra = new HashSet<>();
for (GeneratedPlugin plugin : plugins) {
plugin.extraImports(extra);
}
if (!extra.isEmpty()) {
out.printf("\n");
for (String i : extra) {
out.printf("import %s;\n", i);
}
}
}
private static void createPluginFactoryMethod(PrintWriter out, List<GeneratedPlugin> plugins) {
out.printf(" @Override\n");
out.printf(" public void registerPlugins(InvocationPlugins plugins, InjectionProvider injection) {\n");
for (GeneratedPlugin plugin : plugins) {
plugin.register(out);
}
out.printf(" }\n");
}
}