package sun.security.rsa;
import java.math.BigInteger;
import java.security.*;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.RSAKeyGenParameterSpec;
import static java.math.BigInteger.*;
import sun.security.jca.JCAUtil;
import sun.security.rsa.RSAUtil.KeyType;
import static sun.security.util.SecurityProviderConstants.DEF_RSA_KEY_SIZE;
import static sun.security.util.SecurityProviderConstants.DEF_RSASSA_PSS_KEY_SIZE;
public abstract class RSAKeyPairGenerator extends KeyPairGeneratorSpi {
private static final BigInteger SQRT_2048;
private static final BigInteger SQRT_3072;
private static final BigInteger SQRT_4096;
static {
SQRT_2048 = TWO.pow(2047).sqrt();
SQRT_3072 = TWO.pow(3071).sqrt();
SQRT_4096 = TWO.pow(4095).sqrt();
}
private BigInteger publicExponent;
private int keySize;
private final KeyType type;
private AlgorithmParameterSpec keyParams;
private SecureRandom random;
private boolean useNew;
RSAKeyPairGenerator(KeyType type, int defKeySize) {
this.type = type;
initialize(defKeySize, null);
}
public void initialize(int keySize, SecureRandom random) {
try {
initialize(new RSAKeyGenParameterSpec(keySize,
RSAKeyGenParameterSpec.F4), random);
} catch (InvalidAlgorithmParameterException iape) {
throw new InvalidParameterException(iape.getMessage());
}
}
public void initialize(AlgorithmParameterSpec params, SecureRandom random)
throws InvalidAlgorithmParameterException {
if (params instanceof RSAKeyGenParameterSpec == false) {
throw new InvalidAlgorithmParameterException
("Params must be instance of RSAKeyGenParameterSpec");
}
RSAKeyGenParameterSpec rsaSpec = (RSAKeyGenParameterSpec)params;
int tmpKeySize = rsaSpec.getKeysize();
BigInteger tmpPubExp = rsaSpec.getPublicExponent();
AlgorithmParameterSpec tmpParams = rsaSpec.getKeyParams();
boolean useNew = (tmpKeySize >= 2048 && ((tmpKeySize & 1) == 0));
if (tmpPubExp == null) {
tmpPubExp = RSAKeyGenParameterSpec.F4;
} else {
if (!tmpPubExp.testBit(0)) {
throw new InvalidAlgorithmParameterException
("Public exponent must be an odd number");
}
BigInteger minValue = RSAKeyGenParameterSpec.F0;
int maxBitLength = tmpKeySize;
if (tmpPubExp.compareTo(RSAKeyGenParameterSpec.F0) < 0) {
throw new InvalidAlgorithmParameterException
("Public exponent must be " + minValue + " or larger");
}
if (tmpPubExp.bitLength() > maxBitLength) {
throw new InvalidAlgorithmParameterException
("Public exponent must be no longer than " +
maxBitLength + " bits");
}
useNew &= ((tmpPubExp.compareTo(RSAKeyGenParameterSpec.F4) >= 0) &&
(tmpPubExp.bitLength() < 256));
}
try {
RSAKeyFactory.checkKeyLengths(tmpKeySize, tmpPubExp, 512,
64 * 1024);
} catch (InvalidKeyException e) {
throw new InvalidAlgorithmParameterException(
"Invalid key sizes", e);
}
try {
this.keyParams = RSAUtil.checkParamsAgainstType(type, tmpParams);
} catch (ProviderException e) {
throw new InvalidAlgorithmParameterException(
"Invalid key parameters", e);
}
this.keySize = tmpKeySize;
this.publicExponent = tmpPubExp;
this.random = (random == null? JCAUtil.getSecureRandom() : random);
this.useNew = useNew;
}
public KeyPair generateKeyPair() {
BigInteger e = publicExponent;
BigInteger minValue = (useNew? getSqrt(keySize) : ZERO);
int lp = (keySize + 1) >> 1;;
int lq = keySize - lp;
int pqDiffSize = lp - 100;
while (true) {
BigInteger p = null;
BigInteger q = null;
int i = 0;
while (i++ < 10*lp) {
BigInteger tmpP = BigInteger.probablePrime(lp, random);
if ((!useNew || tmpP.compareTo(minValue) == 1) &&
isRelativePrime(e, tmpP.subtract(ONE))) {
p = tmpP;
break;
}
}
if (p == null) {
throw new ProviderException("Cannot find prime P");
}
i = 0;
while (i++ < 20*lq) {
BigInteger tmpQ = BigInteger.probablePrime(lq, random);
if ((!useNew || tmpQ.compareTo(minValue) == 1) &&
(p.subtract(tmpQ).abs().compareTo
(TWO.pow(pqDiffSize)) == 1) &&
isRelativePrime(e, tmpQ.subtract(ONE))) {
q = tmpQ;
break;
}
}
if (q == null) {
throw new ProviderException("Cannot find prime Q");
}
BigInteger n = p.multiply(q);
if (n.bitLength() != keySize) {
continue;
}
KeyPair kp = createKeyPair(type, keyParams, n, e, p, q);
if (kp != null) return kp;
}
}
private static BigInteger getSqrt(int keySize) {
BigInteger sqrt = null;
switch (keySize) {
case 2048:
sqrt = SQRT_2048;
break;
case 3072:
sqrt = SQRT_3072;
break;
case 4096:
sqrt = SQRT_4096;
break;
default:
sqrt = TWO.pow(keySize-1).sqrt();
}
return sqrt;
}
private static boolean isRelativePrime(BigInteger e, BigInteger bi) {
if (e.compareTo(RSAKeyGenParameterSpec.F4) == 0 ||
e.compareTo(RSAKeyGenParameterSpec.F0) == 0) {
return !bi.mod(e).equals(ZERO);
} else {
return e.gcd(bi).equals(ONE);
}
}
private static KeyPair createKeyPair(KeyType type,
AlgorithmParameterSpec keyParams,
BigInteger n, BigInteger e, BigInteger p, BigInteger q) {
BigInteger p1 = p.subtract(ONE);
BigInteger q1 = q.subtract(ONE);
BigInteger phi = p1.multiply(q1);
BigInteger gcd = p1.gcd(q1);
BigInteger lcm = (gcd.equals(ONE)? phi : phi.divide(gcd));
BigInteger d = e.modInverse(lcm);
if (d.compareTo(TWO.pow(p.bitLength())) != 1) {
return null;
}
BigInteger pe = d.mod(p1);
BigInteger qe = d.mod(q1);
BigInteger coeff = q.modInverse(p);
try {
PublicKey publicKey = new RSAPublicKeyImpl(type, keyParams, n, e);
PrivateKey privateKey = new RSAPrivateCrtKeyImpl(
type, keyParams, n, e, d, p, q, pe, qe, coeff);
return new KeyPair(publicKey, privateKey);
} catch (InvalidKeyException exc) {
throw new RuntimeException(exc);
}
}
public static final class Legacy extends RSAKeyPairGenerator {
public Legacy() {
super(KeyType.RSA, DEF_RSA_KEY_SIZE);
}
}
public static final class PSS extends RSAKeyPairGenerator {
public PSS() {
super(KeyType.PSS, DEF_RSASSA_PSS_KEY_SIZE);
}
}
}