package org.bouncycastle.crypto.tls;
import java.io.IOException;
class DTLSRecordLayer
implements DatagramTransport
{
private static final int = 13;
private static final int MAX_FRAGMENT_LENGTH = 1 << 14;
private static final long TCP_MSL = 1000L * 60 * 2;
private static final long RETRANSMIT_TIMEOUT = TCP_MSL * 2;
private final DatagramTransport transport;
private final TlsContext context;
private final TlsPeer peer;
private final ByteQueue recordQueue = new ByteQueue();
private volatile boolean closed = false;
private volatile boolean failed = false;
private volatile ProtocolVersion readVersion = null, writeVersion = null;
private volatile boolean inHandshake;
private volatile int plaintextLimit;
private DTLSEpoch currentEpoch, pendingEpoch;
private DTLSEpoch readEpoch, writeEpoch;
private DTLSHandshakeRetransmit retransmit = null;
private DTLSEpoch retransmitEpoch = null;
private long retransmitExpiry = 0;
DTLSRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, short contentType)
{
this.transport = transport;
this.context = context;
this.peer = peer;
this.inHandshake = true;
this.currentEpoch = new DTLSEpoch(0, new TlsNullCipher(context));
this.pendingEpoch = null;
this.readEpoch = currentEpoch;
this.writeEpoch = currentEpoch;
setPlaintextLimit(MAX_FRAGMENT_LENGTH);
}
void setPlaintextLimit(int plaintextLimit)
{
this.plaintextLimit = plaintextLimit;
}
int getReadEpoch()
{
return readEpoch.getEpoch();
}
ProtocolVersion getReadVersion()
{
return readVersion;
}
void setReadVersion(ProtocolVersion readVersion)
{
this.readVersion = readVersion;
}
void setWriteVersion(ProtocolVersion writeVersion)
{
this.writeVersion = writeVersion;
}
void initPendingEpoch(TlsCipher pendingCipher)
{
if (pendingEpoch != null)
{
throw new IllegalStateException();
}
this.pendingEpoch = new DTLSEpoch(writeEpoch.getEpoch() + 1, pendingCipher);
}
void handshakeSuccessful(DTLSHandshakeRetransmit retransmit)
{
if (readEpoch == currentEpoch || writeEpoch == currentEpoch)
{
throw new IllegalStateException();
}
if (retransmit != null)
{
this.retransmit = retransmit;
this.retransmitEpoch = currentEpoch;
this.retransmitExpiry = System.currentTimeMillis() + RETRANSMIT_TIMEOUT;
}
this.inHandshake = false;
this.currentEpoch = pendingEpoch;
this.pendingEpoch = null;
}
void resetWriteEpoch()
{
if (retransmitEpoch != null)
{
this.writeEpoch = retransmitEpoch;
}
else
{
this.writeEpoch = currentEpoch;
}
}
public int getReceiveLimit()
throws IOException
{
return Math.min(this.plaintextLimit,
readEpoch.getCipher().getPlaintextLimit(transport.getReceiveLimit() - RECORD_HEADER_LENGTH));
}
public int getSendLimit()
throws IOException
{
return Math.min(this.plaintextLimit,
writeEpoch.getCipher().getPlaintextLimit(transport.getSendLimit() - RECORD_HEADER_LENGTH));
}
public int receive(byte[] buf, int off, int len, int waitMillis)
throws IOException
{
byte[] record = null;
for (;;)
{
int receiveLimit = Math.min(len, getReceiveLimit()) + RECORD_HEADER_LENGTH;
if (record == null || record.length < receiveLimit)
{
record = new byte[receiveLimit];
}
try
{
if (retransmit != null && System.currentTimeMillis() > retransmitExpiry)
{
retransmit = null;
retransmitEpoch = null;
}
int received = receiveRecord(record, 0, receiveLimit, waitMillis);
if (received < 0)
{
return received;
}
if (received < RECORD_HEADER_LENGTH)
{
continue;
}
int length = TlsUtils.readUint16(record, 11);
if (received != (length + RECORD_HEADER_LENGTH))
{
continue;
}
short type = TlsUtils.readUint8(record, 0);
switch (type)
{
case ContentType.alert:
case ContentType.application_data:
case ContentType.change_cipher_spec:
case ContentType.handshake:
case ContentType.heartbeat:
break;
default:
continue;
}
int epoch = TlsUtils.readUint16(record, 3);
DTLSEpoch recordEpoch = null;
if (epoch == readEpoch.getEpoch())
{
recordEpoch = readEpoch;
}
else if (type == ContentType.handshake && retransmitEpoch != null
&& epoch == retransmitEpoch.getEpoch())
{
recordEpoch = retransmitEpoch;
}
if (recordEpoch == null)
{
continue;
}
long seq = TlsUtils.readUint48(record, 5);
if (recordEpoch.getReplayWindow().shouldDiscard(seq))
{
continue;
}
ProtocolVersion version = TlsUtils.readVersion(record, 1);
if (!version.isDTLS())
{
continue;
}
if (readVersion != null && !readVersion.equals(version))
{
continue;
}
byte[] plaintext = recordEpoch.getCipher().decodeCiphertext(
getMacSequenceNumber(recordEpoch.getEpoch(), seq), type, record, RECORD_HEADER_LENGTH,
received - RECORD_HEADER_LENGTH);
recordEpoch.getReplayWindow().reportAuthenticated(seq);
if (plaintext.length > this.plaintextLimit)
{
continue;
}
if (readVersion == null)
{
readVersion = version;
}
switch (type)
{
case ContentType.alert:
{
if (plaintext.length == 2)
{
short alertLevel = plaintext[0];
short alertDescription = plaintext[1];
peer.notifyAlertReceived(alertLevel, alertDescription);
if (alertLevel == AlertLevel.fatal)
{
failed();
throw new TlsFatalAlert(alertDescription);
}
if (alertDescription == AlertDescription.close_notify)
{
closeTransport();
}
}
continue;
}
case ContentType.application_data:
{
if (inHandshake)
{
continue;
}
break;
}
case ContentType.change_cipher_spec:
{
for (int i = 0; i < plaintext.length; ++i)
{
short message = TlsUtils.readUint8(plaintext, i);
if (message != ChangeCipherSpec.change_cipher_spec)
{
continue;
}
if (pendingEpoch != null)
{
readEpoch = pendingEpoch;
}
}
continue;
}
case ContentType.handshake:
{
if (!inHandshake)
{
if (retransmit != null)
{
retransmit.receivedHandshakeRecord(epoch, plaintext, 0, plaintext.length);
}
continue;
}
break;
}
case ContentType.heartbeat:
{
continue;
}
}
if (!inHandshake && retransmit != null)
{
this.retransmit = null;
this.retransmitEpoch = null;
}
System.arraycopy(plaintext, 0, buf, off, plaintext.length);
return plaintext.length;
}
catch (IOException e)
{
throw e;
}
}
}
public void send(byte[] buf, int off, int len)
throws IOException
{
short contentType = ContentType.application_data;
if (this.inHandshake || this.writeEpoch == this.retransmitEpoch)
{
contentType = ContentType.handshake;
short handshakeType = TlsUtils.readUint8(buf, off);
if (handshakeType == HandshakeType.finished)
{
DTLSEpoch nextEpoch = null;
if (this.inHandshake)
{
nextEpoch = pendingEpoch;
}
else if (this.writeEpoch == this.retransmitEpoch)
{
nextEpoch = currentEpoch;
}
if (nextEpoch == null)
{
throw new IllegalStateException();
}
byte[] data = new byte[]{ 1 };
sendRecord(ContentType.change_cipher_spec, data, 0, data.length);
writeEpoch = nextEpoch;
}
}
sendRecord(contentType, buf, off, len);
}
public void close()
throws IOException
{
if (!closed)
{
if (inHandshake)
{
warn(AlertDescription.user_canceled, "User canceled handshake");
}
closeTransport();
}
}
void fail(short alertDescription)
{
if (!closed)
{
try
{
raiseAlert(AlertLevel.fatal, alertDescription, null, null);
}
catch (Exception e)
{
}
failed = true;
closeTransport();
}
}
void failed()
{
if (!closed)
{
failed = true;
closeTransport();
}
}
void warn(short alertDescription, String message)
throws IOException
{
raiseAlert(AlertLevel.warning, alertDescription, message, null);
}
private void closeTransport()
{
if (!closed)
{
try
{
if (!failed)
{
warn(AlertDescription.close_notify, null);
}
transport.close();
}
catch (Exception e)
{
}
closed = true;
}
}
private void raiseAlert(short alertLevel, short alertDescription, String message, Throwable cause)
throws IOException
{
peer.notifyAlertRaised(alertLevel, alertDescription, message, cause);
byte[] error = new byte[2];
error[0] = (byte)alertLevel;
error[1] = (byte)alertDescription;
sendRecord(ContentType.alert, error, 0, 2);
}
private int receiveRecord(byte[] buf, int off, int len, int waitMillis)
throws IOException
{
if (recordQueue.available() > 0)
{
int length = 0;
if (recordQueue.available() >= RECORD_HEADER_LENGTH)
{
byte[] lengthBytes = new byte[2];
recordQueue.read(lengthBytes, 0, 2, 11);
length = TlsUtils.readUint16(lengthBytes, 0);
}
int received = Math.min(recordQueue.available(), RECORD_HEADER_LENGTH + length);
recordQueue.removeData(buf, off, received, 0);
return received;
}
int received = transport.receive(buf, off, len, waitMillis);
if (received >= RECORD_HEADER_LENGTH)
{
int fragmentLength = TlsUtils.readUint16(buf, off + 11);
int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
if (received > recordLength)
{
recordQueue.addData(buf, off + recordLength, received - recordLength);
received = recordLength;
}
}
return received;
}
private void sendRecord(short contentType, byte[] buf, int off, int len)
throws IOException
{
if (writeVersion == null)
{
return;
}
if (len > this.plaintextLimit)
{
throw new TlsFatalAlert(AlertDescription.internal_error);
}
if (len < 1 && contentType != ContentType.application_data)
{
throw new TlsFatalAlert(AlertDescription.internal_error);
}
int recordEpoch = writeEpoch.getEpoch();
long recordSequenceNumber = writeEpoch.allocateSequenceNumber();
byte[] ciphertext = writeEpoch.getCipher().encodePlaintext(
getMacSequenceNumber(recordEpoch, recordSequenceNumber), contentType, buf, off, len);
byte[] record = new byte[ciphertext.length + RECORD_HEADER_LENGTH];
TlsUtils.writeUint8(contentType, record, 0);
TlsUtils.writeVersion(writeVersion, record, 1);
TlsUtils.writeUint16(recordEpoch, record, 3);
TlsUtils.writeUint48(recordSequenceNumber, record, 5);
TlsUtils.writeUint16(ciphertext.length, record, 11);
System.arraycopy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.length);
transport.send(record, 0, record.length);
}
private static long getMacSequenceNumber(int epoch, long sequence_number)
{
return ((epoch & 0xFFFFFFFFL) << 48) | sequence_number;
}
}