package org.graalvm.compiler.replacements.test;
import java.util.HashMap;
import org.graalvm.compiler.api.replacements.MethodSubstitution;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.calc.AbsNode;
import org.graalvm.compiler.nodes.calc.ReinterpretNode;
import org.graalvm.compiler.replacements.amd64.AMD64CountLeadingZerosNode;
import org.graalvm.compiler.replacements.amd64.AMD64CountTrailingZerosNode;
import org.graalvm.compiler.replacements.nodes.BitCountNode;
import org.graalvm.compiler.replacements.nodes.BitScanForwardNode;
import org.graalvm.compiler.replacements.nodes.BitScanReverseNode;
import org.graalvm.compiler.replacements.nodes.ReverseBytesNode;
import org.junit.Test;
import jdk.vm.ci.code.InstalledCode;
import jdk.vm.ci.meta.ResolvedJavaMethod;
public class StandardMethodSubstitutionsTest extends MethodSubstitutionTest {
@Test
public void testMathSubstitutions() {
assertInGraph(assertNotInGraph(testGraph("mathAbs"), IfNode.class), AbsNode.class);
double value = 34567.891D;
testGraph("mathCos");
testGraph("mathLog");
testGraph("mathLog10");
testGraph("mathSin");
testGraph("mathSqrt");
testGraph("mathTan");
testGraph("mathAll");
test("mathCos", value);
test("mathLog", value);
test("mathLog10", value);
test("mathSin", value);
test("mathSqrt", value);
test("mathTan", value);
test("mathAll", value);
}
@Test
public void testMathPow() {
double a = 34567.891D;
double b = 4.6D;
test("mathPow", a, b);
test("mathPow", a, 0.0D);
test("mathPow", a, -0.0D);
test("mathPow", a, 1.0D);
test("mathPow", a, Double.NaN);
test("mathPow", Double.NaN, b);
test("mathPow", Double.NaN, 0.0D);
test("mathPow", a, -1.0D);
test("mathPow", a, 2.0D);
test("mathPow", a, 0.5D);
}
public static double mathPow(double a, double b) {
return mathPow0(a, b);
}
public static double mathPow0(double a, double b) {
return Math.pow(a, b);
}
public static double mathAbs(double value) {
return Math.abs(value);
}
public static double mathSqrt(double value) {
return Math.sqrt(value);
}
public static double mathLog(double value) {
return Math.log(value);
}
public static double mathLog10(double value) {
return Math.log10(value);
}
public static double mathSin(double value) {
return Math.sin(value);
}
public static double mathCos(double value) {
return Math.cos(value);
}
public static double mathTan(double value) {
return Math.tan(value);
}
public static double mathAll(double value) {
return Math.sqrt(value) + Math.log(value) + Math.log10(value) + Math.sin(value) + Math.cos(value) + Math.tan(value);
}
public void testSubstitution(String testMethodName, Class<?> holder, String methodName, boolean optional, Object[] args, Class<?>... intrinsicClasses) {
ResolvedJavaMethod realJavaMethod = getResolvedJavaMethod(holder, methodName);
ResolvedJavaMethod testJavaMethod = getResolvedJavaMethod(testMethodName);
StructuredGraph graph = testGraph(testMethodName);
StructuredGraph replacement = getReplacements().getSubstitution(realJavaMethod, -1, false, null, graph.allowAssumptions(), graph.getOptions());
if (replacement == null && !optional) {
assertInGraph(graph, intrinsicClasses);
}
for (Object l : args) {
InstalledCode code = getCode(testJavaMethod);
assert optional || code != null;
Object expected = invokeSafe(realJavaMethod, null, l);
assertDeepEquals(expected, invokeSafe(testJavaMethod, null, l));
assertDeepEquals(expected, executeVarargsSafe(code, l));
}
}
@Test
public void testCharSubstitutions() {
Object[] args = new Character[]{Character.MIN_VALUE, (char) -1, (char) 0, (char) 1, Character.MAX_VALUE};
testSubstitution("charReverseBytes", Character.class, "reverseBytes", false, args, ReverseBytesNode.class);
}
public static char charReverseBytes(char value) {
return Character.reverseBytes(value);
}
@Test
public void testCharSubstitutionsNarrowing() {
Object[] args = new Integer[]{(int) Character.MIN_VALUE, -1, 0, 1, (int) Character.MAX_VALUE};
for (Object arg : args) {
test("charReverseBytesNarrowing", arg);
}
}
public static char charReverseBytesNarrowing(int value) {
return Character.reverseBytes((char) value);
}
@Test
public void testShortSubstitutions() {
Object[] args = new Short[]{Short.MIN_VALUE, -1, 0, 1, Short.MAX_VALUE};
testSubstitution("shortReverseBytes", Short.class, "reverseBytes", false, args, ReverseBytesNode.class);
}
public static short shortReverseBytes(short value) {
return Short.reverseBytes(value);
}
@Test
public void testShortSubstitutionsNarrowing() {
Object[] args = new Integer[]{(int) Short.MIN_VALUE, -1, 0, 1, (int) Short.MAX_VALUE};
for (Object arg : args) {
test("shortReverseBytesNarrowing", arg);
}
}
public static short shortReverseBytesNarrowing(int value) {
return Short.reverseBytes((short) value);
}
@Test
public void testIntegerSubstitutions() {
Object[] args = new Object[]{Integer.MIN_VALUE, -1, 0, 1, Integer.MAX_VALUE};
testSubstitution("integerReverseBytes", Integer.class, "reverseBytes", false, args, ReverseBytesNode.class);
testSubstitution("integerNumberOfLeadingZeros", Integer.class, "numberOfLeadingZeros", true, args, BitScanReverseNode.class, AMD64CountLeadingZerosNode.class);
testSubstitution("integerNumberOfTrailingZeros", Integer.class, "numberOfTrailingZeros", false, args, BitScanForwardNode.class, AMD64CountTrailingZerosNode.class);
testSubstitution("integerBitCount", Integer.class, "bitCount", true, args, BitCountNode.class);
}
public static int integerReverseBytes(int value) {
return Integer.reverseBytes(value);
}
public static int integerNumberOfLeadingZeros(int value) {
return Integer.numberOfLeadingZeros(value);
}
public static int integerNumberOfTrailingZeros(int value) {
return Integer.numberOfTrailingZeros(value);
}
public static int integerBitCount(int value) {
return Integer.bitCount(value);
}
@Test
public void testLongSubstitutions() {
Object[] args = new Object[]{Long.MIN_VALUE, -1L, 0L, 1L, Long.MAX_VALUE};
testSubstitution("longReverseBytes", Long.class, "reverseBytes", false, args, ReverseBytesNode.class);
testSubstitution("longNumberOfLeadingZeros", Long.class, "numberOfLeadingZeros", true, args, BitScanReverseNode.class, AMD64CountLeadingZerosNode.class);
testSubstitution("longNumberOfTrailingZeros", Long.class, "numberOfTrailingZeros", false, args, BitScanForwardNode.class, AMD64CountTrailingZerosNode.class);
testSubstitution("longBitCount", Long.class, "bitCount", true, args, BitCountNode.class);
}
public static long longReverseBytes(long value) {
return Long.reverseBytes(value);
}
public static int longNumberOfLeadingZeros(long value) {
return Long.numberOfLeadingZeros(value);
}
public static int longNumberOfTrailingZeros(long value) {
return Long.numberOfTrailingZeros(value);
}
public static int longBitCount(long value) {
return Long.bitCount(value);
}
@Test
public void testFloatSubstitutions() {
assertInGraph(testGraph("floatToIntBits"), ReinterpretNode.class);
testGraph("intBitsToFloat");
}
public static int floatToIntBits(float value) {
return Float.floatToIntBits(value);
}
public static float intBitsToFloat(int value) {
return Float.intBitsToFloat(value);
}
@Test
public void testDoubleSubstitutions() {
assertInGraph(testGraph("doubleToLongBits"), ReinterpretNode.class);
testGraph("longBitsToDouble");
}
public static long doubleToLongBits(double value) {
return Double.doubleToLongBits(value);
}
public static double longBitsToDouble(long value) {
return Double.longBitsToDouble(value);
}
public static boolean isInstance(Class<?> clazz, Object object) {
return clazz.isInstance(object);
}
public static boolean isInstance2(boolean cond, Object object) {
Class<?> clazz;
if (cond) {
clazz = String.class;
} else {
clazz = java.util.HashMap.class;
}
return clazz.isInstance(object);
}
public static boolean isAssignableFrom(Class<?> clazz, Class<?> other) {
return clazz.isAssignableFrom(other);
}
@Test
public void testClassSubstitutions() {
testGraph("isInstance");
testGraph("isInstance2");
testGraph("isAssignableFrom");
for (Class<?> c : new Class<?>[]{getClass(), Cloneable.class, int[].class, String[][].class}) {
for (Object o : new Object[]{this, new int[5], new String[2][], new Object()}) {
test("isInstance", c, o);
test("isAssignableFrom", c, o.getClass());
}
}
test("isInstance2", true, null);
test("isInstance2", false, null);
test("isInstance2", true, "string");
test("isInstance2", false, "string");
test("isInstance2", true, new HashMap<>());
test("isInstance2", false, new HashMap<>());
}
}