package org.graalvm.compiler.hotspot.test;
import java.math.BigInteger;
import java.util.Random;
import org.graalvm.compiler.api.test.Graal;
import org.graalvm.compiler.hotspot.GraalHotSpotVMConfig;
import org.graalvm.compiler.hotspot.HotSpotGraalRuntimeProvider;
import org.graalvm.compiler.replacements.test.MethodSubstitutionTest;
import org.graalvm.compiler.runtime.RuntimeProvider;
import org.junit.Test;
import jdk.vm.ci.amd64.AMD64;
import jdk.vm.ci.code.InstalledCode;
import jdk.vm.ci.meta.ResolvedJavaMethod;
public final class BigIntegerIntrinsicsTest extends MethodSubstitutionTest {
static final int N = 100;
@Test
public void testMultiplyToLen() throws ClassNotFoundException {
org.junit.Assume.assumeTrue(config.useMultiplyToLenIntrinsic());
org.junit.Assume.assumeTrue(getTarget().arch instanceof AMD64);
Class<?> javaclass = Class.forName("java.math.BigInteger");
TestIntrinsic tin = new TestIntrinsic("testMultiplyAux", javaclass,
"multiply", BigInteger.class);
for (int i = 0; i < N; i++) {
BigInteger big1 = randomBig(i);
BigInteger big2 = randomBig(i);
BigInteger res1 = (BigInteger) tin.invokeJava(big1, big2);
BigInteger res2 = (BigInteger) tin.invokeTest(big1, big2);
assertDeepEquals(res1, res2);
BigInteger res3 = (BigInteger) tin.invokeCode(big1, big2);
assertDeepEquals(res1, res3);
}
}
@Test
public void testMulAdd() throws ClassNotFoundException {
org.junit.Assume.assumeTrue(config.useMulAddIntrinsic() ||
config.useSquareToLenIntrinsic());
org.junit.Assume.assumeTrue(getTarget().arch instanceof AMD64);
Class<?> javaclass = Class.forName("java.math.BigInteger");
TestIntrinsic tin = new TestIntrinsic("testMultiplyAux", javaclass,
"multiply", BigInteger.class);
for (int i = 0; i < N; i++) {
BigInteger big1 = randomBig(i);
BigInteger res1 = (BigInteger) tin.invokeJava(big1, big1);
BigInteger res2 = (BigInteger) tin.invokeTest(big1, big1);
assertDeepEquals(res1, res2);
BigInteger res3 = (BigInteger) tin.invokeCode(big1, big1);
assertDeepEquals(res1, res3);
}
}
@Test
public void testMontgomery() throws ClassNotFoundException {
org.junit.Assume.assumeTrue(config.useMontgomeryMultiplyIntrinsic() ||
config.useMontgomerySquareIntrinsic());
org.junit.Assume.assumeTrue(getTarget().arch instanceof AMD64);
Class<?> javaclass = Class.forName("java.math.BigInteger");
TestIntrinsic tin = new TestIntrinsic("testMontgomeryAux", javaclass,
"modPow", BigInteger.class, BigInteger.class);
for (int i = 0; i < N; i++) {
BigInteger big1 = randomBig(i);
BigInteger big2 = randomBig(i);
BigInteger res1 = (BigInteger) tin.invokeJava(big1, bigTwo, big2);
BigInteger res2 = (BigInteger) tin.invokeTest(big1, bigTwo, big2);
assertDeepEquals(res1, res2);
BigInteger res3 = (BigInteger) tin.invokeCode(big1, bigTwo, big2);
assertDeepEquals(res1, res3);
}
}
public static BigInteger testMultiplyAux(BigInteger a, BigInteger b) {
return a.multiply(b);
}
public static BigInteger testMontgomeryAux(BigInteger a, BigInteger exp, BigInteger b) {
return a.modPow(exp, b);
}
private class TestIntrinsic {
TestIntrinsic(String testmname, Class<?> javaclass, String javamname, Class<?>... params) {
javamethod = getResolvedJavaMethod(javaclass, javamname, params);
testmethod = getResolvedJavaMethod(testmname);
assert javamethod != null;
assert testmethod != null;
testcode = getCode(testmethod);
assert testcode != null;
}
Object invokeJava(BigInteger big, Object... args) {
return invokeSafe(javamethod, big, args);
}
Object invokeTest(Object... args) {
return invokeSafe(testmethod, null, args);
}
Object invokeCode(Object... args) {
return executeVarargsSafe(testcode, args);
}
private ResolvedJavaMethod javamethod;
private ResolvedJavaMethod testmethod;
private InstalledCode testcode;
}
private static GraalHotSpotVMConfig config = ((HotSpotGraalRuntimeProvider) Graal.getRequiredCapability(RuntimeProvider.class)).getVMConfig();
private static BigInteger bigTwo = BigInteger.valueOf(2);
private static Random rnd = new Random(17);
private static BigInteger randomBig(int i) {
return new BigInteger(rnd.nextInt(4096) + i2sz(i), rnd);
}
private static int i2sz(int i) {
return i * 3 + 1;
}
}