package lombok.javac.handlers;
import static lombok.core.handlers.HandlerUtil.*;
import static lombok.javac.Javac.*;
import static lombok.javac.handlers.JavacHandlerUtil.*;
import org.mangosdk.spi.ProviderFor;
import com.sun.tools.javac.tree.JCTree.JCAnnotation;
import com.sun.tools.javac.tree.JCTree.JCBinary;
import com.sun.tools.javac.tree.JCTree.JCBlock;
import com.sun.tools.javac.tree.JCTree.JCExpression;
import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCIf;
import com.sun.tools.javac.tree.JCTree.JCLiteral;
import com.sun.tools.javac.tree.JCTree.JCMethodDecl;
import com.sun.tools.javac.tree.JCTree.JCParens;
import com.sun.tools.javac.tree.JCTree.JCStatement;
import com.sun.tools.javac.tree.JCTree.JCSynchronized;
import com.sun.tools.javac.tree.JCTree.JCThrow;
import com.sun.tools.javac.tree.JCTree.JCTry;
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
import com.sun.tools.javac.util.List;
import lombok.ConfigurationKeys;
import lombok.NonNull;
import lombok.core.AnnotationValues;
import lombok.core.HandlerPriority;
import lombok.core.AST.Kind;
import lombok.javac.JavacAnnotationHandler;
import lombok.javac.JavacNode;
import static lombok.javac.JavacTreeMaker.TypeTag.*;
import static lombok.javac.JavacTreeMaker.TreeTag.*;
@ProviderFor(JavacAnnotationHandler.class)
@HandlerPriority(value = 512)
public class HandleNonNull extends JavacAnnotationHandler<NonNull> {
@Override public void handle(AnnotationValues<NonNull> annotation, JCAnnotation ast, JavacNode annotationNode) {
handleFlagUsage(annotationNode, ConfigurationKeys.NON_NULL_FLAG_USAGE, "@NonNull");
if (annotationNode.up().getKind() == Kind.FIELD) {
try {
if (isPrimitive(((JCVariableDecl) annotationNode.up().get()).vartype)) {
annotationNode.addWarning("@NonNull is meaningless on a primitive.");
}
} catch (Exception ignore) {}
return;
}
JCMethodDecl declaration;
JavacNode paramNode;
switch (annotationNode.up().getKind()) {
case ARGUMENT:
paramNode = annotationNode.up();
break;
case TYPE_USE:
JavacNode typeNode = annotationNode.directUp();
paramNode = typeNode.directUp();
break;
default:
return;
}
if (paramNode.getKind() != Kind.ARGUMENT) return;
try {
declaration = (JCMethodDecl) paramNode.up().get();
} catch (Exception e) {
return;
}
if (declaration.body == null) {
return;
}
JCStatement nullCheck = recursiveSetGeneratedBy(generateNullCheck(annotationNode.getTreeMaker(), paramNode, annotationNode), ast, annotationNode.getContext());
if (nullCheck == null) {
annotationNode.addWarning("@NonNull is meaningless on a primitive.");
return;
}
List<JCStatement> statements = declaration.body.stats;
String expectedName = paramNode.getName();
{
List<JCStatement> stats = statements;
int idx = 0;
while (stats.size() > idx) {
JCStatement stat = stats.get(idx++);
if (JavacHandlerUtil.isConstructorCall(stat)) continue;
if (stat instanceof JCTry) {
stats = ((JCTry) stat).body.stats;
idx = 0;
continue;
}
if (stat instanceof JCSynchronized) {
stats = ((JCSynchronized) stat).body.stats;
idx = 0;
continue;
}
String varNameOfNullCheck = returnVarNameIfNullCheck(stat);
if (varNameOfNullCheck == null) break;
if (varNameOfNullCheck.equals(expectedName)) return;
}
}
List<JCStatement> tail = statements;
List<JCStatement> head = List.nil();
for (JCStatement stat : statements) {
if (JavacHandlerUtil.isConstructorCall(stat) || (JavacHandlerUtil.isGenerated(stat) && isNullCheck(stat))) {
tail = tail.tail;
head = head.prepend(stat);
continue;
}
break;
}
List<JCStatement> newList = tail.prepend(nullCheck);
for (JCStatement stat : head) newList = newList.prepend(stat);
declaration.body.stats = newList;
annotationNode.getAst().setChanged();
}
public boolean isNullCheck(JCStatement stat) {
return returnVarNameIfNullCheck(stat) != null;
}
public String returnVarNameIfNullCheck(JCStatement stat) {
if (!(stat instanceof JCIf)) return null;
{
JCStatement then = ((JCIf) stat).thenpart;
if (then instanceof JCBlock) {
List<JCStatement> stats = ((JCBlock) then).stats;
if (stats.length() == 0) return null;
then = stats.get(0);
}
if (!(then instanceof JCThrow)) return null;
}
{
JCExpression cond = ((JCIf) stat).cond;
while (cond instanceof JCParens) cond = ((JCParens) cond).expr;
if (!(cond instanceof JCBinary)) return null;
JCBinary bin = (JCBinary) cond;
if (!CTC_EQUAL.equals(treeTag(bin))) return null;
if (!(bin.lhs instanceof JCIdent)) return null;
if (!(bin.rhs instanceof JCLiteral)) return null;
if (!CTC_BOT.equals(typeTag(bin.rhs))) return null;
return ((JCIdent) bin.lhs).name.toString();
}
}
}