package com.google.crypto.tink.subtle.prf;
import static java.lang.Math.min;
import com.google.crypto.tink.subtle.EngineFactory;
import com.google.crypto.tink.subtle.Enums.HashType;
import com.google.errorprone.annotations.Immutable;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import javax.crypto.spec.SecretKeySpec;
@Immutable
public class HkdfStreamingPrf implements StreamingPrf {
private static String getJavaxHmacName(HashType hashType) throws GeneralSecurityException {
switch (hashType) {
case SHA1:
return "HmacSha1";
case SHA256:
return "HmacSha256";
case SHA384:
return "HmacSha384";
case SHA512:
return "HmacSha512";
default:
throw new GeneralSecurityException(
"No getJavaxHmacName for given hash " + hashType + " known");
}
}
public HkdfStreamingPrf(final HashType hashType, final byte[] ikm, final byte[] salt) {
this.hashType = hashType;
this.ikm = Arrays.copyOf(ikm, ikm.length);
this.salt = Arrays.copyOf(salt, salt.length);
}
private final HashType hashType;
@SuppressWarnings("Immutable")
private final byte[] ikm;
@SuppressWarnings("Immutable")
private final byte[] salt;
private class HkdfInputStream extends InputStream {
public HkdfInputStream(final byte[] input) {
ctr = -1;
this.input = Arrays.copyOf(input, input.length);
}
private void initialize() throws GeneralSecurityException, IOException {
try {
mac = EngineFactory.MAC.getInstance(getJavaxHmacName(hashType));
} catch (GeneralSecurityException e) {
throw new IOException("Creating HMac failed", e);
}
if (salt == null || salt.length == 0) {
mac.init(new SecretKeySpec(new byte[mac.getMacLength()], getJavaxHmacName(hashType)));
} else {
mac.init(new SecretKeySpec(salt, getJavaxHmacName(hashType)));
}
mac.update(ikm);
prk = mac.doFinal();
buffer = ByteBuffer.allocateDirect(0);
buffer.mark();
ctr = 0;
}
private void updateBuffer() throws GeneralSecurityException, IOException {
mac.init(new SecretKeySpec(prk, getJavaxHmacName(hashType)));
buffer.reset();
mac.update(buffer);
mac.update(input);
ctr = ctr + 1;
mac.update((byte) ctr);
buffer = ByteBuffer.wrap(mac.doFinal());
buffer.mark();
}
@Override
public int read() throws IOException {
byte[] oneByte = new byte[1];
int ret = read(oneByte, 0, 1);
if (ret == 1) {
return oneByte[0] & 0xff;
} else if (ret == -1) {
return ret;
} else {
throw new IOException("Reading failed");
}
}
@Override
public int read(byte[] dst) throws IOException {
return read(dst, 0, dst.length);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
int totalRead = 0;
try {
if (ctr == -1) {
initialize();
}
while (totalRead < len) {
if (!buffer.hasRemaining()) {
if (ctr == 255) {
return totalRead;
}
updateBuffer();
}
int toRead = min(len - totalRead, buffer.remaining());
buffer.get(b, off, toRead);
off += toRead;
totalRead += toRead;
}
} catch (GeneralSecurityException e) {
mac = null;
throw new IOException("HkdfInputStream failed", e);
}
return totalRead;
}
private final byte[] input;
private javax.crypto.Mac mac;
private byte[] prk;
private ByteBuffer buffer;
private int ctr;
}
@Override
public InputStream computePrf(final byte[] input) {
return new HkdfInputStream(input);
}
}