package org.apache.commons.crypto.cipher;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.InvalidAlgorithmParameterException;
import java.security.spec.AlgorithmParameterSpec;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.GCMParameterSpec;
class OpenSslGaloisCounterMode extends OpenSslFeedbackCipher {
private ByteArrayOutputStream aadBuffer = new ByteArrayOutputStream();
private int tagBitLen = -1;
static final int DEFAULT_TAG_LEN = 16;
private ByteArrayOutputStream inBuffer = null;
public OpenSslGaloisCounterMode(final long context, final int algorithmMode, final int padding) {
super(context, algorithmMode, padding);
}
@Override
public void init(final int mode, final byte[] key, final AlgorithmParameterSpec params)
throws InvalidAlgorithmParameterException {
if (aadBuffer == null) {
aadBuffer = new ByteArrayOutputStream();
} else {
aadBuffer.reset();
}
this.cipherMode = mode;
byte[] iv;
if (params instanceof GCMParameterSpec) {
final GCMParameterSpec gcmParam = (GCMParameterSpec) params;
iv = gcmParam.getIV();
this.tagBitLen = gcmParam.getTLen();
} else {
throw new InvalidAlgorithmParameterException("Illegal parameters");
}
if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
inBuffer = new ByteArrayOutputStream();
}
context = OpenSslNative.init(context, mode, algorithmMode, padding, key, iv);
}
@Override
public int update(final ByteBuffer input, final ByteBuffer output) throws ShortBufferException {
checkState();
processAAD();
int len;
if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
final int inputLen = input.remaining();
final byte[] inputBuf = new byte[inputLen];
input.get(inputBuf, 0, inputLen);
inBuffer.write(inputBuf, 0, inputLen);
return 0;
}
len = OpenSslNative.update(context, input, input.position(),
input.remaining(), output, output.position(),
output.remaining());
input.position(input.limit());
output.position(output.position() + len);
return len;
}
@Override
public int update(final byte[] input, final int inputOffset, final int inputLen, final byte[] output, final int outputOffset)
throws ShortBufferException {
checkState();
processAAD();
if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
inBuffer.write(input, inputOffset, inputLen);
return 0;
}
return OpenSslNative.updateByteArray(context, input, inputOffset,
inputLen, output, outputOffset, output.length - outputOffset);
}
@Override
public int doFinal(final byte[] input, final int inputOffset, final int inputLen, final byte[] output, final int outputOffset)
throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
checkState();
processAAD();
int len;
if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
int inputOffsetFinal = inputOffset;
int inputLenFinal = inputLen;
byte[] inputFinal;
if (inBuffer != null && inBuffer.size() > 0) {
inBuffer.write(input, inputOffset, inputLen);
inputFinal = inBuffer.toByteArray();
inputOffsetFinal = 0;
inputLenFinal = inputFinal.length;
inBuffer.reset();
} else {
inputFinal = input;
}
if (inputFinal.length < getTagLen()) {
throw new AEADBadTagException("Input too short - need tag");
}
final int inputDataLen = inputLenFinal - getTagLen();
len = OpenSslNative.updateByteArray(context, inputFinal, inputOffsetFinal,
inputDataLen, output, outputOffset, output.length - outputOffset);
final ByteBuffer tag = ByteBuffer.allocate(getTagLen());
tag.put(input, input.length - getTagLen(), getTagLen());
tag.flip();
evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_SET_TAG.getValue(), getTagLen(), tag);
} else {
len = OpenSslNative.updateByteArray(context, input, inputOffset,
inputLen, output, outputOffset, output.length - outputOffset);
}
len += OpenSslNative.doFinalByteArray(context, output, outputOffset + len,
output.length - outputOffset - len);
if (this.cipherMode == OpenSsl.ENCRYPT_MODE) {
ByteBuffer tag;
tag = ByteBuffer.allocate(getTagLen());
evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_GET_TAG.getValue(), getTagLen(), tag);
tag.get(output, output.length - getTagLen(), getTagLen());
len += getTagLen();
}
return len;
}
@Override
public int doFinal(final ByteBuffer input, final ByteBuffer output)
throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
checkState();
processAAD();
int totalLen = 0;
int len;
if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
final ByteBuffer tag = ByteBuffer.allocate(getTagLen());
if (inBuffer != null && inBuffer.size() > 0) {
final byte[] inputBytes = new byte[input.remaining()];
input.get(inputBytes, 0, inputBytes.length);
inBuffer.write(inputBytes, 0, inputBytes.length);
final byte[] inputFinal = inBuffer.toByteArray();
inBuffer.reset();
if (inputFinal.length < getTagLen()) {
throw new AEADBadTagException("Input too short - need tag");
}
len = OpenSslNative.updateByteArrayByteBuffer(context, inputFinal, 0,
inputFinal.length - getTagLen(),
output, output.position(), output.remaining());
tag.put(inputFinal, inputFinal.length - getTagLen(), getTagLen());
tag.flip();
} else {
if (input.remaining() < getTagLen()) {
throw new AEADBadTagException("Input too short - need tag");
}
len = OpenSslNative.update(context, input, input.position(),
input.remaining() - getTagLen(), output, output.position(),
output.remaining());
input.position(input.position() + len);
tag.put(input);
tag.flip();
}
evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_SET_TAG.getValue(),
getTagLen(), tag);
} else {
len = OpenSslNative.update(context, input, input.position(),
input.remaining(), output, output.position(),
output.remaining());
input.position(input.limit());
}
totalLen += len;
output.position(output.position() + len);
len = OpenSslNative.doFinal(context, output, output.position(),
output.remaining());
output.position(output.position() + len);
totalLen += len;
if (this.cipherMode == OpenSsl.ENCRYPT_MODE) {
ByteBuffer tag;
tag = ByteBuffer.allocate(getTagLen());
evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_GET_TAG.getValue(), getTagLen(), tag);
output.put(tag);
totalLen += getTagLen();
}
return totalLen;
}
@Override
public void clean() {
super.clean();
aadBuffer = null;
}
@Override
public void updateAAD(final byte[] aad) {
if (aadBuffer != null) {
aadBuffer.write(aad, 0, aad.length);
} else {
throw new IllegalStateException
("Update has been called; no more AAD data");
}
}
private void processAAD() {
if (aadBuffer != null && aadBuffer.size() > 0) {
OpenSslNative.updateByteArray(context, aadBuffer.toByteArray(),
0, aadBuffer.size(), null, 0, 0);
aadBuffer = null;
}
}
private int getTagLen() {
return tagBitLen < 0 ? DEFAULT_TAG_LEN : (tagBitLen >> 3);
}
private void evpCipherCtxCtrl(final long context, final int type, final int arg, final ByteBuffer bb) {
checkState();
try {
if (bb != null) {
bb.order(ByteOrder.nativeOrder());
OpenSslNative.ctrl(context, type, arg, bb.array());
} else {
OpenSslNative.ctrl(context, type, arg, null);
}
} catch (final Exception e) {
System.out.println(e.getMessage());
}
}
}