package org.bouncycastle.pqc.crypto.sphincs;
class Wots
{
static final int WOTS_LOGW = 4;
static final int WOTS_W = (1 << WOTS_LOGW);
static final int WOTS_L1 = ((256 + WOTS_LOGW - 1) / WOTS_LOGW);
static final int WOTS_L = 67;
static final int WOTS_LOG_L = 7;
static final int WOTS_SIGBYTES = (WOTS_L * SPHINCS256Config.HASH_BYTES);
static void expand_seed(byte[] outseeds, int outOff, byte[] inseed, int inOff)
{
clear(outseeds, outOff, WOTS_L * SPHINCS256Config.HASH_BYTES);
Seed.prg(outseeds, outOff, WOTS_L * SPHINCS256Config.HASH_BYTES, inseed, inOff);
}
private static void clear(byte[] bytes, int offSet, int length)
{
for (int i = 0; i != length; i++)
{
bytes[i + offSet] = 0;
}
}
static void gen_chain(HashFunctions hs, byte[] out, int outOff, byte[] seed, int seedOff, byte[] masks, int masksOff, int chainlen)
{
int i, j;
for (j = 0; j < SPHINCS256Config.HASH_BYTES; j++)
out[j + outOff] = seed[j + seedOff];
for (i = 0; i < chainlen && i < WOTS_W; i++)
hs.hash_n_n_mask(out, outOff, out, outOff, masks, masksOff + (i * SPHINCS256Config.HASH_BYTES));
}
void wots_pkgen(HashFunctions hs, byte[] pk, int pkOff, byte[] sk, int skOff, byte[] masks, int masksOff)
{
int i;
expand_seed(pk, pkOff, sk, skOff);
for (i = 0; i < WOTS_L; i++)
gen_chain(hs, pk, pkOff + i * SPHINCS256Config.HASH_BYTES, pk, pkOff + i * SPHINCS256Config.HASH_BYTES, masks, masksOff, WOTS_W - 1);
}
void wots_sign(HashFunctions hs, byte[] sig, int sigOff, byte[] msg, byte[] sk, byte[] masks)
{
int[] basew = new int[WOTS_L];
int i, c = 0;
for (i = 0; i < WOTS_L1; i += 2)
{
basew[i] = msg[i / 2] & 0xf;
basew[i + 1] = (msg[i / 2] & 0xff) >>> 4;
c += WOTS_W - 1 - basew[i];
c += WOTS_W - 1 - basew[i + 1];
}
for (; i < WOTS_L; i++)
{
basew[i] = c & 0xf;
c >>>= 4;
}
expand_seed(sig, sigOff, sk, 0);
for (i = 0; i < WOTS_L; i++)
gen_chain(hs, sig, sigOff + i * SPHINCS256Config.HASH_BYTES, sig, sigOff + i * SPHINCS256Config.HASH_BYTES, masks, 0, basew[i]);
}
void wots_verify(HashFunctions hs, byte[] pk, byte[] sig, int sigOff, byte[] msg, byte[] masks)
{
int[] basew = new int[WOTS_L];
int i, c = 0;
for (i = 0; i < WOTS_L1; i += 2)
{
basew[i] = msg[i / 2] & 0xf;
basew[i + 1] = (msg[i / 2] & 0xff) >>> 4;
c += WOTS_W - 1 - basew[i];
c += WOTS_W - 1 - basew[i + 1];
}
for (; i < WOTS_L; i++)
{
basew[i] = c & 0xf;
c >>>= 4;
}
for (i = 0; i < WOTS_L; i++)
gen_chain(hs, pk, i * SPHINCS256Config.HASH_BYTES, sig, sigOff + i * SPHINCS256Config.HASH_BYTES, masks, (basew[i] * SPHINCS256Config.HASH_BYTES), WOTS_W - 1 - basew[i]);
}
}