package org.graalvm.compiler.nodes;
import static org.graalvm.compiler.nodeinfo.InputType.Condition;
import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_0;
import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_0;
import org.graalvm.compiler.core.common.spi.ConstantFieldProvider;
import org.graalvm.compiler.core.common.spi.MetaAccessExtensionProvider;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.core.common.type.Stamp;
import org.graalvm.compiler.graph.IterableNodeType;
import org.graalvm.compiler.graph.NodeClass;
import org.graalvm.compiler.graph.spi.Canonicalizable;
import org.graalvm.compiler.graph.spi.CanonicalizerTool;
import org.graalvm.compiler.nodeinfo.NodeInfo;
import org.graalvm.compiler.nodes.calc.CompareNode;
import org.graalvm.compiler.nodes.calc.IntegerBelowNode;
import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
import org.graalvm.compiler.options.OptionValues;
import jdk.vm.ci.meta.Assumptions;
import jdk.vm.ci.meta.ConstantReflectionProvider;
import jdk.vm.ci.meta.MetaAccessProvider;
import jdk.vm.ci.meta.TriState;
@NodeInfo(cycles = CYCLES_0, size = SIZE_0)
public final class ShortCircuitOrNode extends LogicNode implements IterableNodeType, Canonicalizable.Binary<LogicNode> {
public static final NodeClass<ShortCircuitOrNode> TYPE = NodeClass.create(ShortCircuitOrNode.class);
@Input(Condition) LogicNode x;
@Input(Condition) LogicNode y;
protected boolean xNegated;
protected boolean yNegated;
protected double shortCircuitProbability;
public ShortCircuitOrNode(LogicNode x, boolean xNegated, LogicNode y, boolean yNegated, double shortCircuitProbability) {
super(TYPE);
this.x = x;
this.xNegated = xNegated;
this.y = y;
this.yNegated = yNegated;
this.shortCircuitProbability = shortCircuitProbability;
}
public static LogicNode create(LogicNode x, boolean xNegated, LogicNode y, boolean yNegated, double shortCircuitProbability) {
return new ShortCircuitOrNode(x, xNegated, y, yNegated, shortCircuitProbability);
}
@Override
public LogicNode getX() {
return x;
}
@Override
public LogicNode getY() {
return y;
}
public boolean isXNegated() {
return xNegated;
}
public boolean isYNegated() {
return yNegated;
}
public double getShortCircuitProbability() {
return shortCircuitProbability;
}
protected ShortCircuitOrNode canonicalizeNegation(LogicNode forX, LogicNode forY) {
LogicNode xCond = forX;
boolean xNeg = xNegated;
while (xCond instanceof LogicNegationNode) {
xCond = ((LogicNegationNode) xCond).getValue();
xNeg = !xNeg;
}
LogicNode yCond = forY;
boolean yNeg = yNegated;
while (yCond instanceof LogicNegationNode) {
yCond = ((LogicNegationNode) yCond).getValue();
yNeg = !yNeg;
}
if (xCond != forX || yCond != forY) {
return new ShortCircuitOrNode(xCond, xNeg, yCond, yNeg, shortCircuitProbability);
} else {
return this;
}
}
@Override
public LogicNode canonical(CanonicalizerTool tool, LogicNode forX, LogicNode forY) {
ShortCircuitOrNode ret = canonicalizeNegation(forX, forY);
if (ret != this) {
return ret;
}
NodeView view = NodeView.from(tool);
if (forX == forY) {
if (isXNegated()) {
if (isYNegated()) {
return LogicNegationNode.create(forX);
} else {
return LogicConstantNode.tautology();
}
} else {
if (isYNegated()) {
return LogicConstantNode.tautology();
} else {
return forX;
}
}
}
if (forX instanceof LogicConstantNode) {
if (((LogicConstantNode) forX).getValue() ^ isXNegated()) {
return LogicConstantNode.tautology();
} else {
if (isYNegated()) {
return new LogicNegationNode(forY);
} else {
return forY;
}
}
}
if (forY instanceof LogicConstantNode) {
if (((LogicConstantNode) forY).getValue() ^ isYNegated()) {
return LogicConstantNode.tautology();
} else {
if (isXNegated()) {
return new LogicNegationNode(forX);
} else {
return forX;
}
}
}
if (forX instanceof ShortCircuitOrNode) {
ShortCircuitOrNode inner = (ShortCircuitOrNode) forX;
if (forY == inner.getX()) {
return optimizeShortCircuit(inner, this.xNegated, this.yNegated, true);
} else if (forY == inner.getY()) {
return optimizeShortCircuit(inner, this.xNegated, this.yNegated, false);
}
} else if (forY instanceof ShortCircuitOrNode) {
ShortCircuitOrNode inner = (ShortCircuitOrNode) forY;
if (inner.getX() == forX) {
return optimizeShortCircuit(inner, this.yNegated, this.xNegated, true);
} else if (inner.getY() == forX) {
return optimizeShortCircuit(inner, this.yNegated, this.xNegated, false);
}
}
TriState impliedForY = forX.implies(!isXNegated(), forY);
if (impliedForY.isKnown()) {
boolean yResult = impliedForY.toBoolean() ^ isYNegated();
return yResult
? LogicConstantNode.tautology()
: (isXNegated()
? LogicNegationNode.create(forX)
: forX);
}
if (!isXNegated() && !isYNegated()) {
LogicNode sym = simplifyComparison(forX, forY);
if (sym != null) {
return sym;
}
}
if (forX instanceof IntegerBelowNode && forY instanceof IntegerLessThanNode && !isXNegated() && !isYNegated()) {
IntegerBelowNode xNode = (IntegerBelowNode) forX;
IntegerLessThanNode yNode = (IntegerLessThanNode) forY;
ValueNode xxNode = xNode.getX();
ValueNode yxNode = yNode.getX();
if (xxNode == yxNode && ((IntegerStamp) xxNode.stamp(view)).isPositive()) {
ValueNode xyNode = xNode.getY();
ValueNode yyNode = yNode.getY();
if (xyNode == yyNode) {
return forX;
}
}
}
if (forY instanceof ShortCircuitOrNode && !isXNegated() && !isYNegated()) {
ShortCircuitOrNode yNode = (ShortCircuitOrNode) forY;
if (!yNode.isXNegated()) {
LogicNode sym = simplifyComparison(forX, yNode.getX());
if (sym != null) {
double p1 = getShortCircuitProbability();
double p2 = yNode.getShortCircuitProbability();
return new ShortCircuitOrNode(sym, isXNegated(), yNode.getY(), yNode.isYNegated(), p1 + (1 - p1) * p2);
}
}
}
if (forX instanceof CompareNode && forY instanceof CompareNode) {
CompareNode xCompare = (CompareNode) forX;
CompareNode yCompare = (CompareNode) forY;
if (xCompare.getX() == yCompare.getX() || xCompare.getX() == yCompare.getY()) {
Stamp succeedingStampX = xCompare.getSucceedingStampForX(!xNegated, xCompare.getX().stamp(view), xCompare.getY().stamp(view));
if (succeedingStampX != null && !succeedingStampX.isUnrestricted()) {
CanonicalizerTool proxyTool = new ProxyCanonicalizerTool(succeedingStampX, xCompare.getX(), tool, view);
ValueNode result = yCompare.canonical(proxyTool);
if (result != yCompare) {
return ShortCircuitOrNode.create(forX, xNegated, (LogicNode) result, yNegated, this.shortCircuitProbability);
}
}
}
}
return this;
}
private static class ProxyCanonicalizerTool implements CanonicalizerTool, NodeView {
private final Stamp stamp;
private final ValueNode node;
private final CanonicalizerTool tool;
private final NodeView view;
ProxyCanonicalizerTool(Stamp stamp, ValueNode node, CanonicalizerTool tool, NodeView view) {
this.stamp = stamp;
this.node = node;
this.tool = tool;
this.view = view;
}
@Override
public Stamp stamp(ValueNode n) {
if (n == node) {
return stamp;
}
return view.stamp(n);
}
@Override
public Assumptions getAssumptions() {
return tool.getAssumptions();
}
@Override
public MetaAccessProvider getMetaAccess() {
return tool.getMetaAccess();
}
@Override
public ConstantReflectionProvider getConstantReflection() {
return tool.getConstantReflection();
}
@Override
public ConstantFieldProvider getConstantFieldProvider() {
return tool.getConstantFieldProvider();
}
@Override
public MetaAccessExtensionProvider getMetaAccessExtensionProvider() {
return tool.getMetaAccessExtensionProvider();
}
@Override
public boolean canonicalizeReads() {
return tool.canonicalizeReads();
}
@Override
public boolean allUsagesAvailable() {
return tool.allUsagesAvailable();
}
@Override
public Integer smallestCompareWidth() {
return tool.smallestCompareWidth();
}
@Override
public OptionValues getOptions() {
return tool.getOptions();
}
}
private static LogicNode simplifyComparison(LogicNode forX, LogicNode forY) {
LogicNode sym = simplifyComparisonOrdered(forX, forY);
if (sym == null) {
return simplifyComparisonOrdered(forY, forX);
}
return sym;
}
private static LogicNode simplifyComparisonOrdered(LogicNode forX, LogicNode forY) {
if (forX instanceof IntegerLessThanNode && forY instanceof IntegerLessThanNode) {
IntegerLessThanNode xNode = (IntegerLessThanNode) forX;
IntegerLessThanNode yNode = (IntegerLessThanNode) forY;
ValueNode xyNode = xNode.getY();
if (xyNode.isConstant() && IntegerStamp.OPS.getAdd().isNeutral(xyNode.asConstant())) {
ValueNode yxNode = yNode.getX();
IntegerStamp stamp = (IntegerStamp) yxNode.stamp(NodeView.DEFAULT);
if (stamp.isPositive()) {
if (xNode.getX() == yNode.getY()) {
ValueNode u = xNode.getX();
return IntegerBelowNode.create(yxNode, u, NodeView.DEFAULT);
}
}
}
}
return null;
}
private static LogicNode optimizeShortCircuit(ShortCircuitOrNode inner, boolean innerNegated, boolean matchNegated, boolean matchIsInnerX) {
boolean innerMatchNegated;
if (matchIsInnerX) {
innerMatchNegated = inner.isXNegated();
} else {
innerMatchNegated = inner.isYNegated();
}
if (!innerNegated) {
if (innerMatchNegated == matchNegated) {
return inner;
} else {
return LogicConstantNode.tautology();
}
} else {
if (innerMatchNegated == matchNegated) {
boolean newInnerXNegated = inner.isXNegated();
boolean newInnerYNegated = inner.isYNegated();
double newProbability = inner.getShortCircuitProbability();
if (matchIsInnerX) {
newInnerYNegated = !newInnerYNegated;
} else {
newInnerXNegated = !newInnerXNegated;
newProbability = 1.0 - newProbability;
}
return new ShortCircuitOrNode(inner.getX(), newInnerXNegated, inner.getY(), newInnerYNegated, newProbability);
} else {
LogicNode result = inner.getY();
if (matchIsInnerX) {
result = inner.getX();
}
if (matchNegated) {
return LogicNegationNode.create(result);
} else {
return result;
}
}
}
}
}