package org.graalvm.compiler.hotspot.test;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.AlgorithmParameters;
import java.security.SecureRandom;
import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import org.junit.Assert;
import org.junit.Test;
import org.graalvm.compiler.code.CompilationResult;
import org.graalvm.compiler.hotspot.meta.HotSpotGraphBuilderPlugins;
import jdk.vm.ci.code.InstalledCode;
import jdk.vm.ci.meta.ResolvedJavaMethod;
public class HotSpotCryptoSubstitutionTest extends HotSpotGraalCompilerTest {
@Override
protected InstalledCode addMethod(ResolvedJavaMethod method, CompilationResult compResult) {
return getBackend().createDefaultInstalledCode(method, compResult);
}
SecretKey aesKey;
SecretKey desKey;
byte[] input;
ByteArrayOutputStream aesExpected = new ByteArrayOutputStream();
ByteArrayOutputStream desExpected = new ByteArrayOutputStream();
public HotSpotCryptoSubstitutionTest() throws Exception {
byte[] seed = {0x4, 0x7, 0x1, 0x1};
SecureRandom random = new SecureRandom(seed);
KeyGenerator aesKeyGen = KeyGenerator.getInstance("AES");
KeyGenerator desKeyGen = KeyGenerator.getInstance("DESede");
aesKeyGen.init(128, random);
desKeyGen.init(168, random);
aesKey = aesKeyGen.generateKey();
desKey = desKeyGen.generateKey();
input = readClassfile16(getClass());
aesExpected.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
aesExpected.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
desExpected.write(runEncryptDecrypt(desKey, "DESede/CBC/NoPadding"));
desExpected.write(runEncryptDecrypt(desKey, "DESede/CBC/PKCS5Padding"));
}
@Test
public void testAESCryptIntrinsics() throws Exception {
if (compileAndInstall("com.sun.crypto.provider.AESCrypt", HotSpotGraphBuilderPlugins.aesEncryptName, HotSpotGraphBuilderPlugins.aesDecryptName)) {
ByteArrayOutputStream actual = new ByteArrayOutputStream();
actual.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
actual.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
Assert.assertArrayEquals(aesExpected.toByteArray(), actual.toByteArray());
}
}
@Test
public void testCipherBlockChainingIntrinsics() throws Exception {
if (compileAndInstall("com.sun.crypto.provider.CipherBlockChaining", HotSpotGraphBuilderPlugins.cbcEncryptName, HotSpotGraphBuilderPlugins.cbcDecryptName)) {
ByteArrayOutputStream actual = new ByteArrayOutputStream();
actual.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
actual.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
Assert.assertArrayEquals(aesExpected.toByteArray(), actual.toByteArray());
actual.reset();
actual.write(runEncryptDecrypt(desKey, "DESede/CBC/NoPadding"));
actual.write(runEncryptDecrypt(desKey, "DESede/CBC/PKCS5Padding"));
Assert.assertArrayEquals(desExpected.toByteArray(), actual.toByteArray());
}
}
private boolean compileAndInstall(String className, String... methodNames) {
if (!runtime().getVMConfig().useAESIntrinsics) {
return false;
}
Class<?> c;
try {
c = Class.forName(className);
} catch (ClassNotFoundException e) {
return false;
}
boolean atLeastOneCompiled = false;
for (String methodName : methodNames) {
if (compileAndInstallSubstitution(c, methodName) != null) {
atLeastOneCompiled = true;
}
}
return atLeastOneCompiled;
}
AlgorithmParameters algorithmParameters;
private byte[] encrypt(byte[] indata, SecretKey key, String algorithm) throws Exception {
byte[] result = indata;
Cipher c = Cipher.getInstance(algorithm);
c.init(Cipher.ENCRYPT_MODE, key);
algorithmParameters = c.getParameters();
byte[] r1 = c.update(result);
byte[] r2 = c.doFinal();
result = new byte[r1.length + r2.length];
System.arraycopy(r1, 0, result, 0, r1.length);
System.arraycopy(r2, 0, result, r1.length, r2.length);
return result;
}
private byte[] decrypt(byte[] indata, SecretKey key, String algorithm) throws Exception {
byte[] result = indata;
Cipher c = Cipher.getInstance(algorithm);
c.init(Cipher.DECRYPT_MODE, key, algorithmParameters);
byte[] r1 = c.update(result);
byte[] r2 = c.doFinal();
result = new byte[r1.length + r2.length];
System.arraycopy(r1, 0, result, 0, r1.length);
System.arraycopy(r2, 0, result, r1.length, r2.length);
return result;
}
private static byte[] readClassfile16(Class<? extends HotSpotCryptoSubstitutionTest> c) throws IOException {
String classFilePath = "/" + c.getName().replace('.', '/') + ".class";
InputStream stream = c.getResourceAsStream(classFilePath);
int bytesToRead = stream.available();
bytesToRead -= bytesToRead % 16;
byte[] classFile = new byte[bytesToRead];
new DataInputStream(stream).readFully(classFile);
return classFile;
}
public byte[] runEncryptDecrypt(SecretKey key, String algorithm) throws Exception {
byte[] indata = input.clone();
byte[] cipher = encrypt(indata, key, algorithm);
byte[] plain = decrypt(cipher, key, algorithm);
Assert.assertArrayEquals(indata, plain);
return plain;
}
}