package org.bouncycastle.crypto.tls;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.bouncycastle.util.io.SimpleOutputStream;
class RecordStream
{
private static int DEFAULT_PLAINTEXT_LIMIT = (1 << 14);
static final int = 5;
static final int = 0;
static final int = 1;
static final int = 3;
private TlsProtocol handler;
private InputStream input;
private OutputStream output;
private TlsCompression pendingCompression = null, readCompression = null, writeCompression = null;
private TlsCipher pendingCipher = null, readCipher = null, writeCipher = null;
private SequenceNumber readSeqNo = new SequenceNumber(), writeSeqNo = new SequenceNumber();
private ByteArrayOutputStream buffer = new ByteArrayOutputStream();
private TlsHandshakeHash handshakeHash = null;
private SimpleOutputStream handshakeHashUpdater = new SimpleOutputStream()
{
public void write(byte[] buf, int off, int len) throws IOException
{
handshakeHash.update(buf, off, len);
}
};
private ProtocolVersion readVersion = null, writeVersion = null;
private boolean restrictReadVersion = true;
private int plaintextLimit, compressedLimit, ciphertextLimit;
RecordStream(TlsProtocol handler, InputStream input, OutputStream output)
{
this.handler = handler;
this.input = input;
this.output = output;
this.readCompression = new TlsNullCompression();
this.writeCompression = this.readCompression;
}
void init(TlsContext context)
{
this.readCipher = new TlsNullCipher(context);
this.writeCipher = this.readCipher;
this.handshakeHash = new DeferredHash();
this.handshakeHash.init(context);
setPlaintextLimit(DEFAULT_PLAINTEXT_LIMIT);
}
int getPlaintextLimit()
{
return plaintextLimit;
}
void setPlaintextLimit(int plaintextLimit)
{
this.plaintextLimit = plaintextLimit;
this.compressedLimit = this.plaintextLimit + 1024;
this.ciphertextLimit = this.compressedLimit + 1024;
}
ProtocolVersion getReadVersion()
{
return readVersion;
}
void setReadVersion(ProtocolVersion readVersion)
{
this.readVersion = readVersion;
}
void setWriteVersion(ProtocolVersion writeVersion)
{
this.writeVersion = writeVersion;
}
void setRestrictReadVersion(boolean enabled)
{
this.restrictReadVersion = enabled;
}
void setPendingConnectionState(TlsCompression tlsCompression, TlsCipher tlsCipher)
{
this.pendingCompression = tlsCompression;
this.pendingCipher = tlsCipher;
}
void sentWriteCipherSpec()
throws IOException
{
if (pendingCompression == null || pendingCipher == null)
{
throw new TlsFatalAlert(AlertDescription.handshake_failure);
}
this.writeCompression = this.pendingCompression;
this.writeCipher = this.pendingCipher;
this.writeSeqNo = new SequenceNumber();
}
void receivedReadCipherSpec()
throws IOException
{
if (pendingCompression == null || pendingCipher == null)
{
throw new TlsFatalAlert(AlertDescription.handshake_failure);
}
this.readCompression = this.pendingCompression;
this.readCipher = this.pendingCipher;
this.readSeqNo = new SequenceNumber();
}
void finaliseHandshake()
throws IOException
{
if (readCompression != pendingCompression || writeCompression != pendingCompression
|| readCipher != pendingCipher || writeCipher != pendingCipher)
{
throw new TlsFatalAlert(AlertDescription.handshake_failure);
}
this.pendingCompression = null;
this.pendingCipher = null;
}
void (byte[] recordHeader) throws IOException
{
short type = TlsUtils.readUint8(recordHeader, TLS_HEADER_TYPE_OFFSET);
checkType(type, AlertDescription.unexpected_message);
if (!restrictReadVersion)
{
int version = TlsUtils.readVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET);
if ((version & 0xffffff00) != 0x0300)
{
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
}
}
else
{
ProtocolVersion version = TlsUtils.readVersion(recordHeader, TLS_HEADER_VERSION_OFFSET);
if (readVersion == null)
{
}
else if (!version.equals(readVersion))
{
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
}
}
int length = TlsUtils.readUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET);
checkLength(length, ciphertextLimit, AlertDescription.record_overflow);
}
boolean readRecord()
throws IOException
{
byte[] recordHeader = TlsUtils.readAllOrNothing(TLS_HEADER_SIZE, input);
if (recordHeader == null)
{
return false;
}
short type = TlsUtils.readUint8(recordHeader, TLS_HEADER_TYPE_OFFSET);
checkType(type, AlertDescription.unexpected_message);
if (!restrictReadVersion)
{
int version = TlsUtils.readVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET);
if ((version & 0xffffff00) != 0x0300)
{
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
}
}
else
{
ProtocolVersion version = TlsUtils.readVersion(recordHeader, TLS_HEADER_VERSION_OFFSET);
if (readVersion == null)
{
readVersion = version;
}
else if (!version.equals(readVersion))
{
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
}
}
int length = TlsUtils.readUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET);
checkLength(length, ciphertextLimit, AlertDescription.record_overflow);
byte[] plaintext = decodeAndVerify(type, input, length);
handler.processRecord(type, plaintext, 0, plaintext.length);
return true;
}
byte[] decodeAndVerify(short type, InputStream input, int len)
throws IOException
{
byte[] buf = TlsUtils.readFully(len, input);
long seqNo = readSeqNo.nextValue(AlertDescription.unexpected_message);
byte[] decoded = readCipher.decodeCiphertext(seqNo, type, buf, 0, buf.length);
checkLength(decoded.length, compressedLimit, AlertDescription.record_overflow);
OutputStream cOut = readCompression.decompress(buffer);
if (cOut != buffer)
{
cOut.write(decoded, 0, decoded.length);
cOut.flush();
decoded = getBufferContents();
}
checkLength(decoded.length, plaintextLimit, AlertDescription.decompression_failure);
if (decoded.length < 1 && type != ContentType.application_data)
{
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
}
return decoded;
}
void writeRecord(short type, byte[] plaintext, int plaintextOffset, int plaintextLength)
throws IOException
{
if (writeVersion == null)
{
return;
}
checkType(type, AlertDescription.internal_error);
checkLength(plaintextLength, plaintextLimit, AlertDescription.internal_error);
if (plaintextLength < 1 && type != ContentType.application_data)
{
throw new TlsFatalAlert(AlertDescription.internal_error);
}
OutputStream cOut = writeCompression.compress(buffer);
long seqNo = writeSeqNo.nextValue(AlertDescription.internal_error);
byte[] ciphertext;
if (cOut == buffer)
{
ciphertext = writeCipher.encodePlaintext(seqNo, type, plaintext, plaintextOffset, plaintextLength);
}
else
{
cOut.write(plaintext, plaintextOffset, plaintextLength);
cOut.flush();
byte[] compressed = getBufferContents();
checkLength(compressed.length, plaintextLength + 1024, AlertDescription.internal_error);
ciphertext = writeCipher.encodePlaintext(seqNo, type, compressed, 0, compressed.length);
}
checkLength(ciphertext.length, ciphertextLimit, AlertDescription.internal_error);
byte[] record = new byte[ciphertext.length + TLS_HEADER_SIZE];
TlsUtils.writeUint8(type, record, TLS_HEADER_TYPE_OFFSET);
TlsUtils.writeVersion(writeVersion, record, TLS_HEADER_VERSION_OFFSET);
TlsUtils.writeUint16(ciphertext.length, record, TLS_HEADER_LENGTH_OFFSET);
System.arraycopy(ciphertext, 0, record, TLS_HEADER_SIZE, ciphertext.length);
output.write(record);
output.flush();
}
void notifyHelloComplete()
{
this.handshakeHash = handshakeHash.notifyPRFDetermined();
}
TlsHandshakeHash getHandshakeHash()
{
return handshakeHash;
}
OutputStream getHandshakeHashUpdater()
{
return handshakeHashUpdater;
}
TlsHandshakeHash prepareToFinish()
{
TlsHandshakeHash result = handshakeHash;
this.handshakeHash = handshakeHash.stopTracking();
return result;
}
void safeClose()
{
try
{
input.close();
}
catch (IOException e)
{
}
try
{
output.close();
}
catch (IOException e)
{
}
}
void flush()
throws IOException
{
output.flush();
}
private byte[] getBufferContents()
{
byte[] contents = buffer.toByteArray();
buffer.reset();
return contents;
}
private static void checkType(short type, short alertDescription)
throws IOException
{
switch (type)
{
case ContentType.application_data:
case ContentType.alert:
case ContentType.change_cipher_spec:
case ContentType.handshake:
break;
default:
throw new TlsFatalAlert(alertDescription);
}
}
private static void checkLength(int length, int limit, short alertDescription)
throws IOException
{
if (length > limit)
{
throw new TlsFatalAlert(alertDescription);
}
}
private static class SequenceNumber
{
private long value = 0L;
private boolean exhausted = false;
synchronized long nextValue(short alertDescription) throws TlsFatalAlert
{
if (exhausted)
{
throw new TlsFatalAlert(alertDescription);
}
long result = value;
if (++value == 0)
{
exhausted = true;
}
return result;
}
}
}