package org.graalvm.compiler.core.match;
import static org.graalvm.compiler.core.common.cfg.AbstractControlFlowGraph.dominates;
import static org.graalvm.compiler.core.common.cfg.AbstractControlFlowGraph.strictlyDominates;
import static org.graalvm.compiler.debug.DebugOptions.LogVerbose;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import jdk.internal.vm.compiler.collections.EconomicMap;
import jdk.internal.vm.compiler.collections.Equivalence;
import org.graalvm.compiler.core.common.cfg.BlockMap;
import org.graalvm.compiler.core.gen.NodeLIRBuilder;
import org.graalvm.compiler.core.match.MatchPattern.Result;
import org.graalvm.compiler.debug.CounterKey;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.NodeMap;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.calc.FloatingNode;
import org.graalvm.compiler.nodes.cfg.Block;
import org.graalvm.compiler.nodes.virtual.VirtualObjectNode;
public class MatchContext {
private static final CounterKey MatchContextSuccessDifferentBlocks = DebugContext.counter("MatchContextSuccessDifferentBlocks");
private final Node root;
private final MatchStatement rule;
private final StructuredGraph.ScheduleResult schedule;
private EconomicMap<String, NamedNode> namedNodes;
static final class ConsumedNode {
final Node node;
final boolean ignoresSideEffects;
ConsumedNode(Node node, boolean ignoresSideEffects) {
this.node = node;
this.ignoresSideEffects = ignoresSideEffects;
}
}
static final class ConsumedNodes implements Iterable<ConsumedNode> {
private ArrayList<ConsumedNode> nodes;
ConsumedNodes() {
this.nodes = null;
}
public void add(Node node, boolean ignoresSideEffects) {
if (nodes == null) {
nodes = new ArrayList<>(2);
}
nodes.add(new ConsumedNode(node, ignoresSideEffects));
}
public boolean contains(Node node) {
for (ConsumedNode c : nodes) {
if (c.node == node) {
return true;
}
}
return false;
}
public ConsumedNode find(Node node) {
for (ConsumedNode c : nodes) {
if (c.node == node) {
return c;
}
}
return null;
}
@Override
public String toString() {
Node[] arr = new Node[nodes.size()];
int i = 0;
for (ConsumedNode c : nodes) {
arr[i++] = c.node;
}
return Arrays.toString(arr);
}
@Override
public Iterator<ConsumedNode> iterator() {
return nodes.iterator();
}
}
private ConsumedNodes consumed = new ConsumedNodes();
private Block rootBlock;
private int rootIndex;
private int emitIndex;
private Block emitBlock;
private final NodeLIRBuilder builder;
private static class NamedNode {
final Class<? extends Node> type;
final Node value;
NamedNode(Class<? extends Node> type, Node value) {
this.type = type;
this.value = value;
}
}
public MatchContext(NodeLIRBuilder builder, MatchStatement rule, int index, Node node, Block rootBlock, StructuredGraph.ScheduleResult schedule) {
this.builder = builder;
this.rule = rule;
this.root = node;
assert index == schedule.getBlockToNodesMap().get(rootBlock).indexOf(node);
this.schedule = schedule;
this.rootBlock = rootBlock;
rootIndex = index;
}
public Node getRoot() {
return root;
}
public Result captureNamedValue(String name, Class<? extends Node> type, Node value) {
if (namedNodes == null) {
namedNodes = EconomicMap.create(Equivalence.DEFAULT);
}
NamedNode current = namedNodes.get(name);
if (current == null) {
current = new NamedNode(type, value);
namedNodes.put(name, current);
return Result.OK;
} else {
if (current.value != value || current.type != type) {
return Result.namedValueMismatch(value, rule.getPattern());
}
return Result.OK;
}
}
public Result validate() {
Result result = findEarlyPosition();
if (result != Result.OK) {
return result;
}
findLatePosition();
assert emitIndex == rootIndex || consumed.find(root).ignoresSideEffects;
return verifyInputs();
}
private Result findEarlyPosition() {
int startIndexSideEffect = -1;
int endIndexSideEffect = -1;
final NodeMap<Block> nodeToBlockMap = schedule.getNodeToBlockMap();
final BlockMap<List<Node>> blockToNodesMap = schedule.getBlockToNodesMap();
for (ConsumedNode cn : consumed) {
if (!cn.ignoresSideEffects) {
Block b = nodeToBlockMap.get(cn.node);
if (emitBlock == null) {
emitBlock = b;
startIndexSideEffect = endIndexSideEffect = blockToNodesMap.get(b).indexOf(cn.node);
} else if (emitBlock == b) {
int index = blockToNodesMap.get(b).indexOf(cn.node);
startIndexSideEffect = Math.min(startIndexSideEffect, index);
endIndexSideEffect = Math.max(endIndexSideEffect, index);
} else {
logFailedMatch("nodes affected by side effects in different blocks %s", cn.node);
return Result.notInBlock(cn.node, rule.getPattern());
}
}
}
if (emitBlock != null) {
assert startIndexSideEffect != -1 && endIndexSideEffect != -1;
final List<Node> nodes = blockToNodesMap.get(emitBlock);
for (int i = startIndexSideEffect; i <= endIndexSideEffect; i++) {
Node node = nodes.get(i);
if (!sideEffectFree(node) && !consumed.contains(node)) {
logFailedMatch("unexpected side effect %s", node);
return Result.notSafe(node, rule.getPattern());
}
}
emitIndex = endIndexSideEffect;
} else {
emitBlock = nodeToBlockMap.get(root);
emitIndex = rootIndex;
}
return Result.OK;
}
private static boolean sideEffectFree(Node node) {
return node instanceof VirtualObjectNode || node instanceof FloatingNode;
}
private void findLatePosition() {
int index = rootIndex;
if (emitBlock != rootBlock) {
index = schedule.getBlockToNodesMap().get(emitBlock).size() - 1;
}
final List<Node> emitBlockNodes = schedule.getBlockToNodesMap().get(emitBlock);
for (int i = emitIndex + 1; i <= index; i++) {
Node node = emitBlockNodes.get(i);
ConsumedNode cn = consumed.find(node);
if (cn == null) {
if (!sideEffectFree(node)) {
return;
}
} else {
assert cn.ignoresSideEffects;
emitIndex = i;
}
}
}
private Result verifyInputs() {
DebugContext debug = root.getDebug();
if (emitBlock != rootBlock) {
assert consumed.find(root).ignoresSideEffects;
Result result = verifyInputsDifferentBlock(root);
if (result == Result.OK) {
MatchContextSuccessDifferentBlocks.increment(debug);
}
return result;
}
final List<Node> nodes = schedule.getBlockToNodesMap().get(rootBlock);
for (int i = emitIndex + 1; i <= rootIndex; i++) {
Node node = nodes.get(i);
ConsumedNode cn = consumed.find(node);
if (cn != null) {
assert cn.ignoresSideEffects;
for (Node in : node.inputs()) {
if (!consumed.contains(in)) {
for (int j = emitIndex + 1; j < i; j++) {
if (nodes.get(j) == in) {
logFailedMatch("Earliest position in block is too late %s", in);
assert consumed.find(root).ignoresSideEffects;
assert verifyInputsDifferentBlock(root) != Result.OK;
return Result.tooLate(node, rule.getPattern());
}
}
}
}
}
}
assert verifyInputsDifferentBlock(root) == Result.OK;
return Result.OK;
}
private Result verifyInputsDifferentBlock(Node node) {
for (Node in : node.inputs()) {
if (in instanceof PhiNode) {
Block b = schedule.getNodeToBlockMap().get(((PhiNode) in).merge());
if (dominates(b, emitBlock)) {
continue;
}
} else {
Block b = schedule.getNodeToBlockMap().get(in);
if (strictlyDominates(b, emitBlock) || (b == emitBlock && schedule.getBlockToNodesMap().get(emitBlock).indexOf(in) <= emitIndex)) {
continue;
}
}
ConsumedNode cn = consumed.find(in);
if (cn == null) {
logFailedMatch("Earliest position in block is too late %s", in);
return Result.tooLate(node, rule.getPattern());
}
assert cn.ignoresSideEffects;
Result res = verifyInputsDifferentBlock(in);
if (res != Result.OK) {
return res;
}
}
return Result.OK;
}
private void logFailedMatch(String s, Node node) {
if (LogVerbose.getValue(root.getOptions())) {
DebugContext debug = root.getDebug();
debug.log(s, node);
int startIndex = emitIndex;
if (emitBlock != rootBlock) {
int endIndex = schedule.getBlockToNodesMap().get(emitBlock).size() - 1;
final List<Node> emitBlockNodes = schedule.getBlockToNodesMap().get(emitBlock);
debug.log("%s:", emitBlock);
for (int j = startIndex; j <= endIndex; j++) {
Node theNode = emitBlockNodes.get(j);
debug.log("%s(%s) %1s", consumed.contains(theNode) ? "*" : " ", theNode.getUsageCount(), theNode);
}
startIndex = 0;
}
debug.log("%s:", rootBlock);
final List<Node> nodes = schedule.getBlockToNodesMap().get(rootBlock);
for (int j = startIndex; j <= rootIndex; j++) {
Node theNode = nodes.get(j);
debug.log("%s(%s) %1s", consumed.contains(theNode) ? "*" : " ", theNode.getUsageCount(), theNode);
}
}
}
public void setResult(ComplexMatchResult result) {
ComplexMatchValue value = new ComplexMatchValue(result);
Node emitNode = schedule.getBlockToNodesMap().get(emitBlock).get(emitIndex);
DebugContext debug = root.getDebug();
if (debug.isLogEnabled()) {
debug.log("matched %s %s%s", rule.getName(), rule.getPattern(), emitIndex != rootIndex ? " skipping side effects" : "");
debug.log("with nodes %s", rule.formatMatch(root));
}
for (ConsumedNode cn : consumed) {
if (cn.node == root || cn.node == emitNode) {
continue;
}
builder.setMatchResult(cn.node, ComplexMatchValue.INTERIOR_MATCH);
}
builder.setMatchResult(emitNode, value);
if (root != emitNode) {
builder.setMatchResult(root, new ComplexMatchValue(gen -> gen.operand(emitNode)));
}
}
public Result consume(Node node, boolean ignoresSideEffects, boolean atRoot) {
if (atRoot) {
consumed.add(node, ignoresSideEffects);
return Result.OK;
}
assert MatchPattern.isSingleValueUser(node) : "should have already been checked";
if (builder.hasOperand(node)) {
return Result.alreadyUsed(node, rule.getPattern());
}
consumed.add(node, ignoresSideEffects);
return Result.OK;
}
public Node namedNode(String name) {
if (namedNodes != null) {
NamedNode value = namedNodes.get(name);
if (value != null) {
return value.value;
}
}
throw new GraalError("missing node %s", name);
}
@Override
public String toString() {
return String.format("%s %s (%s/%d, %s/%d) consumed %s", rule, root, rootBlock, rootIndex, emitBlock, emitIndex, consumed);
}
}