package com.oracle.security.ucrypto;
import java.util.Arrays;
import java.util.WeakHashMap;
import java.util.Collections;
import java.util.Map;
import java.security.AlgorithmParameters;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.PublicKey;
import java.security.PrivateKey;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.security.spec.RSAPrivateKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.security.interfaces.RSAKey;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.security.spec.InvalidKeySpecException;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.CipherSpi;
import javax.crypto.SecretKey;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.SecretKeySpec;
import sun.security.internal.spec.TlsRsaPremasterSecretParameterSpec;
import sun.security.jca.JCAUtil;
import sun.security.util.KeyUtil;
public class NativeRSACipher extends CipherSpi {
private final UcryptoMech mech;
private final int padLen;
private final NativeRSAKeyFactory keyFactory;
private AlgorithmParameterSpec spec;
private SecureRandom random;
private static final Map<Key, NativeKey> keyList =
Collections.synchronizedMap(new WeakHashMap<Key, NativeKey>());
private NativeKey key = null;
private int outputSize = 0;
private boolean encrypt = true;
private byte[] buffer;
private int bufOfs = 0;
public static final class NoPadding extends NativeRSACipher {
public NoPadding() throws NoSuchAlgorithmException {
super(UcryptoMech.CRYPTO_RSA_X_509, 0);
}
}
public static final class PKCS1Padding extends NativeRSACipher {
public PKCS1Padding() throws NoSuchAlgorithmException {
super(UcryptoMech.CRYPTO_RSA_PKCS, 11);
}
}
NativeRSACipher(UcryptoMech mech, int padLen)
throws NoSuchAlgorithmException {
this.mech = mech;
this.padLen = padLen;
this.keyFactory = new NativeRSAKeyFactory();
}
@Override
protected void engineSetMode(String mode) throws NoSuchAlgorithmException {
throw new NoSuchAlgorithmException("Unsupported mode " + mode);
}
@Override
protected void engineSetPadding(String padding)
throws NoSuchPaddingException {
throw new NoSuchPaddingException("Unsupported padding " + padding);
}
@Override
protected int engineGetBlockSize() {
return 0;
}
@Override
protected synchronized int engineGetOutputSize(int inputLen) {
return outputSize;
}
@Override
protected byte[] engineGetIV() {
return null;
}
@Override
protected AlgorithmParameters engineGetParameters() {
return null;
}
@Override
protected int engineGetKeySize(Key key) throws InvalidKeyException {
if (!(key instanceof RSAKey)) {
throw new InvalidKeyException("RSAKey required. Got: " +
key.getClass().getName());
}
int n = ((RSAKey)key).getModulus().bitLength();
int realByteSize = (n + 7) >> 3;
return realByteSize * 8;
}
@Override
protected synchronized void engineInit(int opmode, Key key, SecureRandom random)
throws InvalidKeyException {
try {
engineInit(opmode, key, (AlgorithmParameterSpec)null, random);
} catch (InvalidAlgorithmParameterException e) {
throw new InvalidKeyException("init() failed", e);
}
}
@Override
@SuppressWarnings("deprecation")
protected synchronized void engineInit(int opmode, Key newKey,
AlgorithmParameterSpec params, SecureRandom random)
throws InvalidKeyException, InvalidAlgorithmParameterException {
if (newKey == null) {
throw new InvalidKeyException("Key cannot be null");
}
if (opmode != Cipher.ENCRYPT_MODE &&
opmode != Cipher.DECRYPT_MODE &&
opmode != Cipher.WRAP_MODE &&
opmode != Cipher.UNWRAP_MODE) {
throw new InvalidAlgorithmParameterException
("Unsupported mode: " + opmode);
}
if (params != null) {
if (!(params instanceof TlsRsaPremasterSecretParameterSpec)) {
throw new InvalidAlgorithmParameterException(
"No Parameters can be specified");
}
spec = params;
if (random == null) {
random = JCAUtil.getSecureRandom();
}
this.random = random;
}
boolean doEncrypt = (opmode == Cipher.ENCRYPT_MODE || opmode == Cipher.WRAP_MODE);
if (doEncrypt && (!(newKey instanceof RSAPublicKey))) {
throw new InvalidKeyException("RSAPublicKey required for encryption." +
" Received: " + newKey.getClass().getName());
} else if (!doEncrypt && (!(newKey instanceof RSAPrivateKey))) {
throw new InvalidKeyException("RSAPrivateKey required for decryption." +
" Received: " + newKey.getClass().getName());
}
NativeKey nativeKey = null;
nativeKey = keyList.get(newKey);
if (nativeKey == null) {
if (doEncrypt) {
RSAPublicKey publicKey = (RSAPublicKey) newKey;
try {
nativeKey = (NativeKey) keyFactory.engineGeneratePublic
(new RSAPublicKeySpec(publicKey.getModulus(), publicKey.getPublicExponent()));
} catch (InvalidKeySpecException ikse) {
throw new InvalidKeyException(ikse);
}
} else {
try {
if (newKey instanceof RSAPrivateCrtKey) {
RSAPrivateCrtKey privateKey = (RSAPrivateCrtKey) newKey;
nativeKey = (NativeKey) keyFactory.engineGeneratePrivate
(new RSAPrivateCrtKeySpec(privateKey.getModulus(),
privateKey.getPublicExponent(),
privateKey.getPrivateExponent(),
privateKey.getPrimeP(),
privateKey.getPrimeQ(),
privateKey.getPrimeExponentP(),
privateKey.getPrimeExponentQ(),
privateKey.getCrtCoefficient()));
} else if (newKey instanceof RSAPrivateKey) {
RSAPrivateKey privateKey = (RSAPrivateKey) newKey;
nativeKey = (NativeKey) keyFactory.engineGeneratePrivate
(new RSAPrivateKeySpec(privateKey.getModulus(),
privateKey.getPrivateExponent()));
} else {
throw new InvalidKeyException("Unsupported type of RSAPrivateKey." +
" Received: " + newKey.getClass().getName());
}
} catch (InvalidKeySpecException ikse) {
throw new InvalidKeyException(ikse);
}
}
keyList.put(newKey, nativeKey);
}
init(doEncrypt, nativeKey);
}
@Override
protected synchronized void engineInit(int opmode, Key key, AlgorithmParameters params,
SecureRandom random)
throws InvalidKeyException, InvalidAlgorithmParameterException {
if (params != null) {
throw new InvalidAlgorithmParameterException("No Parameters can be specified");
}
engineInit(opmode, key, (AlgorithmParameterSpec) null, random);
}
@Override
protected synchronized byte[] engineUpdate(byte[] in, int inOfs, int inLen) {
if (inLen > 0) {
update(in, inOfs, inLen);
}
return null;
}
@Override
protected synchronized int engineUpdate(byte[] in, int inOfs, int inLen, byte[] out,
int outOfs) throws ShortBufferException {
if (out.length - outOfs < outputSize) {
throw new ShortBufferException("Output buffer too small. outputSize: " +
outputSize + ". out.length: " + out.length + ". outOfs: " + outOfs);
}
if (inLen > 0) {
update(in, inOfs, inLen);
}
return 0;
}
@Override
protected synchronized byte[] engineDoFinal(byte[] in, int inOfs, int inLen)
throws IllegalBlockSizeException, BadPaddingException {
byte[] out = new byte[outputSize];
try {
int actualLen = engineDoFinal(in, inOfs, inLen, out, 0);
if (actualLen != outputSize) {
return Arrays.copyOf(out, actualLen);
} else {
return out;
}
} catch (ShortBufferException e) {
throw new UcryptoException("Internal Error", e);
}
}
@Override
protected synchronized int engineDoFinal(byte[] in, int inOfs, int inLen, byte[] out,
int outOfs)
throws ShortBufferException, IllegalBlockSizeException,
BadPaddingException {
if (inLen != 0) {
update(in, inOfs, inLen);
}
return doFinal(out, outOfs, out.length - outOfs);
}
@Override
protected synchronized byte[] engineWrap(Key key) throws IllegalBlockSizeException,
InvalidKeyException {
try {
byte[] encodedKey = key.getEncoded();
if ((encodedKey == null) || (encodedKey.length == 0)) {
throw new InvalidKeyException("Cannot get an encoding of " +
"the key to be wrapped");
}
if (encodedKey.length > buffer.length) {
throw new InvalidKeyException("Key is too long for wrapping. " +
"encodedKey.length: " + encodedKey.length +
". buffer.length: " + buffer.length);
}
return engineDoFinal(encodedKey, 0, encodedKey.length);
} catch (BadPaddingException e) {
throw new UcryptoException("Internal Error", e);
}
}
@Override
@SuppressWarnings("deprecation")
protected synchronized Key engineUnwrap(byte[] wrappedKey,
String wrappedKeyAlgorithm, int wrappedKeyType)
throws InvalidKeyException, NoSuchAlgorithmException {
if (wrappedKey.length > buffer.length) {
throw new InvalidKeyException("Key is too long for unwrapping." +
" wrappedKey.length: " + wrappedKey.length +
". buffer.length: " + buffer.length);
}
boolean isTlsRsaPremasterSecret =
wrappedKeyAlgorithm.equals("TlsRsaPremasterSecret");
Exception failover = null;
byte[] encodedKey = null;
try {
encodedKey = engineDoFinal(wrappedKey, 0, wrappedKey.length);
} catch (BadPaddingException bpe) {
if (isTlsRsaPremasterSecret) {
failover = bpe;
} else {
throw new InvalidKeyException("Unwrapping failed", bpe);
}
} catch (Exception e) {
throw new InvalidKeyException("Unwrapping failed", e);
}
if (isTlsRsaPremasterSecret) {
if (!(spec instanceof TlsRsaPremasterSecretParameterSpec)) {
throw new IllegalStateException(
"No TlsRsaPremasterSecretParameterSpec specified");
}
encodedKey = KeyUtil.checkTlsPreMasterSecretKey(
((TlsRsaPremasterSecretParameterSpec)spec).getClientVersion(),
((TlsRsaPremasterSecretParameterSpec)spec).getServerVersion(),
random, encodedKey, (failover != null));
}
return NativeCipher.constructKey(wrappedKeyType,
encodedKey, wrappedKeyAlgorithm);
}
private native static int nativeAtomic(int mech, boolean encrypt,
long keyValue, int keyLength,
byte[] in, int inLen,
byte[] out, int ouOfs, int outLen);
private void init(boolean encrypt, NativeKey key) {
this.encrypt = encrypt;
this.key = key;
try {
this.outputSize = engineGetKeySize(key)/8;
} catch (InvalidKeyException ike) {
throw new UcryptoException("Internal Error", ike);
}
this.buffer = new byte[outputSize];
this.bufOfs = 0;
}
private void update(byte[] in, int inOfs, int inLen) {
if ((inLen <= 0) || (in == null)) {
return;
}
if ((bufOfs + inLen + (encrypt? padLen:0)) > buffer.length) {
bufOfs = buffer.length + 1;
return;
}
System.arraycopy(in, inOfs, buffer, bufOfs, inLen);
bufOfs += inLen;
}
private int doFinal(byte[] out, int outOfs, int outLen)
throws ShortBufferException, IllegalBlockSizeException,
BadPaddingException {
if (bufOfs > buffer.length) {
throw new IllegalBlockSizeException(
"Data must not be longer than " +
(buffer.length - (encrypt ? padLen : 0)) + " bytes");
}
if (outLen < outputSize) {
throw new ShortBufferException();
}
try {
long keyValue = key.value();
int k = nativeAtomic(mech.value(), encrypt, keyValue,
key.length(), buffer, bufOfs,
out, outOfs, outLen);
if (k < 0) {
if ( k == -16 || k == -64) {
UcryptoException ue = new UcryptoException(16);
BadPaddingException bpe =
new BadPaddingException("Invalid encryption data");
bpe.initCause(ue);
throw bpe;
}
throw new UcryptoException(-k);
}
return k;
} finally {
bufOfs = 0;
}
}
}