package jdk.nashorn.internal.codegen;
import static jdk.nashorn.internal.ir.Node.NO_FINISH;
import static jdk.nashorn.internal.ir.Node.NO_LINE_NUMBER;
import static jdk.nashorn.internal.ir.Node.NO_TOKEN;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
import java.util.Objects;
import jdk.nashorn.internal.ir.AccessNode;
import jdk.nashorn.internal.ir.BinaryNode;
import jdk.nashorn.internal.ir.Block;
import jdk.nashorn.internal.ir.BlockLexicalContext;
import jdk.nashorn.internal.ir.BreakNode;
import jdk.nashorn.internal.ir.CallNode;
import jdk.nashorn.internal.ir.CaseNode;
import jdk.nashorn.internal.ir.ContinueNode;
import jdk.nashorn.internal.ir.Expression;
import jdk.nashorn.internal.ir.ExpressionStatement;
import jdk.nashorn.internal.ir.FunctionNode;
import jdk.nashorn.internal.ir.GetSplitState;
import jdk.nashorn.internal.ir.IdentNode;
import jdk.nashorn.internal.ir.IfNode;
import jdk.nashorn.internal.ir.JumpStatement;
import jdk.nashorn.internal.ir.JumpToInlinedFinally;
import jdk.nashorn.internal.ir.LiteralNode;
import jdk.nashorn.internal.ir.Node;
import jdk.nashorn.internal.ir.ReturnNode;
import jdk.nashorn.internal.ir.SetSplitState;
import jdk.nashorn.internal.ir.SplitNode;
import jdk.nashorn.internal.ir.SplitReturn;
import jdk.nashorn.internal.ir.Statement;
import jdk.nashorn.internal.ir.SwitchNode;
import jdk.nashorn.internal.ir.VarNode;
import jdk.nashorn.internal.ir.visitor.NodeVisitor;
import jdk.nashorn.internal.parser.Token;
import jdk.nashorn.internal.parser.TokenType;
final class SplitIntoFunctions extends NodeVisitor<BlockLexicalContext> {
private static final int FALLTHROUGH_STATE = -1;
private static final int RETURN_STATE = 0;
private static final int BREAK_STATE = 1;
private static final int FIRST_JUMP_STATE = 2;
private static final String THIS_NAME = CompilerConstants.THIS.symbolName();
private static final String RETURN_NAME = CompilerConstants.RETURN.symbolName();
private static final String RETURN_PARAM_NAME = RETURN_NAME + "-in";
private final Deque<FunctionState> functionStates = new ArrayDeque<>();
private final Deque<SplitState> splitStates = new ArrayDeque<>();
private final Namespace namespace;
private boolean artificialBlock = false;
private int nextFunctionId = -2;
public SplitIntoFunctions(final Compiler compiler) {
super(new BlockLexicalContext() {
@Override
protected Block afterSetStatements(final Block block) {
for(final Statement stmt: block.getStatements()) {
assert !(stmt instanceof SplitNode);
}
return block;
}
});
namespace = new Namespace(compiler.getScriptEnvironment().getNamespace());
}
@Override
public boolean enterFunctionNode(final FunctionNode functionNode) {
functionStates.push(new FunctionState(functionNode));
return true;
}
@Override
public Node leaveFunctionNode(final FunctionNode functionNode) {
functionStates.pop();
return functionNode;
}
@Override
protected Node leaveDefault(final Node node) {
if (node instanceof Statement) {
appendStatement((Statement)node);
}
return node;
}
@Override
public boolean enterSplitNode(final SplitNode splitNode) {
getCurrentFunctionState().splitDepth++;
splitStates.push(new SplitState(splitNode));
return true;
}
@Override
public Node leaveSplitNode(final SplitNode splitNode) {
final FunctionState fnState = getCurrentFunctionState();
final String name = splitNode.getName();
final Block body = splitNode.getBody();
final int firstLineNumber = body.getFirstStatementLineNumber();
final long token = body.getToken();
final int finish = body.getFinish();
final FunctionNode originalFn = fnState.fn;
assert originalFn == lc.getCurrentFunction();
final boolean isProgram = originalFn.isProgram();
final long newFnToken = Token.toDesc(TokenType.FUNCTION, nextFunctionId--, 0);
final FunctionNode fn = new FunctionNode(
originalFn.getSource(),
body.getFirstStatementLineNumber(),
newFnToken,
finish,
newFnToken,
NO_TOKEN,
namespace,
createIdent(name),
originalFn.getName() + "$" + name,
isProgram ? Collections.singletonList(createReturnParamIdent()) : Collections.<IdentNode>emptyList(),
null,
FunctionNode.Kind.NORMAL,
FunctionNode.IS_ANONYMOUS | FunctionNode.USES_ANCESTOR_SCOPE | FunctionNode.IS_SPLIT,
body,
null,
originalFn.getModule(),
originalFn.getDebugFlags()
)
.setCompileUnit(lc, splitNode.getCompileUnit());
final IdentNode thisIdent = createIdent(THIS_NAME);
final CallNode callNode = new CallNode(firstLineNumber, token, finish, new AccessNode(NO_TOKEN, NO_FINISH, fn, "call"),
isProgram ? Arrays.<Expression>asList(thisIdent, createReturnIdent())
: Collections.<Expression>singletonList(thisIdent),
false);
final SplitState splitState = splitStates.pop();
fnState.splitDepth--;
final Expression callWithReturn;
final boolean hasReturn = splitState.hasReturn;
if (hasReturn && fnState.splitDepth > 0) {
final SplitState parentSplit = splitStates.peek();
if (parentSplit != null) {
parentSplit.hasReturn = true;
}
}
if (hasReturn || isProgram) {
callWithReturn = new BinaryNode(Token.recast(token, TokenType.ASSIGN), createReturnIdent(), callNode);
} else {
callWithReturn = callNode;
}
appendStatement(new ExpressionStatement(firstLineNumber, token, finish, callWithReturn));
Statement splitStateHandler;
final List<JumpStatement> jumpStatements = splitState.jumpStatements;
final int jumpCount = jumpStatements.size();
if (jumpCount > 0) {
final List<CaseNode> cases = new ArrayList<>(jumpCount + (hasReturn ? 1 : 0));
if (hasReturn) {
addCase(cases, RETURN_STATE, createReturnFromSplit());
}
int i = FIRST_JUMP_STATE;
for (final JumpStatement jump: jumpStatements) {
addCase(cases, i++, enblockAndVisit(jump));
}
splitStateHandler = new SwitchNode(NO_LINE_NUMBER, token, finish, GetSplitState.INSTANCE, cases, null);
} else {
splitStateHandler = null;
}
if (splitState.hasBreak) {
splitStateHandler = makeIfStateEquals(firstLineNumber, token, finish, BREAK_STATE,
enblockAndVisit(new BreakNode(NO_LINE_NUMBER, token, finish, null)), splitStateHandler);
}
if (hasReturn && jumpCount == 0) {
splitStateHandler = makeIfStateEquals(NO_LINE_NUMBER, token, finish, RETURN_STATE,
createReturnFromSplit(), splitStateHandler);
}
if (splitStateHandler != null) {
appendStatement(splitStateHandler);
}
return splitNode;
}
private static void addCase(final List<CaseNode> cases, final int i, final Block body) {
cases.add(new CaseNode(NO_TOKEN, NO_FINISH, intLiteral(i), body));
}
private static LiteralNode<Number> intLiteral(final int i) {
return LiteralNode.newInstance(NO_TOKEN, NO_FINISH, i);
}
private static Block createReturnFromSplit() {
return new Block(NO_TOKEN, NO_FINISH, createReturnReturn());
}
private static ReturnNode createReturnReturn() {
return new ReturnNode(NO_LINE_NUMBER, NO_TOKEN, NO_FINISH, createReturnIdent());
}
private static IdentNode createReturnIdent() {
return createIdent(RETURN_NAME);
}
private static IdentNode createReturnParamIdent() {
return createIdent(RETURN_PARAM_NAME);
}
private static IdentNode createIdent(final String name) {
return new IdentNode(NO_TOKEN, NO_FINISH, name);
}
private Block enblockAndVisit(final JumpStatement jump) {
artificialBlock = true;
final Block block = (Block)new Block(NO_TOKEN, NO_FINISH, jump).accept(this);
artificialBlock = false;
return block;
}
private static IfNode makeIfStateEquals(final int lineNumber, final long token, final int finish,
final int value, final Block pass, final Statement fail) {
return new IfNode(lineNumber, token, finish,
new BinaryNode(Token.recast(token, TokenType.EQ_STRICT),
GetSplitState.INSTANCE, intLiteral(value)),
pass,
fail == null ? null : new Block(NO_TOKEN, NO_FINISH, fail));
}
@Override
public boolean enterVarNode(final VarNode varNode) {
if (!inSplitNode() || varNode.isBlockScoped()) {
return super.enterVarNode(varNode);
}
final Expression init = varNode.getInit();
getCurrentFunctionState().varStatements.add(varNode.setInit(null));
if (init != null) {
final long token = Token.recast(varNode.getToken(), TokenType.ASSIGN);
new ExpressionStatement(varNode.getLineNumber(), token, varNode.getFinish(),
new BinaryNode(token, varNode.getName(), varNode.getInit())).accept(this);
}
return false;
}
@Override
public Node leaveBlock(final Block block) {
if (!artificialBlock) {
if (lc.isFunctionBody()) {
lc.prependStatements(getCurrentFunctionState().varStatements);
} else if (lc.isSplitBody()) {
appendSplitReturn(FALLTHROUGH_STATE, NO_LINE_NUMBER);
if (getCurrentFunctionState().fn.isProgram()) {
lc.prependStatement(new ExpressionStatement(NO_LINE_NUMBER, NO_TOKEN, NO_FINISH,
new BinaryNode(Token.toDesc(TokenType.ASSIGN, 0, 0), createReturnIdent(), createReturnParamIdent())));
}
}
}
return block;
}
@Override
public Node leaveBreakNode(final BreakNode breakNode) {
return leaveJumpNode(breakNode);
}
@Override
public Node leaveContinueNode(final ContinueNode continueNode) {
return leaveJumpNode(continueNode);
}
@Override
public Node leaveJumpToInlinedFinally(final JumpToInlinedFinally jumpToInlinedFinally) {
return leaveJumpNode(jumpToInlinedFinally);
}
private JumpStatement leaveJumpNode(final JumpStatement jump) {
if (inSplitNode()) {
final SplitState splitState = getCurrentSplitState();
final SplitNode splitNode = splitState.splitNode;
if (lc.isExternalTarget(splitNode, jump.getTarget(lc))) {
appendSplitReturn(splitState.getSplitStateIndex(jump), jump.getLineNumber());
return jump;
}
}
appendStatement(jump);
return jump;
}
private void appendSplitReturn(final int splitState, final int lineNumber) {
appendStatement(new SetSplitState(splitState, lineNumber));
if (getCurrentFunctionState().fn.isProgram()) {
appendStatement(createReturnReturn());
} else {
appendStatement(SplitReturn.INSTANCE);
}
}
@Override
public Node leaveReturnNode(final ReturnNode returnNode) {
if(inSplitNode()) {
appendStatement(new SetSplitState(RETURN_STATE, returnNode.getLineNumber()));
getCurrentSplitState().hasReturn = true;
}
appendStatement(returnNode);
return returnNode;
}
private void appendStatement(final Statement statement) {
lc.appendStatement(statement);
}
private boolean inSplitNode() {
return getCurrentFunctionState().splitDepth > 0;
}
private FunctionState getCurrentFunctionState() {
return functionStates.peek();
}
private SplitState getCurrentSplitState() {
return splitStates.peek();
}
private static class FunctionState {
final FunctionNode fn;
final List<Statement> varStatements = new ArrayList<>();
int splitDepth;
FunctionState(final FunctionNode fn) {
this.fn = fn;
}
}
private static class SplitState {
final SplitNode splitNode;
boolean hasReturn;
boolean hasBreak;
final List<JumpStatement> jumpStatements = new ArrayList<>();
int getSplitStateIndex(final JumpStatement jump) {
if (jump instanceof BreakNode && jump.getLabelName() == null) {
hasBreak = true;
return BREAK_STATE;
}
int i = 0;
for(final JumpStatement exJump: jumpStatements) {
if (jump.getClass() == exJump.getClass() && Objects.equals(jump.getLabelName(), exJump.getLabelName())) {
return i + FIRST_JUMP_STATE;
}
++i;
}
jumpStatements.add(jump);
return i + FIRST_JUMP_STATE;
}
SplitState(final SplitNode splitNode) {
this.splitNode = splitNode;
}
}
}