package org.graalvm.compiler.replacements.test;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.core.test.GraalCompilerTest;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.ParameterNode;
import org.graalvm.compiler.nodes.PiNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions;
import org.graalvm.compiler.nodes.calc.MulNode;
import org.graalvm.compiler.nodes.java.StoreFieldNode;
import org.graalvm.compiler.phases.common.CanonicalizerPhase;
import org.graalvm.compiler.phases.tiers.HighTierContext;
import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerMulExactNode;
@RunWith(Parameterized.class)
public class IntegerMulExactFoldTest extends GraalCompilerTest {
public static int SideEffectI;
public static long SideEffectL;
public static void snippetInt(int a, int b) {
SideEffectI = Math.multiplyExact(a, b);
}
public static void snippetLong(long a, long b) {
SideEffectL = Math.multiplyExact(a, b);
}
private StructuredGraph prepareGraph(String snippet) {
StructuredGraph graph = parseEager(snippet, AllowAssumptions.NO);
HighTierContext context = getDefaultHighTierContext();
new CanonicalizerPhase().apply(graph, context);
return graph;
}
@Parameter(0) public long lowerBound1;
@Parameter(1) public long upperBound1;
@Parameter(2) public long lowerBound2;
@Parameter(3) public long upperBound2;
@Parameter(4) public int bits;
@Test
public void tryFold() {
assert bits == 32 || bits == 64;
IntegerStamp a = StampFactory.forInteger(bits, lowerBound1, upperBound1);
IntegerStamp b = StampFactory.forInteger(bits, lowerBound2, upperBound2);
StructuredGraph g = prepareGraph(bits == 32 ? "snippetInt" : "snippetLong");
List<ParameterNode> params = g.getNodes(ParameterNode.TYPE).snapshot();
params.get(0).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(0), a))), x -> x instanceof IntegerMulExactNode);
params.get(1).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(1), b))), x -> x instanceof IntegerMulExactNode);
new CanonicalizerPhase().apply(g, getDefaultHighTierContext());
boolean optimized = g.getNodes().filter(IntegerMulExactNode.class).count() == 0;
ValueNode leftOverMull = optimized ? g.getNodes().filter(MulNode.class).first() : g.getNodes().filter(IntegerMulExactNode.class).first();
new CanonicalizerPhase().apply(g, getDefaultHighTierContext());
if (leftOverMull == null) {
leftOverMull = g.getNodes().filter(StoreFieldNode.class).first().inputs().filter(ConstantNode.class).first();
}
if (leftOverMull == null) {
leftOverMull = g.getNodes().filter(PiNode.class).first();
}
IntegerStamp resultStamp = (IntegerStamp) leftOverMull.stamp();
for (long l1 = lowerBound1; l1 <= upperBound1; l1++) {
for (long l2 = lowerBound2; l2 <= upperBound2; l2++) {
try {
long res = mulExact(l1, l2, bits);
Assert.assertTrue(resultStamp.contains(res));
} catch (ArithmeticException e) {
Assert.assertFalse(optimized);
}
if (l2 == Long.MAX_VALUE) {
break;
}
}
if (l1 == Long.MAX_VALUE) {
break;
}
}
}
private static long mulExact(long x, long y, int bits) {
long r = x * y;
if (bits == 8) {
if ((byte) r != r) {
throw new ArithmeticException("overflow");
}
} else if (bits == 16) {
if ((short) r != r) {
throw new ArithmeticException("overflow");
}
} else if (bits == 32) {
return Math.multiplyExact((int) x, (int) y);
} else {
return Math.multiplyExact(x, y);
}
return r;
}
@Parameters(name = "a[{0} - {1}] b[{2} - {3}] bits=32")
public static Collection<Object[]> data() {
ArrayList<Object[]> tests = new ArrayList<>();
addTest(tests, -2, 2, 3, 3, 32);
addTest(tests, 0, 0, 1, 1, 32);
addTest(tests, 1, 1, 0, 0, 32);
addTest(tests, -1, 1, 0, 1, 32);
addTest(tests, -1, 1, 1, 1, 32);
addTest(tests, -1, 1, -1, 1, 32);
addTest(tests, -2, 2, 3, 3, 64);
addTest(tests, 0, 0, 1, 1, 64);
addTest(tests, 1, 1, 0, 0, 64);
addTest(tests, -1, 1, 0, 1, 64);
addTest(tests, -1, 1, 1, 1, 64);
addTest(tests, -1, 1, -1, 1, 64);
addTest(tests, -2, 2, 3, 3, 32);
addTest(tests, 0, 0, 1, 1, 32);
addTest(tests, 1, 1, 0, 0, 32);
addTest(tests, -1, 1, 0, 1, 32);
addTest(tests, -1, 1, 1, 1, 32);
addTest(tests, -1, 1, -1, 1, 32);
addTest(tests, 0, 0, 1, 1, 64);
addTest(tests, 1, 1, 0, 0, 64);
addTest(tests, -1, 1, 0, 1, 64);
addTest(tests, -1, 1, 1, 1, 64);
addTest(tests, -1, 1, -1, 1, 64);
addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF,
Integer.MAX_VALUE, 32);
addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 32);
addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF,
Integer.MAX_VALUE, 64);
addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 64);
addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE + 0xFFF, -1, -1, 64);
addTest(tests, 2, 2, 2, 2, 32);
addTest(tests, 1, 1, 2, 2, 32);
addTest(tests, 2, 2, 4, 4, 32);
addTest(tests, 3, 3, 3, 3, 32);
addTest(tests, -4, -4, 3, 3, 32);
addTest(tests, -4, -4, -3, -3, 32);
addTest(tests, 4, 4, -3, -3, 32);
addTest(tests, 2, 2, 2, 2, 64);
addTest(tests, 1, 1, 2, 2, 64);
addTest(tests, 3, 3, 3, 3, 64);
addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, 1, 1, 64);
addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, -1, -1, 64);
addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, -1, -1, 64);
addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, 1, 1, 64);
return tests;
}
private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits) {
tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits});
}
}