package org.bouncycastle.pqc.crypto.xmss;

import java.util.ArrayList;
import java.util.List;

import org.bouncycastle.util.Arrays;

WOTS+.
/** * WOTS+. */
final class WOTSPlus {
WOTS+ parameters.
/** * WOTS+ parameters. */
private final WOTSPlusParameters params;
Randomization functions.
/** * Randomization functions. */
private final KeyedHashFunctions khf;
WOTS+ secret key seed.
/** * WOTS+ secret key seed. */
private byte[] secretKeySeed;
WOTS+ public seed.
/** * WOTS+ public seed. */
private byte[] publicSeed;
Constructs a new WOTS+ one-time signature system based on the given WOTS+ parameters.
Params:
  • params – Parameters for WOTSPlus object.
/** * Constructs a new WOTS+ one-time signature system based on the given WOTS+ * parameters. * * @param params Parameters for WOTSPlus object. */
protected WOTSPlus(WOTSPlusParameters params) { super(); if (params == null) { throw new NullPointerException("params == null"); } this.params = params; int n = params.getDigestSize(); khf = new KeyedHashFunctions(params.getDigest(), n); secretKeySeed = new byte[n]; publicSeed = new byte[n]; }
Import keys to WOTS+ instance.
Params:
  • secretKeySeed – Secret key seed.
  • publicSeed – Public seed.
/** * Import keys to WOTS+ instance. * * @param secretKeySeed Secret key seed. * @param publicSeed Public seed. */
void importKeys(byte[] secretKeySeed, byte[] publicSeed) { if (secretKeySeed == null) { throw new NullPointerException("secretKeySeed == null"); } if (secretKeySeed.length != params.getDigestSize()) { throw new IllegalArgumentException("size of secretKeySeed needs to be equal to size of digest"); } if (publicSeed == null) { throw new NullPointerException("publicSeed == null"); } if (publicSeed.length != params.getDigestSize()) { throw new IllegalArgumentException("size of publicSeed needs to be equal to size of digest"); } this.secretKeySeed = secretKeySeed; this.publicSeed = publicSeed; }
Creates a signature for the n-byte messageDigest.
Params:
  • messageDigest – Digest to sign.
  • otsHashAddress – OTS hash address for randomization.
Returns:WOTS+ signature.
/** * Creates a signature for the n-byte messageDigest. * * @param messageDigest Digest to sign. * @param otsHashAddress OTS hash address for randomization. * @return WOTS+ signature. */
protected WOTSPlusSignature sign(byte[] messageDigest, OTSHashAddress otsHashAddress) { if (messageDigest == null) { throw new NullPointerException("messageDigest == null"); } if (messageDigest.length != params.getDigestSize()) { throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest"); } if (otsHashAddress == null) { throw new NullPointerException("otsHashAddress == null"); } List<Integer> baseWMessage = convertToBaseW(messageDigest, params.getWinternitzParameter(), params.getLen1()); /* create checksum */ int checksum = 0; for (int i = 0; i < params.getLen1(); i++) { checksum += params.getWinternitzParameter() - 1 - baseWMessage.get(i); } checksum <<= (8 - ((params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) % 8)); int len2Bytes = (int)Math .ceil((double)(params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) / 8); List<Integer> baseWChecksum = convertToBaseW(XMSSUtil.toBytesBigEndian(checksum, len2Bytes), params.getWinternitzParameter(), params.getLen2()); /* msg || checksum */ baseWMessage.addAll(baseWChecksum); /* create signature */ byte[][] signature = new byte[params.getLen()][]; for (int i = 0; i < params.getLen(); i++) { otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i) .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask()) .build(); signature[i] = chain(expandSecretKeySeed(i), 0, baseWMessage.get(i), otsHashAddress); } return new WOTSPlusSignature(params, signature); }
Verifies signature on message.
Params:
  • messageDigest – The digest that was signed.
  • signature – Signature on digest.
  • otsHashAddress – OTS hash address for randomization.
Returns:true if signature was correct false else.
/** * Verifies signature on message. * * @param messageDigest The digest that was signed. * @param signature Signature on digest. * @param otsHashAddress OTS hash address for randomization. * @return true if signature was correct false else. */
protected boolean verifySignature(byte[] messageDigest, WOTSPlusSignature signature, OTSHashAddress otsHashAddress) { if (messageDigest == null) { throw new NullPointerException("messageDigest == null"); } if (messageDigest.length != params.getDigestSize()) { throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest"); } if (signature == null) { throw new NullPointerException("signature == null"); } if (otsHashAddress == null) { throw new NullPointerException("otsHashAddress == null"); } byte[][] tmpPublicKey = getPublicKeyFromSignature(messageDigest, signature, otsHashAddress).toByteArray(); /* compare values */ return XMSSUtil.areEqual(tmpPublicKey, getPublicKey(otsHashAddress).toByteArray()) ? true : false; }
Calculates a public key based on digest and signature.
Params:
  • messageDigest – The digest that was signed.
  • signature – Signarure on digest.
  • otsHashAddress – OTS hash address for randomization.
Returns:WOTS+ public key derived from digest and signature.
/** * Calculates a public key based on digest and signature. * * @param messageDigest The digest that was signed. * @param signature Signarure on digest. * @param otsHashAddress OTS hash address for randomization. * @return WOTS+ public key derived from digest and signature. */
protected WOTSPlusPublicKeyParameters getPublicKeyFromSignature(byte[] messageDigest, WOTSPlusSignature signature, OTSHashAddress otsHashAddress) { if (messageDigest == null) { throw new NullPointerException("messageDigest == null"); } if (messageDigest.length != params.getDigestSize()) { throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest"); } if (signature == null) { throw new NullPointerException("signature == null"); } if (otsHashAddress == null) { throw new NullPointerException("otsHashAddress == null"); } List<Integer> baseWMessage = convertToBaseW(messageDigest, params.getWinternitzParameter(), params.getLen1()); /* create checksum */ int checksum = 0; for (int i = 0; i < params.getLen1(); i++) { checksum += params.getWinternitzParameter() - 1 - baseWMessage.get(i); } checksum <<= (8 - ((params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) % 8)); int len2Bytes = (int)Math .ceil((double)(params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) / 8); List<Integer> baseWChecksum = convertToBaseW(XMSSUtil.toBytesBigEndian(checksum, len2Bytes), params.getWinternitzParameter(), params.getLen2()); /* msg || checksum */ baseWMessage.addAll(baseWChecksum); byte[][] publicKey = new byte[params.getLen()][]; for (int i = 0; i < params.getLen(); i++) { otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i) .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask()) .build(); publicKey[i] = chain(signature.toByteArray()[i], baseWMessage.get(i), params.getWinternitzParameter() - 1 - baseWMessage.get(i), otsHashAddress); } return new WOTSPlusPublicKeyParameters(params, publicKey); }
Computes an iteration of F on an n-byte input using outputs of PRF.
Params:
  • startHash – Starting point.
  • startIndex – Start index.
  • steps – Steps to take.
  • otsHashAddress – OTS hash address for randomization.
Returns:Value obtained by iterating F for steps times on input startHash, using the outputs of PRF.
/** * Computes an iteration of F on an n-byte input using outputs of PRF. * * @param startHash Starting point. * @param startIndex Start index. * @param steps Steps to take. * @param otsHashAddress OTS hash address for randomization. * @return Value obtained by iterating F for steps times on input startHash, * using the outputs of PRF. */
private byte[] chain(byte[] startHash, int startIndex, int steps, OTSHashAddress otsHashAddress) { int n = params.getDigestSize(); if (startHash == null) { throw new NullPointerException("startHash == null"); } if (startHash.length != n) { throw new IllegalArgumentException("startHash needs to be " + n + "bytes"); } if (otsHashAddress == null) { throw new NullPointerException("otsHashAddress == null"); } if (otsHashAddress.toByteArray() == null) { throw new NullPointerException("otsHashAddress byte array == null"); } if ((startIndex + steps) > params.getWinternitzParameter() - 1) { throw new IllegalArgumentException("max chain length must not be greater than w"); } if (steps == 0) { return startHash; } byte[] tmp = chain(startHash, startIndex, steps - 1, otsHashAddress); otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(otsHashAddress.getChainAddress()) .withHashAddress(startIndex + steps - 1).withKeyAndMask(0).build(); byte[] key = khf.PRF(publicSeed, otsHashAddress.toByteArray()); otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(otsHashAddress.getChainAddress()) .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(1).build(); byte[] bitmask = khf.PRF(publicSeed, otsHashAddress.toByteArray()); byte[] tmpMasked = new byte[n]; for (int i = 0; i < n; i++) { tmpMasked[i] = (byte)(tmp[i] ^ bitmask[i]); } tmp = khf.F(key, tmpMasked); return tmp; }
Obtain base w values from Input.
Params:
  • messageDigest – Input data.
  • w – Base.
  • outLength – Length of output.
Returns:outLength-length list of base w integers.
/** * Obtain base w values from Input. * * @param messageDigest Input data. * @param w Base. * @param outLength Length of output. * @return outLength-length list of base w integers. */
private List<Integer> convertToBaseW(byte[] messageDigest, int w, int outLength) { if (messageDigest == null) { throw new NullPointerException("msg == null"); } if (w != 4 && w != 16) { throw new IllegalArgumentException("w needs to be 4 or 16"); } int logW = XMSSUtil.log2(w); if (outLength > ((8 * messageDigest.length) / logW)) { throw new IllegalArgumentException("outLength too big"); } ArrayList<Integer> res = new ArrayList<Integer>(); for (int i = 0; i < messageDigest.length; i++) { for (int j = 8 - logW; j >= 0; j -= logW) { res.add((messageDigest[i] >> j) & (w - 1)); if (res.size() == outLength) { return res; } } } return res; }
Derive WOTS+ secret key for specific index as in XMSS ref impl Andreas Huelsing.
Params:
  • otsHashAddress –
Returns:WOTS+ secret key at index.
/** * Derive WOTS+ secret key for specific index as in XMSS ref impl Andreas * Huelsing. * * @param otsHashAddress * @return WOTS+ secret key at index. */
protected byte[] getWOTSPlusSecretKey(byte[] secretKeySeed, OTSHashAddress otsHashAddress) { otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) .withOTSAddress(otsHashAddress.getOTSAddress()).build(); return khf.PRF(secretKeySeed, otsHashAddress.toByteArray()); }
Derive private key at index from secret key seed.
Params:
  • index – Index.
Returns:Private key at index.
/** * Derive private key at index from secret key seed. * * @param index Index. * @return Private key at index. */
private byte[] expandSecretKeySeed(int index) { if (index < 0 || index >= params.getLen()) { throw new IllegalArgumentException("index out of bounds"); } return khf.PRF(secretKeySeed, XMSSUtil.toBytesBigEndian(index, 32)); }
Getter parameters.
Returns:params.
/** * Getter parameters. * * @return params. */
protected WOTSPlusParameters getParams() { return params; }
Getter keyed hash functions.
Returns:keyed hash functions.
/** * Getter keyed hash functions. * * @return keyed hash functions. */
protected KeyedHashFunctions getKhf() { return khf; }
Getter secret key seed.
Returns:secret key seed.
/** * Getter secret key seed. * * @return secret key seed. */
protected byte[] getSecretKeySeed() { return Arrays.clone(secretKeySeed); }
Getter public seed.
Returns:public seed.
/** * Getter public seed. * * @return public seed. */
protected byte[] getPublicSeed() { return Arrays.clone(publicSeed); }
Getter private key.
Returns:WOTS+ private key.
/** * Getter private key. * * @return WOTS+ private key. */
protected WOTSPlusPrivateKeyParameters getPrivateKey() { byte[][] privateKey = new byte[params.getLen()][]; for (int i = 0; i < privateKey.length; i++) { privateKey[i] = expandSecretKeySeed(i); } return new WOTSPlusPrivateKeyParameters(params, privateKey); }
Calculates a new public key based on the state of secretKeySeed, publicSeed and otsHashAddress.
Params:
  • otsHashAddress – OTS hash address for randomization.
Returns:WOTS+ public key.
/** * Calculates a new public key based on the state of secretKeySeed, * publicSeed and otsHashAddress. * * @param otsHashAddress OTS hash address for randomization. * @return WOTS+ public key. */
protected WOTSPlusPublicKeyParameters getPublicKey(OTSHashAddress otsHashAddress) { if (otsHashAddress == null) { throw new NullPointerException("otsHashAddress == null"); } byte[][] publicKey = new byte[params.getLen()][]; /* derive public key from secretKeySeed */ for (int i = 0; i < params.getLen(); i++) { otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i) .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask()) .build(); publicKey[i] = chain(expandSecretKeySeed(i), 0, params.getWinternitzParameter() - 1, otsHashAddress); } return new WOTSPlusPublicKeyParameters(params, publicKey); } }