/*
 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

package sun.security.util.math.intpoly;

import sun.security.util.math.*;

import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;

A large number polynomial representation using sparse limbs of signed long (64-bit) values. Limb values will always fit within a long, so inputs to multiplication must be less than 32 bits. All IntegerPolynomial implementations allow at most one addition before multiplication. Additions after that will result in an ArithmeticException. The following element operations are branch-free for all subclasses: fixed mutable add additiveInverse multiply square subtract conditionalSwapWith setValue (may branch on high-order byte parameter only) setSum setDifference setProduct setSquare addModPowerTwo asByteArray All other operations may branch in some subclasses.
/** * A large number polynomial representation using sparse limbs of signed * long (64-bit) values. Limb values will always fit within a long, so inputs * to multiplication must be less than 32 bits. All IntegerPolynomial * implementations allow at most one addition before multiplication. Additions * after that will result in an ArithmeticException. * * The following element operations are branch-free for all subclasses: * * fixed * mutable * add * additiveInverse * multiply * square * subtract * conditionalSwapWith * setValue (may branch on high-order byte parameter only) * setSum * setDifference * setProduct * setSquare * addModPowerTwo * asByteArray * * All other operations may branch in some subclasses. * */
public abstract class IntegerPolynomial implements IntegerFieldModuloP { protected static final BigInteger TWO = BigInteger.valueOf(2); protected final int numLimbs; private final BigInteger modulus; protected final int bitsPerLimb; private final long[] posModLimbs; private final int maxAdds;
Reduce an IntegerPolynomial representation (a) and store the result in a. Requires that a.length == numLimbs.
/** * Reduce an IntegerPolynomial representation (a) and store the result * in a. Requires that a.length == numLimbs. */
protected abstract void reduce(long[] a);
Multiply an IntegerPolynomial representation (a) with a long (b) and store the result in an IntegerPolynomial representation in a. Requires that a.length == numLimbs.
/** * Multiply an IntegerPolynomial representation (a) with a long (b) and * store the result in an IntegerPolynomial representation in a. Requires * that a.length == numLimbs. */
protected void multByInt(long[] a, long b) { for (int i = 0; i < a.length; i++) { a[i] *= b; } reduce(a); }
Multiply two IntegerPolynomial representations (a and b) and store the result in an IntegerPolynomial representation (r). Requires that a.length == b.length == r.length == numLimbs. It is allowed for a and r to be the same array.
/** * Multiply two IntegerPolynomial representations (a and b) and store the * result in an IntegerPolynomial representation (r). Requires that * a.length == b.length == r.length == numLimbs. It is allowed for a and r * to be the same array. */
protected abstract void mult(long[] a, long[] b, long[] r);
Multiply an IntegerPolynomial representation (a) with itself and store the result in an IntegerPolynomialRepresentation (r). Requires that a.length == r.length == numLimbs. It is allowed for a and r to be the same array.
/** * Multiply an IntegerPolynomial representation (a) with itself and store * the result in an IntegerPolynomialRepresentation (r). Requires that * a.length == r.length == numLimbs. It is allowed for a and r * to be the same array. */
protected abstract void square(long[] a, long[] r); IntegerPolynomial(int bitsPerLimb, int numLimbs, int maxAdds, BigInteger modulus) { this.numLimbs = numLimbs; this.modulus = modulus; this.bitsPerLimb = bitsPerLimb; this.maxAdds = maxAdds; posModLimbs = setPosModLimbs(); } private long[] setPosModLimbs() { long[] result = new long[numLimbs]; setLimbsValuePositive(modulus, result); return result; } protected int getNumLimbs() { return numLimbs; } public int getMaxAdds() { return maxAdds; } @Override public BigInteger getSize() { return modulus; } @Override public ImmutableElement get0() { return new ImmutableElement(false); } @Override public ImmutableElement get1() { return new ImmutableElement(true); } @Override public ImmutableElement getElement(BigInteger v) { return new ImmutableElement(v); } @Override public SmallValue getSmallValue(int value) { int maxMag = 1 << (bitsPerLimb - 1); if (Math.abs(value) >= maxMag) { throw new IllegalArgumentException( "max magnitude is " + maxMag); } return new Limb(value); }
This version of encode takes a ByteBuffer that is properly ordered, and may extract larger values (e.g. long) from the ByteBuffer for better performance. The implementation below only extracts bytes from the buffer, but this method may be overridden in field-specific implementations.
/** * This version of encode takes a ByteBuffer that is properly ordered, and * may extract larger values (e.g. long) from the ByteBuffer for better * performance. The implementation below only extracts bytes from the * buffer, but this method may be overridden in field-specific * implementations. */
protected void encode(ByteBuffer buf, int length, byte highByte, long[] result) { int numHighBits = 32 - Integer.numberOfLeadingZeros(highByte); int numBits = 8 * length + numHighBits; int requiredLimbs = (numBits + bitsPerLimb - 1) / bitsPerLimb; if (requiredLimbs > numLimbs) { long[] temp = new long[requiredLimbs]; encodeSmall(buf, length, highByte, temp); // encode does a full carry/reduce System.arraycopy(temp, 0, result, 0, result.length); } else { encodeSmall(buf, length, highByte, result); } } protected void encodeSmall(ByteBuffer buf, int length, byte highByte, long[] result) { int limbIndex = 0; long curLimbValue = 0; int bitPos = 0; for (int i = 0; i < length; i++) { long curV = buf.get() & 0xFF; if (bitPos + 8 >= bitsPerLimb) { int bitsThisLimb = bitsPerLimb - bitPos; curLimbValue += (curV & (0xFF >> (8 - bitsThisLimb))) << bitPos; result[limbIndex++] = curLimbValue; curLimbValue = curV >> bitsThisLimb; bitPos = 8 - bitsThisLimb; } else { curLimbValue += curV << bitPos; bitPos += 8; } } // one more for the high byte if (highByte != 0) { long curV = highByte & 0xFF; if (bitPos + 8 >= bitsPerLimb) { int bitsThisLimb = bitsPerLimb - bitPos; curLimbValue += (curV & (0xFF >> (8 - bitsThisLimb))) << bitPos; result[limbIndex++] = curLimbValue; curLimbValue = curV >> bitsThisLimb; } else { curLimbValue += curV << bitPos; } } if (limbIndex < result.length) { result[limbIndex++] = curLimbValue; } Arrays.fill(result, limbIndex, result.length, 0); postEncodeCarry(result); } protected void encode(byte[] v, int offset, int length, byte highByte, long[] result) { ByteBuffer buf = ByteBuffer.wrap(v, offset, length); buf.order(ByteOrder.LITTLE_ENDIAN); encode(buf, length, highByte, result); } // Encode does not produce compressed limbs. A simplified carry/reduce // operation can be used to compress the limbs. protected void postEncodeCarry(long[] v) { reduce(v); } public ImmutableElement getElement(byte[] v, int offset, int length, byte highByte) { long[] result = new long[numLimbs]; encode(v, offset, length, highByte, result); return new ImmutableElement(result, 0); } protected BigInteger evaluate(long[] limbs) { BigInteger result = BigInteger.ZERO; for (int i = limbs.length - 1; i >= 0; i--) { result = result.shiftLeft(bitsPerLimb) .add(BigInteger.valueOf(limbs[i])); } return result.mod(modulus); } protected long carryValue(long x) { // compressing carry operation // if large positive number, carry one more to make it negative // if large negative number (closer to zero), carry one fewer return (x + (1 << (bitsPerLimb - 1))) >> bitsPerLimb; } protected void carry(long[] limbs, int start, int end) { for (int i = start; i < end; i++) { long carry = carryOut(limbs, i); limbs[i + 1] += carry; } } protected void carry(long[] limbs) { carry(limbs, 0, limbs.length - 1); }
Carry out of the specified position and return the carry value.
/** * Carry out of the specified position and return the carry value. */
protected long carryOut(long[] limbs, int index) { long carry = carryValue(limbs[index]); limbs[index] -= (carry << bitsPerLimb); return carry; } private void setLimbsValue(BigInteger v, long[] limbs) { // set all limbs positive, and then carry setLimbsValuePositive(v, limbs); carry(limbs); } protected void setLimbsValuePositive(BigInteger v, long[] limbs) { BigInteger mod = BigInteger.valueOf(1 << bitsPerLimb); for (int i = 0; i < limbs.length; i++) { limbs[i] = v.mod(mod).longValue(); v = v.shiftRight(bitsPerLimb); } }
Carry out of the last limb and reduce back in. This method will be called as part of the "finalReduce" operation that puts the representation into a fully-reduced form. It is representation- specific, because representations have different amounts of empty space in the high-order limb. Requires that limbs.length=numLimbs.
/** * Carry out of the last limb and reduce back in. This method will be * called as part of the "finalReduce" operation that puts the * representation into a fully-reduced form. It is representation- * specific, because representations have different amounts of empty * space in the high-order limb. Requires that limbs.length=numLimbs. */
protected abstract void finalCarryReduceLast(long[] limbs);
Convert reduced limbs into a number between 0 and MODULUS-1. Requires that limbs.length == numLimbs. This method only works if the modulus has at most three terms.
/** * Convert reduced limbs into a number between 0 and MODULUS-1. * Requires that limbs.length == numLimbs. This method only works if the * modulus has at most three terms. */
protected void finalReduce(long[] limbs) { // This method works by doing several full carry/reduce operations. // Some representations have extra high bits, so the carry/reduce out // of the high position is implementation-specific. The "unsigned" // carry operation always carries some (negative) value out of a // position occupied by a negative value. So after a number of // passes, all negative values are removed. // The first pass may leave a negative value in the high position, but // this only happens if something was carried out of the previous // position. So the previous position must have a "small" value. The // next full carry is guaranteed not to carry out of that position. for (int pass = 0; pass < 2; pass++) { // unsigned carry out of last position and reduce in to // first position finalCarryReduceLast(limbs); // unsigned carry on all positions long carry = 0; for (int i = 0; i < numLimbs - 1; i++) { limbs[i] += carry; carry = limbs[i] >> bitsPerLimb; limbs[i] -= carry << bitsPerLimb; } limbs[numLimbs - 1] += carry; } // Limbs are positive and all less than 2^bitsPerLimb, and the // high-order limb may be even smaller due to the representation- // specific carry/reduce out of the high position. // The value may still be greater than the modulus. // Subtract the max limb values only if all limbs end up non-negative // This only works if there is at most one position where posModLimbs // is less than 2^bitsPerLimb - 1 (not counting the high-order limb, // if it has extra bits that are cleared by finalCarryReduceLast). int smallerNonNegative = 1; long[] smaller = new long[numLimbs]; for (int i = numLimbs - 1; i >= 0; i--) { smaller[i] = limbs[i] - posModLimbs[i]; // expression on right is 1 if smaller[i] is nonnegative, // 0 otherwise smallerNonNegative *= (int) (smaller[i] >> 63) + 1; } conditionalSwap(smallerNonNegative, limbs, smaller); }
Decode the value in v and store it in dst. Requires that v is final reduced. I.e. all limbs in [0, 2^bitsPerLimb) and value in [0, modulus).
/** * Decode the value in v and store it in dst. Requires that v is final * reduced. I.e. all limbs in [0, 2^bitsPerLimb) and value in [0, modulus). */
protected void decode(long[] v, byte[] dst, int offset, int length) { int nextLimbIndex = 0; long curLimbValue = v[nextLimbIndex++]; int bitPos = 0; for (int i = 0; i < length; i++) { int dstIndex = i + offset; if (bitPos + 8 >= bitsPerLimb) { dst[dstIndex] = (byte) curLimbValue; curLimbValue = 0; if (nextLimbIndex < v.length) { curLimbValue = v[nextLimbIndex++]; } int bitsAdded = bitsPerLimb - bitPos; int bitsLeft = 8 - bitsAdded; dst[dstIndex] += (curLimbValue & (0xFF >> bitsAdded)) << bitsAdded; curLimbValue >>= bitsLeft; bitPos = bitsLeft; } else { dst[dstIndex] = (byte) curLimbValue; curLimbValue >>= 8; bitPos += 8; } } }
Add two IntegerPolynomial representations (a and b) and store the result in an IntegerPolynomialRepresentation (dst). Requires that a.length == b.length == dst.length. It is allowed for a and dst to be the same array.
/** * Add two IntegerPolynomial representations (a and b) and store the result * in an IntegerPolynomialRepresentation (dst). Requires that * a.length == b.length == dst.length. It is allowed for a and * dst to be the same array. */
protected void addLimbs(long[] a, long[] b, long[] dst) { for (int i = 0; i < dst.length; i++) { dst[i] = a[i] + b[i]; } }
Branch-free conditional assignment of b to a. Requires that set is 0 or 1, and that a.length == b.length. If set==0, then the values of a and b will be unchanged. If set==1, then the values of b will be assigned to a. The behavior is undefined if swap has any value other than 0 or 1.
/** * Branch-free conditional assignment of b to a. Requires that set is 0 or * 1, and that a.length == b.length. If set==0, then the values of a and b * will be unchanged. If set==1, then the values of b will be assigned to a. * The behavior is undefined if swap has any value other than 0 or 1. */
protected static void conditionalAssign(int set, long[] a, long[] b) { int maskValue = 0 - set; for (int i = 0; i < a.length; i++) { long dummyLimbs = maskValue & (a[i] ^ b[i]); a[i] = dummyLimbs ^ a[i]; } }
Branch-free conditional swap of a and b. Requires that swap is 0 or 1, and that a.length == b.length. If swap==0, then the values of a and b will be unchanged. If swap==1, then the values of a and b will be swapped. The behavior is undefined if swap has any value other than 0 or 1.
/** * Branch-free conditional swap of a and b. Requires that swap is 0 or 1, * and that a.length == b.length. If swap==0, then the values of a and b * will be unchanged. If swap==1, then the values of a and b will be * swapped. The behavior is undefined if swap has any value other than * 0 or 1. */
protected static void conditionalSwap(int swap, long[] a, long[] b) { int maskValue = 0 - swap; for (int i = 0; i < a.length; i++) { long dummyLimbs = maskValue & (a[i] ^ b[i]); a[i] = dummyLimbs ^ a[i]; b[i] = dummyLimbs ^ b[i]; } }
Stores the reduced, little-endian value of limbs in result.
/** * Stores the reduced, little-endian value of limbs in result. */
protected void limbsToByteArray(long[] limbs, byte[] result) { long[] reducedLimbs = limbs.clone(); finalReduce(reducedLimbs); decode(reducedLimbs, result, 0, result.length); }
Add the reduced number corresponding to limbs and other, and store the low-order bytes of the sum in result. Requires that limbs.length==other.length. The result array may have any length.
/** * Add the reduced number corresponding to limbs and other, and store * the low-order bytes of the sum in result. Requires that * limbs.length==other.length. The result array may have any length. */
protected void addLimbsModPowerTwo(long[] limbs, long[] other, byte[] result) { long[] reducedOther = other.clone(); long[] reducedLimbs = limbs.clone(); finalReduce(reducedOther); finalReduce(reducedLimbs); addLimbs(reducedLimbs, reducedOther, reducedLimbs); // may carry out a value which can be ignored long carry = 0; for (int i = 0; i < numLimbs; i++) { reducedLimbs[i] += carry; carry = reducedLimbs[i] >> bitsPerLimb; reducedLimbs[i] -= carry << bitsPerLimb; } decode(reducedLimbs, result, 0, result.length); } private abstract class Element implements IntegerModuloP { protected long[] limbs; protected int numAdds; public Element(BigInteger v) { limbs = new long[numLimbs]; setValue(v); } public Element(boolean v) { this.limbs = new long[numLimbs]; this.limbs[0] = v ? 1l : 0l; this.numAdds = 0; } private Element(long[] limbs, int numAdds) { this.limbs = limbs; this.numAdds = numAdds; } private void setValue(BigInteger v) { setLimbsValue(v, limbs); this.numAdds = 0; } @Override public IntegerFieldModuloP getField() { return IntegerPolynomial.this; } @Override public BigInteger asBigInteger() { return evaluate(limbs); } @Override public MutableElement mutable() { return new MutableElement(limbs.clone(), numAdds); } protected boolean isSummand() { return numAdds < maxAdds; } @Override public ImmutableElement add(IntegerModuloP genB) { Element b = (Element) genB; if (!(isSummand() && b.isSummand())) { throw new ArithmeticException("Not a valid summand"); } long[] newLimbs = new long[limbs.length]; for (int i = 0; i < limbs.length; i++) { newLimbs[i] = limbs[i] + b.limbs[i]; } int newNumAdds = Math.max(numAdds, b.numAdds) + 1; return new ImmutableElement(newLimbs, newNumAdds); } @Override public ImmutableElement additiveInverse() { long[] newLimbs = new long[limbs.length]; for (int i = 0; i < limbs.length; i++) { newLimbs[i] = -limbs[i]; } ImmutableElement result = new ImmutableElement(newLimbs, numAdds); return result; } protected long[] cloneLow(long[] limbs) { long[] newLimbs = new long[numLimbs]; copyLow(limbs, newLimbs); return newLimbs; } protected void copyLow(long[] limbs, long[] out) { System.arraycopy(limbs, 0, out, 0, out.length); } @Override public ImmutableElement multiply(IntegerModuloP genB) { Element b = (Element) genB; long[] newLimbs = new long[limbs.length]; mult(limbs, b.limbs, newLimbs); return new ImmutableElement(newLimbs, 0); } @Override public ImmutableElement square() { long[] newLimbs = new long[limbs.length]; IntegerPolynomial.this.square(limbs, newLimbs); return new ImmutableElement(newLimbs, 0); } public void addModPowerTwo(IntegerModuloP arg, byte[] result) { Element other = (Element) arg; if (!(isSummand() && other.isSummand())) { throw new ArithmeticException("Not a valid summand"); } addLimbsModPowerTwo(limbs, other.limbs, result); } public void asByteArray(byte[] result) { if (!isSummand()) { throw new ArithmeticException("Not a valid summand"); } limbsToByteArray(limbs, result); } } protected class MutableElement extends Element implements MutableIntegerModuloP { protected MutableElement(long[] limbs, int numAdds) { super(limbs, numAdds); } @Override public ImmutableElement fixed() { return new ImmutableElement(limbs.clone(), numAdds); } @Override public void conditionalSet(IntegerModuloP b, int set) { Element other = (Element) b; conditionalAssign(set, limbs, other.limbs); numAdds = other.numAdds; } @Override public void conditionalSwapWith(MutableIntegerModuloP b, int swap) { MutableElement other = (MutableElement) b; conditionalSwap(swap, limbs, other.limbs); int numAddsTemp = numAdds; numAdds = other.numAdds; other.numAdds = numAddsTemp; } @Override public MutableElement setValue(IntegerModuloP v) { Element other = (Element) v; System.arraycopy(other.limbs, 0, limbs, 0, other.limbs.length); numAdds = other.numAdds; return this; } @Override public MutableElement setValue(byte[] arr, int offset, int length, byte highByte) { encode(arr, offset, length, highByte, limbs); this.numAdds = 0; return this; } @Override public MutableElement setValue(ByteBuffer buf, int length, byte highByte) { encode(buf, length, highByte, limbs); numAdds = 0; return this; } @Override public MutableElement setProduct(IntegerModuloP genB) { Element b = (Element) genB; mult(limbs, b.limbs, limbs); numAdds = 0; return this; } @Override public MutableElement setProduct(SmallValue v) { int value = ((Limb) v).value; multByInt(limbs, value); numAdds = 0; return this; } @Override public MutableElement setSum(IntegerModuloP genB) { Element b = (Element) genB; if (!(isSummand() && b.isSummand())) { throw new ArithmeticException("Not a valid summand"); } for (int i = 0; i < limbs.length; i++) { limbs[i] = limbs[i] + b.limbs[i]; } numAdds = Math.max(numAdds, b.numAdds) + 1; return this; } @Override public MutableElement setDifference(IntegerModuloP genB) { Element b = (Element) genB; if (!(isSummand() && b.isSummand())) { throw new ArithmeticException("Not a valid summand"); } for (int i = 0; i < limbs.length; i++) { limbs[i] = limbs[i] - b.limbs[i]; } numAdds = Math.max(numAdds, b.numAdds) + 1; return this; } @Override public MutableElement setSquare() { IntegerPolynomial.this.square(limbs, limbs); numAdds = 0; return this; } @Override public MutableElement setAdditiveInverse() { for (int i = 0; i < limbs.length; i++) { limbs[i] = -limbs[i]; } return this; } @Override public MutableElement setReduced() { reduce(limbs); numAdds = 0; return this; } } class ImmutableElement extends Element implements ImmutableIntegerModuloP { protected ImmutableElement(BigInteger v) { super(v); } protected ImmutableElement(boolean v) { super(v); } protected ImmutableElement(long[] limbs, int numAdds) { super(limbs, numAdds); } @Override public ImmutableElement fixed() { return this; } } class Limb implements SmallValue { int value; Limb(int value) { this.value = value; } } }