package com.oracle.truffle.js.nodes.control;
import java.util.Set;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.instrumentation.InstrumentableNode;
import com.oracle.truffle.api.instrumentation.Tag;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.js.nodes.JSNodeUtil;
import com.oracle.truffle.js.nodes.JavaScriptNode;
import com.oracle.truffle.js.nodes.instrumentation.JSTaggedExecutionNode;
import com.oracle.truffle.js.nodes.instrumentation.JSTags;
import com.oracle.truffle.js.nodes.instrumentation.JSTags.ControlFlowBlockTag;
import com.oracle.truffle.js.nodes.instrumentation.JSTags.ControlFlowBranchTag;
import com.oracle.truffle.js.nodes.instrumentation.JSTags.ControlFlowRootTag;
@NodeInfo(shortName = "switch")
public final class SwitchNode extends StatementNode {
@Children private final JavaScriptNode[] caseExpressions;
@Children private final JavaScriptNode[] statements;
@CompilationFinal(dimensions = 1) private final int[] jumptable;
@CompilationFinal(dimensions = 1) private final ConditionProfile[] conditionProfiles;
private final boolean ordered;
private SwitchNode(JavaScriptNode[] caseExpressions, int[] jumptable, JavaScriptNode[] statements) {
assert caseExpressions.length == jumptable.length - 1;
this.caseExpressions = caseExpressions;
this.statements = statements;
this.jumptable = jumptable;
this.ordered = isMonotonicallyIncreasing(jumptable);
this.conditionProfiles = createConditionProfiles(caseExpressions.length);
}
private static boolean isMonotonicallyIncreasing(int[] table) {
for (int i = 0; i < table.length - 1; i++) {
int start = table[i];
int end = table[i + 1];
if (start > end) {
return false;
}
}
return true;
}
private static ConditionProfile[] createConditionProfiles(int length) {
ConditionProfile[] a = new ConditionProfile[length];
for (int i = 0; i < length; i++) {
a[i] = ConditionProfile.createCountingProfile();
}
return a;
}
public static SwitchNode create(JavaScriptNode[] caseExpressions, int[] jumptable, JavaScriptNode[] statements) {
return new SwitchNode(caseExpressions, jumptable, statements);
}
@Override
public boolean hasTag(Class<? extends Tag> tag) {
if (tag == ControlFlowRootTag.class) {
return true;
}
return super.hasTag(tag);
}
@Override
public Object getNodeObject() {
return JSTags.createNodeObjectDescriptor("type", ControlFlowRootTag.Type.Conditional.name());
}
@Override
public InstrumentableNode materializeInstrumentableNodes(Set<Class<? extends Tag>> materializedTags) {
if (materializedTags.contains(ControlFlowRootTag.class) && needsMaterialization()) {
JavaScriptNode[] newCaseExpressions = new JavaScriptNode[caseExpressions.length];
boolean wasChanged = false;
for (int i = 0; i < caseExpressions.length; i++) {
InstrumentableNode materialized = caseExpressions[i].materializeInstrumentableNodes(materializedTags);
newCaseExpressions[i] = JSTaggedExecutionNode.createForInput((JavaScriptNode) materialized, ControlFlowBranchTag.class,
JSTags.createNodeObjectDescriptor("type", ControlFlowBranchTag.Type.Condition.name()), materializedTags);
if (newCaseExpressions[i] != caseExpressions[i]) {
wasChanged = true;
}
}
JavaScriptNode[] newStatements = new JavaScriptNode[statements.length];
for (int i = 0; i < statements.length; i++) {
InstrumentableNode materialized = statements[i].materializeInstrumentableNodes(materializedTags);
newStatements[i] = JSTaggedExecutionNode.createFor((JavaScriptNode) materialized, ControlFlowBlockTag.class, materializedTags);
if (newStatements[i] != statements[i]) {
wasChanged = true;
}
}
if (!wasChanged) {
return this;
} else {
for (int i = 0; i < caseExpressions.length; i++) {
if (newCaseExpressions[i] == caseExpressions[i]) {
newCaseExpressions[i] = cloneUninitialized(caseExpressions[i], materializedTags);
}
}
for (int i = 0; i < statements.length; i++) {
if (newStatements[i] == statements[i]) {
newStatements[i] = cloneUninitialized(statements[i], materializedTags);
}
}
}
SwitchNode materialized = SwitchNode.create(newCaseExpressions, jumptable, newStatements);
transferSourceSectionAndTags(this, materialized);
return materialized;
} else {
return this;
}
}
private boolean needsMaterialization() {
boolean needsMaterialization = false;
for (int i = 0; i < caseExpressions.length && !needsMaterialization; i++) {
if (!JSNodeUtil.isTaggedNode(caseExpressions[i])) {
needsMaterialization = true;
}
}
for (int i = 0; i < statements.length && !needsMaterialization; i++) {
if (!JSNodeUtil.isTaggedNode(statements[i])) {
needsMaterialization = true;
}
}
return needsMaterialization;
}
@Override
public Object execute(VirtualFrame frame) {
if (ordered) {
return executeOrdered(frame);
} else {
return executeDefault(frame);
}
}
private Object executeDefault(VirtualFrame frame) {
int statementStartIndex = identifyTargetCase(frame);
return executeStatements(frame, statementStartIndex);
}
@ExplodeLoop
private int identifyTargetCase(VirtualFrame frame) {
int i;
for (i = 0; i < caseExpressions.length; i++) {
if (executeConditionAsBoolean(frame, caseExpressions[i])) {
break;
}
}
int statementStartIndex = jumptable[i];
CompilerAsserts.partialEvaluationConstant(statementStartIndex);
return statementStartIndex;
}
@ExplodeLoop
private Object executeStatements(VirtualFrame frame, int statementStartIndex) {
Object result = EMPTY;
for (int statementIndex = 0; statementIndex < statements.length; statementIndex++) {
if (statementIndex >= statementStartIndex) {
result = statements[statementIndex].execute(frame);
}
}
return result;
}
@ExplodeLoop
private Object executeOrdered(VirtualFrame frame) {
final JavaScriptNode[] caseExpressionsLocal = caseExpressions;
final JavaScriptNode[] statementsLocal = statements;
final int[] jumptableLocal = jumptable;
final ConditionProfile[] conditionProfilesLocal = conditionProfiles;
boolean caseFound = false;
Object result = EMPTY;
int jumptableIdx;
for (jumptableIdx = 0; jumptableIdx < caseExpressionsLocal.length; jumptableIdx++) {
if (caseFound || executeConditionAsBoolean(frame, caseExpressionsLocal[jumptableIdx])) {
caseFound = true;
}
int statementStartIndex = jumptableLocal[jumptableIdx];
int statementEndIndex = jumptableLocal[jumptableIdx + 1];
CompilerAsserts.partialEvaluationConstant(statementStartIndex);
CompilerAsserts.partialEvaluationConstant(statementEndIndex);
if (statementStartIndex != statementEndIndex) {
if (conditionProfilesLocal[jumptableIdx].profile(caseFound)) {
for (int statementIndex = statementStartIndex; statementIndex < statementEndIndex; statementIndex++) {
result = statementsLocal[statementIndex].execute(frame);
}
}
}
}
int statementStartIndex = jumptableLocal[jumptableIdx];
CompilerAsserts.partialEvaluationConstant(statementStartIndex);
for (int statementIndex = statementStartIndex; statementIndex < statementsLocal.length; statementIndex++) {
result = statementsLocal[statementIndex].execute(frame);
}
return result;
}
@Override
protected JavaScriptNode copyUninitialized(Set<Class<? extends Tag>> materializedTags) {
return create(cloneUninitialized(caseExpressions, materializedTags), jumptable, cloneUninitialized(statements, materializedTags));
}
}