package org.bouncycastle.crypto.tls;

import java.io.IOException;

class DTLSRecordLayer
    implements DatagramTransport
{
    private static final int RECORD_HEADER_LENGTH = 13;
    private static final int MAX_FRAGMENT_LENGTH = 1 << 14;
    private static final long TCP_MSL = 1000L * 60 * 2;
    private static final long RETRANSMIT_TIMEOUT = TCP_MSL * 2;

    private final DatagramTransport transport;
    private final TlsContext context;
    private final TlsPeer peer;

    private final ByteQueue recordQueue = new ByteQueue();

    private volatile boolean closed = false;
    private volatile boolean failed = false;
    private volatile ProtocolVersion readVersion = null, writeVersion = null;
    private volatile boolean inHandshake;
    private volatile int plaintextLimit;
    private DTLSEpoch currentEpoch, pendingEpoch;
    private DTLSEpoch readEpoch, writeEpoch;

    private DTLSHandshakeRetransmit retransmit = null;
    private DTLSEpoch retransmitEpoch = null;
    private long retransmitExpiry = 0;

    DTLSRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, short contentType)
    {
        this.transport = transport;
        this.context = context;
        this.peer = peer;

        this.inHandshake = true;

        this.currentEpoch = new DTLSEpoch(0, new TlsNullCipher(context));
        this.pendingEpoch = null;
        this.readEpoch = currentEpoch;
        this.writeEpoch = currentEpoch;

        setPlaintextLimit(MAX_FRAGMENT_LENGTH);
    }

    void setPlaintextLimit(int plaintextLimit)
    {
        this.plaintextLimit = plaintextLimit;
    }

    int getReadEpoch()
    {
        return readEpoch.getEpoch();
    }

    ProtocolVersion getReadVersion()
    {
        return readVersion;
    }

    void setReadVersion(ProtocolVersion readVersion)
    {
        this.readVersion = readVersion;
    }

    void setWriteVersion(ProtocolVersion writeVersion)
    {
        this.writeVersion = writeVersion;
    }

    void initPendingEpoch(TlsCipher pendingCipher)
    {
        if (pendingEpoch != null)
        {
            throw new IllegalStateException();
        }

        /*
         * TODO "In order to ensure that any given sequence/epoch pair is unique, implementations
         * MUST NOT allow the same epoch value to be reused within two times the TCP maximum segment
         * lifetime."
         */

        // TODO Check for overflow
        this.pendingEpoch = new DTLSEpoch(writeEpoch.getEpoch() + 1, pendingCipher);
    }

    void handshakeSuccessful(DTLSHandshakeRetransmit retransmit)
    {
        if (readEpoch == currentEpoch || writeEpoch == currentEpoch)
        {
            // TODO
            throw new IllegalStateException();
        }

        if (retransmit != null)
        {
            this.retransmit = retransmit;
            this.retransmitEpoch = currentEpoch;
            this.retransmitExpiry = System.currentTimeMillis() + RETRANSMIT_TIMEOUT;
        }

        this.inHandshake = false;
        this.currentEpoch = pendingEpoch;
        this.pendingEpoch = null;
    }

    void resetWriteEpoch()
    {
        if (retransmitEpoch != null)
        {
            this.writeEpoch = retransmitEpoch;
        }
        else
        {
            this.writeEpoch = currentEpoch;
        }
    }

    public int getReceiveLimit()
        throws IOException
    {
        return Math.min(this.plaintextLimit,
            readEpoch.getCipher().getPlaintextLimit(transport.getReceiveLimit() - RECORD_HEADER_LENGTH));
    }

    public int getSendLimit()
        throws IOException
    {
        return Math.min(this.plaintextLimit,
            writeEpoch.getCipher().getPlaintextLimit(transport.getSendLimit() - RECORD_HEADER_LENGTH));
    }

    public int receive(byte[] buf, int off, int len, int waitMillis)
        throws IOException
    {
        byte[] record = null;

        for (;;)
        {
            int receiveLimit = Math.min(len, getReceiveLimit()) + RECORD_HEADER_LENGTH;
            if (record == null || record.length < receiveLimit)
            {
                record = new byte[receiveLimit];
            }

            try
            {
                if (retransmit != null && System.currentTimeMillis() > retransmitExpiry)
                {
                    retransmit = null;
                    retransmitEpoch = null;
                }

                int received = receiveRecord(record, 0, receiveLimit, waitMillis);
                if (received < 0)
                {
                    return received;
                }
                if (received < RECORD_HEADER_LENGTH)
                {
                    continue;
                }
                int length = TlsUtils.readUint16(record, 11);
                if (received != (length + RECORD_HEADER_LENGTH))
                {
                    continue;
                }

                short type = TlsUtils.readUint8(record, 0);

                // TODO Support user-specified custom protocols?
                switch (type)
                {
                case ContentType.alert:
                case ContentType.application_data:
                case ContentType.change_cipher_spec:
                case ContentType.handshake:
                case ContentType.heartbeat:
                    break;
                default:
                    // TODO Exception?
                    continue;
                }

                int epoch = TlsUtils.readUint16(record, 3);

                DTLSEpoch recordEpoch = null;
                if (epoch == readEpoch.getEpoch())
                {
                    recordEpoch = readEpoch;
                }
                else if (type == ContentType.handshake && retransmitEpoch != null
                    && epoch == retransmitEpoch.getEpoch())
                {
                    recordEpoch = retransmitEpoch;
                }

                if (recordEpoch == null)
                {
                    continue;
                }

                long seq = TlsUtils.readUint48(record, 5);
                if (recordEpoch.getReplayWindow().shouldDiscard(seq))
                {
                    continue;
                }

                ProtocolVersion version = TlsUtils.readVersion(record, 1);
                if (!version.isDTLS())
                {
                    continue;
                }

                if (readVersion != null && !readVersion.equals(version))
                {
                    continue;
                }

                byte[] plaintext = recordEpoch.getCipher().decodeCiphertext(
                    getMacSequenceNumber(recordEpoch.getEpoch(), seq), type, record, RECORD_HEADER_LENGTH,
                    received - RECORD_HEADER_LENGTH);

                recordEpoch.getReplayWindow().reportAuthenticated(seq);

                if (plaintext.length > this.plaintextLimit)
                {
                    continue;
                }

                if (readVersion == null)
                {
                    readVersion = version;
                }

                switch (type)
                {
                case ContentType.alert:
                {
                    if (plaintext.length == 2)
                    {
                        short alertLevel = plaintext[0];
                        short alertDescription = plaintext[1];

                        peer.notifyAlertReceived(alertLevel, alertDescription);

                        if (alertLevel == AlertLevel.fatal)
                        {
                            failed();
                            throw new TlsFatalAlert(alertDescription);
                        }

                        // TODO Can close_notify be a fatal alert?
                        if (alertDescription == AlertDescription.close_notify)
                        {
                            closeTransport();
                        }
                    }

                    continue;
                }
                case ContentType.application_data:
                {
                    if (inHandshake)
                    {
                        // TODO Consider buffering application data for new epoch that arrives
                        // out-of-order with the Finished message
                        continue;
                    }
                    break;
                }
                case ContentType.change_cipher_spec:
                {
                    // Implicitly receive change_cipher_spec and change to pending cipher state

                    for (int i = 0; i < plaintext.length; ++i)
                    {
                        short message = TlsUtils.readUint8(plaintext, i);
                        if (message != ChangeCipherSpec.change_cipher_spec)
                        {
                            continue;
                        }

                        if (pendingEpoch != null)
                        {
                            readEpoch = pendingEpoch;
                        }
                    }

                    continue;
                }
                case ContentType.handshake:
                {
                    if (!inHandshake)
                    {
                        if (retransmit != null)
                        {
                            retransmit.receivedHandshakeRecord(epoch, plaintext, 0, plaintext.length);
                        }

                        // TODO Consider support for HelloRequest
                        continue;
                    }
                    break;
                }
                case ContentType.heartbeat:
                {
                    // TODO[RFC 6520]
                    continue;
                }
                }

                /*
                 * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
                 * received our final flight.
                 */
                if (!inHandshake && retransmit != null)
                {
                    this.retransmit = null;
                    this.retransmitEpoch = null;
                }

                System.arraycopy(plaintext, 0, buf, off, plaintext.length);
                return plaintext.length;
            }
            catch (IOException e)
            {
                // NOTE: Assume this is a timeout for the moment
                throw e;
            }
        }
    }

    public void send(byte[] buf, int off, int len)
        throws IOException
    {
        short contentType = ContentType.application_data;

        if (this.inHandshake || this.writeEpoch == this.retransmitEpoch)
        {
            contentType = ContentType.handshake;

            short handshakeType = TlsUtils.readUint8(buf, off);
            if (handshakeType == HandshakeType.finished)
            {
                DTLSEpoch nextEpoch = null;
                if (this.inHandshake)
                {
                    nextEpoch = pendingEpoch;
                }
                else if (this.writeEpoch == this.retransmitEpoch)
                {
                    nextEpoch = currentEpoch;
                }

                if (nextEpoch == null)
                {
                    // TODO
                    throw new IllegalStateException();
                }

                // Implicitly send change_cipher_spec and change to pending cipher state

                // TODO Send change_cipher_spec and finished records in single datagram?
                byte[] data = new byte[]{ 1 };
                sendRecord(ContentType.change_cipher_spec, data, 0, data.length);

                writeEpoch = nextEpoch;
            }
        }

        sendRecord(contentType, buf, off, len);
    }

    public void close()
        throws IOException
    {
        if (!closed)
        {
            if (inHandshake)
            {
                warn(AlertDescription.user_canceled, "User canceled handshake");
            }
            closeTransport();
        }
    }

    void fail(short alertDescription)
    {
        if (!closed)
        {
            try
            {
                raiseAlert(AlertLevel.fatal, alertDescription, null, null);
            }
            catch (Exception e)
            {
                // Ignore
            }

            failed = true;

            closeTransport();
        }
    }

    void failed()
    {
        if (!closed)
        {
            failed = true;

            closeTransport();
        }
    }

    void warn(short alertDescription, String message)
        throws IOException
    {
        raiseAlert(AlertLevel.warning, alertDescription, message, null);
    }

    private void closeTransport()
    {
        if (!closed)
        {
            /*
             * RFC 5246 7.2.1. Unless some other fatal alert has been transmitted, each party is
             * required to send a close_notify alert before closing the write side of the
             * connection. The other party MUST respond with a close_notify alert of its own and
             * close down the connection immediately, discarding any pending writes.
             */

            try
            {
                if (!failed)
                {
                    warn(AlertDescription.close_notify, null);
                }
                transport.close();
            }
            catch (Exception e)
            {
                // Ignore
            }

            closed = true;
        }
    }

    private void raiseAlert(short alertLevel, short alertDescription, String message, Throwable cause)
        throws IOException
    {
        peer.notifyAlertRaised(alertLevel, alertDescription, message, cause);

        byte[] error = new byte[2];
        error[0] = (byte)alertLevel;
        error[1] = (byte)alertDescription;

        sendRecord(ContentType.alert, error, 0, 2);
    }

    private int receiveRecord(byte[] buf, int off, int len, int waitMillis)
        throws IOException
    {
        if (recordQueue.available() > 0)
        {
            int length = 0;
            if (recordQueue.available() >= RECORD_HEADER_LENGTH)
            {
                byte[] lengthBytes = new byte[2];
                recordQueue.read(lengthBytes, 0, 2, 11);
                length = TlsUtils.readUint16(lengthBytes, 0);
            }

            int received = Math.min(recordQueue.available(), RECORD_HEADER_LENGTH + length);
            recordQueue.removeData(buf, off, received, 0);
            return received;
        }

        int received = transport.receive(buf, off, len, waitMillis);
        if (received >= RECORD_HEADER_LENGTH)
        {
            int fragmentLength = TlsUtils.readUint16(buf, off + 11);
            int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
            if (received > recordLength)
            {
                recordQueue.addData(buf, off + recordLength, received - recordLength);
                received = recordLength;
            }
        }

        return received;
    }

    private void sendRecord(short contentType, byte[] buf, int off, int len)
        throws IOException
    {
        // Never send anything until a valid ClientHello has been received
        if (writeVersion == null)
        {
            return;
        }

        if (len > this.plaintextLimit)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        /*
         * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
         * or ChangeCipherSpec content types.
         */
        if (len < 1 && contentType != ContentType.application_data)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        int recordEpoch = writeEpoch.getEpoch();
        long recordSequenceNumber = writeEpoch.allocateSequenceNumber();

        byte[] ciphertext = writeEpoch.getCipher().encodePlaintext(
            getMacSequenceNumber(recordEpoch, recordSequenceNumber), contentType, buf, off, len);

        // TODO Check the ciphertext length?

        byte[] record = new byte[ciphertext.length + RECORD_HEADER_LENGTH];
        TlsUtils.writeUint8(contentType, record, 0);
        TlsUtils.writeVersion(writeVersion, record, 1);
        TlsUtils.writeUint16(recordEpoch, record, 3);
        TlsUtils.writeUint48(recordSequenceNumber, record, 5);
        TlsUtils.writeUint16(ciphertext.length, record, 11);
        System.arraycopy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.length);

        transport.send(record, 0, record.length);
    }

    private static long getMacSequenceNumber(int epoch, long sequence_number)
    {
        return ((epoch & 0xFFFFFFFFL) << 48) | sequence_number;
    }
}