package com.google.crypto.tink.subtle;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.WritableByteChannel;
import java.security.GeneralSecurityException;
class StreamingAeadEncryptingChannel implements WritableByteChannel {
private WritableByteChannel ciphertextChannel;
private StreamSegmentEncrypter encrypter;
ByteBuffer ptBuffer;
ByteBuffer ctBuffer;
private int plaintextSegmentSize;
boolean open = true;
public StreamingAeadEncryptingChannel(
NonceBasedStreamingAead streamAead,
WritableByteChannel ciphertextChannel,
byte[] associatedData) throws GeneralSecurityException, IOException {
this.ciphertextChannel = ciphertextChannel;
encrypter = streamAead.newStreamSegmentEncrypter(associatedData);
plaintextSegmentSize = streamAead.getPlaintextSegmentSize();
ptBuffer = ByteBuffer.allocate(plaintextSegmentSize);
ptBuffer.limit(plaintextSegmentSize - streamAead.getCiphertextOffset());
ctBuffer = ByteBuffer.allocate(streamAead.getCiphertextSegmentSize());
ctBuffer.put(encrypter.getHeader());
ctBuffer.flip();
ciphertextChannel.write(ctBuffer);
}
@Override
public synchronized int write(ByteBuffer pt) throws IOException {
if (!open) {
throw new ClosedChannelException();
}
if (ctBuffer.remaining() > 0) {
ciphertextChannel.write(ctBuffer);
}
int startPosition = pt.position();
while (pt.remaining() > ptBuffer.remaining()) {
if (ctBuffer.remaining() > 0) {
return pt.position() - startPosition;
}
int sliceSize = ptBuffer.remaining();
ByteBuffer slice = pt.slice();
slice.limit(sliceSize);
pt.position(pt.position() + sliceSize);
try {
ptBuffer.flip();
ctBuffer.clear();
if (slice.remaining() != 0) {
encrypter.encryptSegment(ptBuffer, slice, false, ctBuffer);
} else {
encrypter.encryptSegment(ptBuffer, false, ctBuffer);
}
} catch (GeneralSecurityException ex) {
throw new IOException(ex);
}
ctBuffer.flip();
ciphertextChannel.write(ctBuffer);
ptBuffer.clear();
ptBuffer.limit(plaintextSegmentSize);
}
ptBuffer.put(pt);
return pt.position() - startPosition;
}
@Override
public synchronized void close() throws IOException {
if (!open) {
return;
}
while (ctBuffer.remaining() > 0) {
int n = ciphertextChannel.write(ctBuffer);
if (n <= 0) {
throw new IOException("Failed to write ciphertext before closing");
}
}
try {
ctBuffer.clear();
ptBuffer.flip();
encrypter.encryptSegment(ptBuffer, true, ctBuffer);
} catch (GeneralSecurityException ex) {
throw new IOException(ex);
}
ctBuffer.flip();
while (ctBuffer.remaining() > 0) {
int n = ciphertextChannel.write(ctBuffer);
if (n <= 0) {
throw new IOException("Failed to write ciphertext before closing");
}
}
ciphertextChannel.close();
open = false;
}
@Override
public boolean isOpen() {
return open;
}
}