package org.graalvm.compiler.phases.common;
import jdk.internal.vm.compiler.collections.Pair;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.BinaryArithmeticNode;
import org.graalvm.compiler.nodes.calc.IntegerDivRemNode;
import org.graalvm.compiler.nodes.calc.IntegerMulHighNode;
import org.graalvm.compiler.nodes.calc.MulNode;
import org.graalvm.compiler.nodes.calc.NarrowNode;
import org.graalvm.compiler.nodes.calc.RightShiftNode;
import org.graalvm.compiler.nodes.calc.SignExtendNode;
import org.graalvm.compiler.nodes.calc.SignedDivNode;
import org.graalvm.compiler.nodes.calc.SignedRemNode;
import org.graalvm.compiler.nodes.calc.UnsignedRightShiftNode;
import org.graalvm.compiler.phases.Phase;
import jdk.vm.ci.code.CodeUtil;
public class OptimizeDivPhase extends Phase {
@Override
protected void run(StructuredGraph graph) {
for (IntegerDivRemNode rem : graph.getNodes(IntegerDivRemNode.TYPE)) {
if (rem instanceof SignedRemNode && divByNonZeroConstant(rem)) {
optimizeRem(rem);
}
}
for (IntegerDivRemNode div : graph.getNodes(IntegerDivRemNode.TYPE)) {
if (div instanceof SignedDivNode && divByNonZeroConstant(div)) {
optimizeSignedDiv((SignedDivNode) div);
}
}
}
@Override
public float codeSizeIncrease() {
return 5.0f;
}
protected static boolean divByNonZeroConstant(IntegerDivRemNode divRemNode) {
return divRemNode.getY().isConstant() && divRemNode.getY().asJavaConstant().asLong() != 0;
}
protected final void optimizeRem(IntegerDivRemNode rem) {
assert rem.getOp() == IntegerDivRemNode.Op.REM;
StructuredGraph graph = rem.graph();
ValueNode div = findDivForRem(rem);
ValueNode mul = BinaryArithmeticNode.mul(graph, div, rem.getY(), NodeView.DEFAULT);
ValueNode result = BinaryArithmeticNode.sub(graph, rem.getX(), mul, NodeView.DEFAULT);
graph.replaceFixedWithFloating(rem, result);
}
private ValueNode findDivForRem(IntegerDivRemNode rem) {
if (rem.next() instanceof IntegerDivRemNode) {
IntegerDivRemNode div = (IntegerDivRemNode) rem.next();
if (div.getOp() == IntegerDivRemNode.Op.DIV && div.getType() == rem.getType() && div.getX() == rem.getX() && div.getY() == rem.getY()) {
return div;
}
}
if (rem.predecessor() instanceof IntegerDivRemNode) {
IntegerDivRemNode div = (IntegerDivRemNode) rem.predecessor();
if (div.getOp() == IntegerDivRemNode.Op.DIV && div.getType() == rem.getType() && div.getX() == rem.getX() && div.getY() == rem.getY()) {
return div;
}
}
ValueNode div = rem.graph().addOrUniqueWithInputs(createDiv(rem));
if (div instanceof FixedNode) {
rem.graph().addAfterFixed(rem, (FixedNode) div);
}
return div;
}
protected ValueNode createDiv(IntegerDivRemNode rem) {
assert rem instanceof SignedRemNode;
return SignedDivNode.create(rem.getX(), rem.getY(), rem.getZeroCheck(), NodeView.DEFAULT);
}
protected static void optimizeSignedDiv(SignedDivNode div) {
ValueNode forX = div.getX();
long c = div.getY().asJavaConstant().asLong();
assert c != 1 && c != -1 && c != 0;
IntegerStamp dividendStamp = (IntegerStamp) forX.stamp(NodeView.DEFAULT);
int bitSize = dividendStamp.getBits();
Pair<Long, Integer> nums = magicDivideConstants(c, bitSize);
long magicNum = nums.getLeft().longValue();
int shiftNum = nums.getRight().intValue();
assert shiftNum >= 0;
ConstantNode m = ConstantNode.forLong(magicNum);
ValueNode value;
if (bitSize == 32) {
value = new MulNode(new SignExtendNode(forX, 64), m);
if ((c > 0 && magicNum < 0) || (c < 0 && magicNum > 0)) {
value = NarrowNode.create(new RightShiftNode(value, ConstantNode.forInt(32)), 32, NodeView.DEFAULT);
if (c > 0) {
value = BinaryArithmeticNode.add(value, forX, NodeView.DEFAULT);
} else {
value = BinaryArithmeticNode.sub(value, forX, NodeView.DEFAULT);
}
if (shiftNum > 0) {
value = new RightShiftNode(value, ConstantNode.forInt(shiftNum));
}
} else {
value = new RightShiftNode(value, ConstantNode.forInt(32 + shiftNum));
value = new NarrowNode(value, Integer.SIZE);
}
} else {
assert bitSize == 64;
value = new IntegerMulHighNode(forX, m);
if (c > 0 && magicNum < 0) {
value = BinaryArithmeticNode.add(value, forX, NodeView.DEFAULT);
} else if (c < 0 && magicNum > 0) {
value = BinaryArithmeticNode.sub(value, forX, NodeView.DEFAULT);
}
if (shiftNum > 0) {
value = new RightShiftNode(value, ConstantNode.forInt(shiftNum));
}
}
if (c < 0) {
ConstantNode s = ConstantNode.forInt(bitSize - 1);
ValueNode sign = UnsignedRightShiftNode.create(value, s, NodeView.DEFAULT);
value = BinaryArithmeticNode.add(value, sign, NodeView.DEFAULT);
} else if (dividendStamp.canBeNegative()) {
ConstantNode s = ConstantNode.forInt(bitSize - 1);
ValueNode sign = UnsignedRightShiftNode.create(forX, s, NodeView.DEFAULT);
value = BinaryArithmeticNode.add(value, sign, NodeView.DEFAULT);
}
StructuredGraph graph = div.graph();
graph.replaceFixed(div, graph.addOrUniqueWithInputs(value));
}
private static Pair<Long, Integer> magicDivideConstants(long divisor, int size) {
final long twoW = 1L << (size - 1);
long t = twoW + (divisor >>> 63);
long ad = Math.abs(divisor);
long anc = t - 1 - Long.remainderUnsigned(t, ad);
long q1 = Long.divideUnsigned(twoW, anc);
long r1 = Long.remainderUnsigned(twoW, anc);
long q2 = Long.divideUnsigned(twoW, ad);
long r2 = Long.remainderUnsigned(twoW, ad);
long delta;
int p = size - 1;
do {
p = p + 1;
q1 = 2 * q1;
r1 = 2 * r1;
if (Long.compareUnsigned(r1, anc) >= 0) {
q1 = q1 + 1;
r1 = r1 - anc;
}
q2 = 2 * q2;
r2 = 2 * r2;
if (Long.compareUnsigned(r2, ad) >= 0) {
q2 = q2 + 1;
r2 = r2 - ad;
}
delta = ad - r2;
} while (Long.compareUnsigned(q1, delta) < 0 || (q1 == delta && r1 == 0));
long magic = CodeUtil.signExtend(q2 + 1, size);
if (divisor < 0) {
magic = -magic;
}
return Pair.create(magic, p - size);
}
}