package com.oracle.security.ucrypto;
import java.util.Set;
import java.util.Arrays;
import java.util.concurrent.ConcurrentSkipListSet;
import java.lang.ref.*;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.security.SignatureSpi;
import java.security.NoSuchAlgorithmException;
import java.security.InvalidParameterException;
import java.security.InvalidKeyException;
import java.security.SignatureException;
import java.security.Key;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.*;
import java.security.interfaces.*;
import java.security.spec.*;
import sun.nio.ch.DirectBuffer;
import java.nio.ByteBuffer;
class NativeRSASignature extends SignatureSpi {
private static final int PKCS1PADDING_LEN = 11;
private final UcryptoMech mech;
private final int encodedLen;
private SignatureContextRef pCtxt = null;
private boolean initialized = false;
private boolean sign = true;
private int sigLength;
private NativeKey key;
private NativeRSAKeyFactory keyFactory;
public static final class MD5 extends NativeRSASignature {
public MD5() throws NoSuchAlgorithmException {
super(UcryptoMech.CRYPTO_MD5_RSA_PKCS, 34);
}
}
public static final class SHA1 extends NativeRSASignature {
public SHA1() throws NoSuchAlgorithmException {
super(UcryptoMech.CRYPTO_SHA1_RSA_PKCS, 35);
}
}
public static final class SHA256 extends NativeRSASignature {
public SHA256() throws NoSuchAlgorithmException {
super(UcryptoMech.CRYPTO_SHA256_RSA_PKCS, 51);
}
}
public static final class SHA384 extends NativeRSASignature {
public SHA384() throws NoSuchAlgorithmException {
super(UcryptoMech.CRYPTO_SHA384_RSA_PKCS, 67);
}
}
public static final class SHA512 extends NativeRSASignature {
public SHA512() throws NoSuchAlgorithmException {
super(UcryptoMech.CRYPTO_SHA512_RSA_PKCS, 83);
}
}
private static class SignatureContextRef extends PhantomReference<NativeRSASignature>
implements Comparable<SignatureContextRef> {
private static ReferenceQueue<NativeRSASignature> refQueue =
new ReferenceQueue<NativeRSASignature>();
private static Set<SignatureContextRef> refList =
new ConcurrentSkipListSet<SignatureContextRef>();
private final long id;
private final boolean sign;
private static void drainRefQueueBounded() {
while (true) {
SignatureContextRef next = (SignatureContextRef) refQueue.poll();
if (next == null) break;
next.dispose(true);
}
}
SignatureContextRef(NativeRSASignature ns, long id, boolean sign) {
super(ns, refQueue);
this.id = id;
this.sign = sign;
refList.add(this);
UcryptoProvider.debug("Resource: track Signature Ctxt " + this.id);
drainRefQueueBounded();
}
public int compareTo(SignatureContextRef other) {
if (this.id == other.id) {
return 0;
} else {
return (this.id < other.id) ? -1 : 1;
}
}
void dispose(boolean doCancel) {
refList.remove(this);
try {
if (doCancel) {
UcryptoProvider.debug("Resource: free Signature Ctxt " + this.id);
NativeRSASignature.nativeFinal(id, sign, null, 0, 0);
} else {
UcryptoProvider.debug("Resource: stop tracking Signature Ctxt " + this.id);
}
} finally {
this.clear();
}
}
}
NativeRSASignature(UcryptoMech mech, int encodedLen)
throws NoSuchAlgorithmException {
this.mech = mech;
this.encodedLen = encodedLen;
this.keyFactory = new NativeRSAKeyFactory();
}
@Override
@SuppressWarnings("deprecation")
protected Object engineGetParameter(String param) throws InvalidParameterException {
throw new UnsupportedOperationException("getParameter() not supported");
}
@Override
protected AlgorithmParameters engineGetParameters() {
return null;
}
@Override
protected synchronized void engineInitSign(PrivateKey privateKey)
throws InvalidKeyException {
if (privateKey == null) {
throw new InvalidKeyException("Key must not be null");
}
NativeKey newKey = key;
int newSigLength = sigLength;
if (privateKey != key) {
if (!(privateKey instanceof RSAPrivateKey)) {
throw new InvalidKeyException("RSAPrivateKey required. " +
"Received: " + privateKey.getClass().getName());
}
RSAPrivateKey rsaPrivKey = (RSAPrivateKey) privateKey;
BigInteger mod = rsaPrivKey.getModulus();
newSigLength = checkRSAKeyLength(mod);
BigInteger pe = rsaPrivKey.getPrivateExponent();
try {
if (rsaPrivKey instanceof RSAPrivateCrtKey) {
RSAPrivateCrtKey rsaPrivCrtKey = (RSAPrivateCrtKey) rsaPrivKey;
newKey = (NativeKey) keyFactory.engineGeneratePrivate
(new RSAPrivateCrtKeySpec(mod,
rsaPrivCrtKey.getPublicExponent(),
pe,
rsaPrivCrtKey.getPrimeP(),
rsaPrivCrtKey.getPrimeQ(),
rsaPrivCrtKey.getPrimeExponentP(),
rsaPrivCrtKey.getPrimeExponentQ(),
rsaPrivCrtKey.getCrtCoefficient()));
} else {
newKey = (NativeKey) keyFactory.engineGeneratePrivate
(new RSAPrivateKeySpec(mod, pe));
}
} catch (InvalidKeySpecException ikse) {
throw new InvalidKeyException(ikse);
}
}
init(true, newKey, newSigLength);
}
@Override
protected synchronized void engineInitVerify(PublicKey publicKey)
throws InvalidKeyException {
if (publicKey == null) {
throw new InvalidKeyException("Key must not be null");
}
NativeKey newKey = key;
int newSigLength = sigLength;
if (publicKey != key) {
if (publicKey instanceof RSAPublicKey) {
BigInteger mod = ((RSAPublicKey) publicKey).getModulus();
newSigLength = checkRSAKeyLength(mod);
try {
newKey = (NativeKey) keyFactory.engineGeneratePublic
(new RSAPublicKeySpec(mod, ((RSAPublicKey) publicKey).getPublicExponent()));
} catch (InvalidKeySpecException ikse) {
throw new InvalidKeyException(ikse);
}
} else {
throw new InvalidKeyException("RSAPublicKey required. " +
"Received: " + publicKey.getClass().getName());
}
}
init(false, newKey, newSigLength);
}
@Override
@SuppressWarnings("deprecation")
protected void engineSetParameter(String param, Object value) throws InvalidParameterException {
throw new UnsupportedOperationException("setParameter() not supported");
}
@Override
protected void engineSetParameter(AlgorithmParameterSpec params)
throws InvalidAlgorithmParameterException {
if (params != null) {
throw new InvalidAlgorithmParameterException("No parameter accepted");
}
}
@Override
protected synchronized byte[] engineSign() throws SignatureException {
try {
byte[] sig = new byte[sigLength];
int rv = doFinal(sig, 0, sigLength);
if (rv < 0) {
throw new SignatureException(new UcryptoException(-rv));
}
return sig;
} finally {
reset(false);
}
}
@Override
protected synchronized int engineSign(byte[] outbuf, int offset, int len)
throws SignatureException {
boolean doCancel = true;
try {
if (outbuf == null || (offset < 0) ||
((outbuf.length - offset) < sigLength) ||
(len < sigLength)) {
throw new SignatureException("Invalid output buffer. offset: " +
offset + ". len: " + len + ". sigLength: " + sigLength);
}
int rv = doFinal(outbuf, offset, sigLength);
doCancel = false;
if (rv < 0) {
throw new SignatureException(new UcryptoException(-rv));
}
return sigLength;
} finally {
reset(doCancel);
}
}
@Override
protected synchronized void engineUpdate(byte b) throws SignatureException {
byte[] in = { b };
int rv = update(in, 0, 1);
if (rv < 0) {
throw new SignatureException(new UcryptoException(-rv));
}
}
@Override
protected synchronized void engineUpdate(byte[] in, int inOfs, int inLen)
throws SignatureException {
if (in == null || inOfs < 0 || inLen == 0) return;
int rv = update(in, inOfs, inLen);
if (rv < 0) {
throw new SignatureException(new UcryptoException(-rv));
}
}
@Override
protected synchronized void engineUpdate(ByteBuffer in) {
if (in == null || in.remaining() == 0) return;
if (in instanceof DirectBuffer == false) {
super.engineUpdate(in);
return;
}
long inAddr = ((DirectBuffer)in).address();
int inOfs = in.position();
int inLen = in.remaining();
int rv = update((inAddr + inOfs), inLen);
if (rv < 0) {
throw new UcryptoException(-rv);
}
in.position(inOfs + inLen);
}
@Override
protected synchronized boolean engineVerify(byte[] sigBytes) throws SignatureException {
return engineVerify(sigBytes, 0, sigBytes.length);
}
@Override
protected synchronized boolean engineVerify(byte[] sigBytes, int sigOfs, int sigLen)
throws SignatureException {
boolean doCancel = true;
try {
if (sigBytes == null || (sigOfs < 0) ||
((sigBytes.length - sigOfs) < this.sigLength) ||
(sigLen != this.sigLength)) {
throw new SignatureException("Invalid signature length: got " +
sigLen + " but was expecting " + this.sigLength);
}
int rv = doFinal(sigBytes, sigOfs, sigLen);
doCancel = false;
if (rv == 0) {
return true;
} else {
UcryptoProvider.debug("Signature: " + mech + " verification error " +
new UcryptoException(-rv).getMessage());
return false;
}
} finally {
reset(doCancel);
}
}
void reset(boolean doCancel) {
initialized = false;
if (pCtxt != null) {
pCtxt.dispose(doCancel);
pCtxt = null;
}
}
private native static long nativeInit(int mech, boolean sign,
long keyValue, int keyLength);
private native static int nativeUpdate(long pContext, boolean sign,
byte[] in, int inOfs, int inLen);
private native static int nativeUpdate(long pContext, boolean sign,
long pIn, int inLen);
private native static int nativeFinal(long pContext, boolean sign,
byte[] sig, int sigOfs, int sigLen);
private void init(boolean sign, NativeKey key, int sigLength) {
reset(true);
this.sign = sign;
this.sigLength = sigLength;
this.key = key;
long pCtxtVal = nativeInit(mech.value(), sign, key.value(),
key.length());
initialized = (pCtxtVal != 0L);
if (initialized) {
pCtxt = new SignatureContextRef(this, pCtxtVal, sign);
} else {
throw new UcryptoException("Cannot initialize Signature");
}
}
private void ensureInitialized() {
if (!initialized) {
init(sign, key, sigLength);
if (!initialized) {
throw new UcryptoException("Cannot initialize Signature");
}
}
}
private int update(byte[] in, int inOfs, int inLen) {
if (inOfs < 0 || inOfs > (in.length - inLen)) {
throw new ArrayIndexOutOfBoundsException("inOfs :" + inOfs +
". inLen: " + inLen + ". in.length: " + in.length);
}
ensureInitialized();
int k = nativeUpdate(pCtxt.id, sign, in, inOfs, inLen);
if (k < 0) {
reset(false);
}
return k;
}
private int update(long pIn, int inLen) {
ensureInitialized();
int k = nativeUpdate(pCtxt.id, sign, pIn, inLen);
if (k < 0) {
reset(false);
}
return k;
}
private int doFinal(byte[] sigBytes, int sigOfs, int sigLen) {
ensureInitialized();
int k = nativeFinal(pCtxt.id, sign, sigBytes, sigOfs, sigLen);
return k;
}
private int checkRSAKeyLength(BigInteger mod) throws InvalidKeyException {
int keySize = (mod.bitLength() + 7) >> 3;
int maxDataSize = keySize - PKCS1PADDING_LEN;
if (maxDataSize < encodedLen) {
throw new InvalidKeyException
("Key is too short for this signature algorithm. maxDataSize: " +
maxDataSize + ". encodedLen: " + encodedLen);
}
return keySize;
}
}