package com.sun.tools.javac.model;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.lang.model.AnnotatedConstruct;
import javax.lang.model.SourceVersion;
import javax.lang.model.element.*;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.util.Elements;
import javax.tools.JavaFileObject;
import static javax.lang.model.util.ElementFilter.methodsIn;
import com.sun.source.util.JavacTask;
import com.sun.tools.javac.api.JavacTaskImpl;
import com.sun.tools.javac.code.*;
import com.sun.tools.javac.code.Attribute.Compound;
import com.sun.tools.javac.code.Directive.ExportsDirective;
import com.sun.tools.javac.code.Directive.ExportsFlag;
import com.sun.tools.javac.code.Directive.OpensDirective;
import com.sun.tools.javac.code.Directive.OpensFlag;
import com.sun.tools.javac.code.Directive.RequiresDirective;
import com.sun.tools.javac.code.Directive.RequiresFlag;
import com.sun.tools.javac.code.Scope.WriteableScope;
import com.sun.tools.javac.code.Symbol.*;
import com.sun.tools.javac.comp.AttrContext;
import com.sun.tools.javac.comp.Enter;
import com.sun.tools.javac.comp.Env;
import com.sun.tools.javac.main.JavaCompiler;
import com.sun.tools.javac.processing.PrintingProcessor;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.*;
import com.sun.tools.javac.tree.TreeInfo;
import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.*;
import com.sun.tools.javac.util.DefinedBy.Api;
import com.sun.tools.javac.util.Name;
import static com.sun.tools.javac.code.Kinds.Kind.*;
import static com.sun.tools.javac.code.Scope.LookupKind.NON_RECURSIVE;
import static com.sun.tools.javac.code.TypeTag.CLASS;
import com.sun.tools.javac.comp.Modules;
import com.sun.tools.javac.comp.Resolve;
import com.sun.tools.javac.comp.Resolve.RecoveryLoadClass;
import com.sun.tools.javac.resources.CompilerProperties.Notes;
import static com.sun.tools.javac.tree.JCTree.Tag.*;
public class JavacElements implements Elements {
private final JavaCompiler javaCompiler;
private final Symtab syms;
private final Modules modules;
private final Names names;
private final Types types;
private final Enter enter;
private final Resolve resolve;
private final JavacTaskImpl javacTaskImpl;
private final Log log;
private final boolean allowModules;
public static JavacElements instance(Context context) {
JavacElements instance = context.get(JavacElements.class);
if (instance == null)
instance = new JavacElements(context);
return instance;
}
protected JavacElements(Context context) {
context.put(JavacElements.class, this);
javaCompiler = JavaCompiler.instance(context);
syms = Symtab.instance(context);
modules = Modules.instance(context);
names = Names.instance(context);
types = Types.instance(context);
enter = Enter.instance(context);
resolve = Resolve.instance(context);
JavacTask t = context.get(JavacTask.class);
javacTaskImpl = t instanceof JavacTaskImpl ? (JavacTaskImpl) t : null;
log = Log.instance(context);
Source source = Source.instance(context);
allowModules = source.allowModules();
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public Set<? extends ModuleElement> getAllModuleElements() {
if (allowModules)
return Collections.unmodifiableSet(modules.allModules());
else
return Collections.emptySet();
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public ModuleSymbol getModuleElement(CharSequence name) {
ensureEntered("getModuleElement");
if (modules.getDefaultModule() == syms.noModule)
return null;
String strName = name.toString();
if (strName.equals(""))
return syms.unnamedModule;
return modules.getObservableModule(names.fromString(strName));
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public PackageSymbol getPackageElement(CharSequence name) {
return doGetPackageElement(null, name);
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public PackageSymbol getPackageElement(ModuleElement module, CharSequence name) {
module.getClass();
return doGetPackageElement(module, name);
}
private PackageSymbol doGetPackageElement(ModuleElement module, CharSequence name) {
ensureEntered("getPackageElement");
return doGetElement(module, "getPackageElement", name, PackageSymbol.class);
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public ClassSymbol getTypeElement(CharSequence name) {
return doGetTypeElement(null, name);
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public ClassSymbol getTypeElement(ModuleElement module, CharSequence name) {
module.getClass();
return doGetTypeElement(module, name);
}
private ClassSymbol doGetTypeElement(ModuleElement module, CharSequence name) {
ensureEntered("getTypeElement");
return doGetElement(module, "getTypeElement", name, ClassSymbol.class);
}
private <S extends Symbol> S doGetElement(ModuleElement module, String methodName,
CharSequence name, Class<S> clazz) {
String strName = name.toString();
if (!SourceVersion.isName(strName) && (!strName.isEmpty() || clazz == ClassSymbol.class)) {
return null;
}
if (module == null) {
return unboundNameToSymbol(methodName, strName, clazz);
} else {
return nameToSymbol((ModuleSymbol) module, strName, clazz);
}
}
private final Set<String> alreadyWarnedDuplicates = new HashSet<>();
private <S extends Symbol> S unboundNameToSymbol(String methodName,
String nameStr,
Class<S> clazz) {
if (modules.getDefaultModule() == syms.noModule) {
return nameToSymbol(syms.noModule, nameStr, clazz);
}
Set<S> found = new LinkedHashSet<>();
for (ModuleSymbol msym : modules.allModules()) {
S sym = nameToSymbol(msym, nameStr, clazz);
if (sym != null) {
if (!allowModules || clazz == ClassSymbol.class || !sym.members().isEmpty()) {
found.add(sym);
}
}
}
if (found.size() == 1) {
return found.iterator().next();
} else if (found.size() > 1) {
if (alreadyWarnedDuplicates.add(methodName + ":" + nameStr)) {
String moduleNames = found.stream()
.map(s -> s.packge().modle)
.map(m -> m.toString())
.collect(Collectors.joining(", "));
log.note(Notes.MultipleElements(methodName, nameStr, moduleNames));
}
return null;
} else {
return null;
}
}
private <S extends Symbol> S nameToSymbol(ModuleSymbol module, String nameStr, Class<S> clazz) {
Name name = names.fromString(nameStr);
Symbol sym = (clazz == ClassSymbol.class)
? syms.getClass(module, name)
: syms.lookupPackage(module, name);
try {
if (sym == null)
sym = javaCompiler.resolveIdent(module, nameStr);
sym.complete();
return (sym.kind != ERR &&
sym.exists() &&
clazz.isInstance(sym) &&
name.equals(sym.getQualifiedName()))
? clazz.cast(sym)
: null;
} catch (CompletionFailure e) {
return null;
}
}
private JCTree matchAnnoToTree(AnnotationMirror findme,
Element e, JCTree tree) {
Symbol sym = cast(Symbol.class, e);
class Vis extends JCTree.Visitor {
List<JCAnnotation> result = null;
public void visitPackageDef(JCPackageDecl tree) {
result = tree.annotations;
}
public void visitClassDef(JCClassDecl tree) {
result = tree.mods.annotations;
}
public void visitMethodDef(JCMethodDecl tree) {
result = tree.mods.annotations;
}
public void visitVarDef(JCVariableDecl tree) {
result = tree.mods.annotations;
}
@Override
public void visitTypeParameter(JCTypeParameter tree) {
result = tree.annotations;
}
}
Vis vis = new Vis();
tree.accept(vis);
if (vis.result == null)
return null;
List<Attribute.Compound> annos = sym.getAnnotationMirrors();
return matchAnnoToTree(cast(Attribute.Compound.class, findme),
annos,
vis.result);
}
private JCTree matchAnnoToTree(Attribute.Compound findme,
List<Attribute.Compound> annos,
List<JCAnnotation> trees) {
for (Attribute.Compound anno : annos) {
for (JCAnnotation tree : trees) {
if (tree.type.tsym != anno.type.tsym)
continue;
JCTree match = matchAttributeToTree(findme, anno, tree);
if (match != null)
return match;
}
}
return null;
}
private JCTree matchAttributeToTree(final Attribute findme,
final Attribute attr,
final JCTree tree) {
if (attr == findme)
return tree;
class Vis implements Attribute.Visitor {
JCTree result = null;
public void visitConstant(Attribute.Constant value) {
}
public void visitClass(Attribute.Class clazz) {
}
public void visitCompound(Attribute.Compound anno) {
for (Pair<MethodSymbol, Attribute> pair : anno.values) {
JCExpression expr = scanForAssign(pair.fst, tree);
if (expr != null) {
JCTree match = matchAttributeToTree(findme, pair.snd, expr);
if (match != null) {
result = match;
return;
}
}
}
}
public void visitArray(Attribute.Array array) {
if (tree.hasTag(NEWARRAY)) {
List<JCExpression> elems = ((JCNewArray)tree).elems;
for (Attribute value : array.values) {
JCTree match = matchAttributeToTree(findme, value, elems.head);
if (match != null) {
result = match;
return;
}
elems = elems.tail;
}
} else if (array.values.length == 1) {
result = matchAttributeToTree(findme, array.values[0], tree);
}
}
public void visitEnum(Attribute.Enum e) {
}
public void visitError(Attribute.Error e) {
}
}
Vis vis = new Vis();
attr.accept(vis);
return vis.result;
}
private JCExpression scanForAssign(final MethodSymbol sym,
final JCTree tree) {
class TS extends TreeScanner {
JCExpression result = null;
public void scan(JCTree t) {
if (t != null && result == null)
t.accept(this);
}
public void visitAnnotation(JCAnnotation t) {
if (t == tree)
scan(t.args);
}
public void visitAssign(JCAssign t) {
if (t.lhs.hasTag(IDENT)) {
JCIdent ident = (JCIdent) t.lhs;
if (ident.sym == sym)
result = t.rhs;
}
}
}
TS scanner = new TS();
tree.accept(scanner);
return scanner.result;
}
public JCTree getTree(Element e) {
Pair<JCTree, ?> treeTop = getTreeAndTopLevel(e);
return (treeTop != null) ? treeTop.fst : null;
}
@DefinedBy(Api.LANGUAGE_MODEL)
public String (Element e) {
Pair<JCTree, JCCompilationUnit> treeTop = getTreeAndTopLevel(e);
if (treeTop == null)
return null;
JCTree tree = treeTop.fst;
JCCompilationUnit toplevel = treeTop.snd;
if (toplevel.docComments == null)
return null;
return toplevel.docComments.getCommentText(tree);
}
@DefinedBy(Api.LANGUAGE_MODEL)
public PackageElement getPackageOf(Element e) {
return cast(Symbol.class, e).packge();
}
@DefinedBy(Api.LANGUAGE_MODEL)
public ModuleElement getModuleOf(Element e) {
Symbol sym = cast(Symbol.class, e);
if (modules.getDefaultModule() == syms.noModule)
return null;
return (sym.kind == MDL) ? ((ModuleElement) e) : sym.packge().modle;
}
@DefinedBy(Api.LANGUAGE_MODEL)
public boolean isDeprecated(Element e) {
Symbol sym = cast(Symbol.class, e);
sym.complete();
return sym.isDeprecated();
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public Origin getOrigin(Element e) {
Symbol sym = cast(Symbol.class, e);
if ((sym.flags() & Flags.GENERATEDCONSTR) != 0)
return Origin.MANDATED;
return Origin.EXPLICIT;
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public Origin getOrigin(AnnotatedConstruct c, AnnotationMirror a) {
Compound ac = cast(Compound.class, a);
if (ac.isSynthesized())
return Origin.MANDATED;
return Origin.EXPLICIT;
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public Origin getOrigin(ModuleElement m, ModuleElement.Directive directive) {
switch (directive.getKind()) {
case REQUIRES:
RequiresDirective rd = cast(RequiresDirective.class, directive);
if (rd.flags.contains(RequiresFlag.MANDATED))
return Origin.MANDATED;
if (rd.flags.contains(RequiresFlag.SYNTHETIC))
return Origin.SYNTHETIC;
return Origin.EXPLICIT;
case EXPORTS:
ExportsDirective ed = cast(ExportsDirective.class, directive);
if (ed.flags.contains(ExportsFlag.MANDATED))
return Origin.MANDATED;
if (ed.flags.contains(ExportsFlag.SYNTHETIC))
return Origin.SYNTHETIC;
return Origin.EXPLICIT;
case OPENS:
OpensDirective od = cast(OpensDirective.class, directive);
if (od.flags.contains(OpensFlag.MANDATED))
return Origin.MANDATED;
if (od.flags.contains(OpensFlag.SYNTHETIC))
return Origin.SYNTHETIC;
return Origin.EXPLICIT;
}
return Origin.EXPLICIT;
}
@DefinedBy(Api.LANGUAGE_MODEL)
public Name getBinaryName(TypeElement type) {
return cast(TypeSymbol.class, type).flatName();
}
@DefinedBy(Api.LANGUAGE_MODEL)
public Map<MethodSymbol, Attribute> getElementValuesWithDefaults(
AnnotationMirror a) {
Attribute.Compound anno = cast(Attribute.Compound.class, a);
DeclaredType annotype = a.getAnnotationType();
Map<MethodSymbol, Attribute> valmap = anno.getElementValues();
for (ExecutableElement ex :
methodsIn(annotype.asElement().getEnclosedElements())) {
MethodSymbol meth = (MethodSymbol) ex;
Attribute defaultValue = meth.getDefaultValue();
if (defaultValue != null && !valmap.containsKey(meth)) {
valmap.put(meth, defaultValue);
}
}
return valmap;
}
@DefinedBy(Api.LANGUAGE_MODEL)
public FilteredMemberList getAllMembers(TypeElement element) {
Symbol sym = cast(Symbol.class, element);
WriteableScope scope = sym.members().dupUnshared();
List<Type> closure = types.closure(sym.asType());
for (Type t : closure)
addMembers(scope, t);
return new FilteredMemberList(scope);
}
private void addMembers(WriteableScope scope, Type type) {
members:
for (Symbol e : type.asElement().members().getSymbols(NON_RECURSIVE)) {
for (Symbol overrider : scope.getSymbolsByName(e.getSimpleName())) {
if (overrider.kind == e.kind && (overrider.flags() & Flags.SYNTHETIC) == 0) {
if (overrider.getKind() == ElementKind.METHOD &&
overrides((ExecutableElement)overrider, (ExecutableElement)e, (TypeElement)type.asElement())) {
continue members;
}
}
}
boolean derived = e.getEnclosingElement() != scope.owner;
ElementKind kind = e.getKind();
boolean initializer = kind == ElementKind.CONSTRUCTOR
|| kind == ElementKind.INSTANCE_INIT
|| kind == ElementKind.STATIC_INIT;
if (!derived || (!initializer && e.isInheritedIn(scope.owner, types)))
scope.enter(e);
}
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public List<Attribute.Compound> getAllAnnotationMirrors(Element e) {
Symbol sym = cast(Symbol.class, e);
List<Attribute.Compound> annos = sym.getAnnotationMirrors();
while (sym.getKind() == ElementKind.CLASS) {
Type sup = ((ClassSymbol) sym).getSuperclass();
if (!sup.hasTag(CLASS) || sup.isErroneous() ||
sup.tsym == syms.objectType.tsym) {
break;
}
sym = sup.tsym;
List<Attribute.Compound> oldAnnos = annos;
List<Attribute.Compound> newAnnos = sym.getAnnotationMirrors();
for (Attribute.Compound anno : newAnnos) {
if (isInherited(anno.type) &&
!containsAnnoOfType(oldAnnos, anno.type)) {
annos = annos.prepend(anno);
}
}
}
return annos;
}
private boolean isInherited(Type annotype) {
return annotype.tsym.attribute(syms.inheritedType.tsym) != null;
}
private static boolean containsAnnoOfType(List<Attribute.Compound> annos,
Type type) {
for (Attribute.Compound anno : annos) {
if (anno.type.tsym == type.tsym)
return true;
}
return false;
}
@DefinedBy(Api.LANGUAGE_MODEL)
public boolean hides(Element hiderEl, Element hideeEl) {
Symbol hider = cast(Symbol.class, hiderEl);
Symbol hidee = cast(Symbol.class, hideeEl);
if (hider == hidee ||
hider.kind != hidee.kind ||
hider.name != hidee.name) {
return false;
}
if (hider.kind == MTH) {
if (!hider.isStatic() ||
!types.isSubSignature(hider.type, hidee.type)) {
return false;
}
}
ClassSymbol hiderClass = hider.owner.enclClass();
ClassSymbol hideeClass = hidee.owner.enclClass();
if (hiderClass == null || hideeClass == null ||
!hiderClass.isSubClass(hideeClass, types)) {
return false;
}
return hidee.isInheritedIn(hiderClass, types);
}
@DefinedBy(Api.LANGUAGE_MODEL)
public boolean overrides(ExecutableElement riderEl,
ExecutableElement rideeEl, TypeElement typeEl) {
MethodSymbol rider = cast(MethodSymbol.class, riderEl);
MethodSymbol ridee = cast(MethodSymbol.class, rideeEl);
ClassSymbol origin = cast(ClassSymbol.class, typeEl);
return rider.name == ridee.name &&
rider != ridee &&
!rider.isStatic() &&
ridee.isMemberOf(origin, types) &&
rider.overrides(ridee, origin, types, false);
}
@DefinedBy(Api.LANGUAGE_MODEL)
public String getConstantExpression(Object value) {
return Constants.format(value);
}
@DefinedBy(Api.LANGUAGE_MODEL)
public void printElements(java.io.Writer w, Element... elements) {
for (Element element : elements)
(new PrintingProcessor.PrintingElementVisitor(w, this)).visit(element).flush();
}
@DefinedBy(Api.LANGUAGE_MODEL)
public Name getName(CharSequence cs) {
return names.fromString(cs.toString());
}
@Override @DefinedBy(Api.LANGUAGE_MODEL)
public boolean isFunctionalInterface(TypeElement element) {
if (element.getKind() != ElementKind.INTERFACE)
return false;
else {
TypeSymbol tsym = cast(TypeSymbol.class, element);
return types.isFunctionalInterface(tsym);
}
}
private Pair<JCTree, JCCompilationUnit> getTreeAndTopLevel(Element e) {
Symbol sym = cast(Symbol.class, e);
Env<AttrContext> enterEnv = getEnterEnv(sym);
if (enterEnv == null)
return null;
JCTree tree = TreeInfo.declarationFor(sym, enterEnv.tree);
if (tree == null || enterEnv.toplevel == null)
return null;
return new Pair<>(tree, enterEnv.toplevel);
}
public Pair<JCTree, JCCompilationUnit> getTreeAndTopLevel(
Element e, AnnotationMirror a, AnnotationValue v) {
if (e == null)
return null;
Pair<JCTree, JCCompilationUnit> elemTreeTop = getTreeAndTopLevel(e);
if (elemTreeTop == null)
return null;
if (a == null)
return elemTreeTop;
JCTree annoTree = matchAnnoToTree(a, e, elemTreeTop.fst);
if (annoTree == null)
return elemTreeTop;
if (v == null)
return new Pair<>(annoTree, elemTreeTop.snd);
JCTree valueTree = matchAttributeToTree(
cast(Attribute.class, v), cast(Attribute.class, a), annoTree);
if (valueTree == null)
return new Pair<>(annoTree, elemTreeTop.snd);
return new Pair<>(valueTree, elemTreeTop.snd);
}
private Env<AttrContext> getEnterEnv(Symbol sym) {
TypeSymbol ts = null;
switch (sym.kind) {
case PCK:
ts = (PackageSymbol)sym;
break;
case MDL:
ts = (ModuleSymbol)sym;
break;
default:
ts = sym.enclClass();
}
return (ts != null)
? enter.getEnv(ts)
: null;
}
private void ensureEntered(String methodName) {
if (javacTaskImpl != null) {
javacTaskImpl.ensureEntered();
}
if (!javaCompiler.isEnterDone()) {
throw new IllegalStateException("Cannot use Elements." + methodName + " before the TaskEvent.Kind.ENTER finished event.");
}
}
private static <T> T cast(Class<T> clazz, Object o) {
if (! clazz.isInstance(o))
throw new IllegalArgumentException(o.toString());
return clazz.cast(o);
}
}