package sun.security.ssl;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import javax.crypto.BadPaddingException;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLProtocolException;
import sun.security.ssl.SSLCipher.SSLReadCipher;
final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
private InputStream is = null;
private OutputStream os = null;
private final byte[] temporary = new byte[1024];
private boolean formatVerified = false;
private ByteBuffer handshakeBuffer = null;
private boolean = false;
SSLSocketInputRecord(HandshakeHash handshakeHash) {
super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
}
@Override
int bytesInCompletePacket() throws IOException {
if (!hasHeader) {
try {
int really = read(is, temporary, 0, headerSize);
if (really < 0) {
return -1;
}
} catch (EOFException eofe) {
return -1;
}
hasHeader = true;
}
byte byteZero = temporary[0];
int len = 0;
if (formatVerified ||
(byteZero == ContentType.HANDSHAKE.id) ||
(byteZero == ContentType.ALERT.id)) {
if (!ProtocolVersion.isNegotiable(
temporary[1], temporary[2], false, false)) {
throw new SSLException("Unrecognized record version " +
ProtocolVersion.nameOf(temporary[1], temporary[2]) +
" , plaintext connection?");
}
formatVerified = true;
len = ((temporary[3] & 0xFF) << 8) +
(temporary[4] & 0xFF) + headerSize;
} else {
boolean isShort = ((byteZero & 0x80) != 0);
if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) {
if (!ProtocolVersion.isNegotiable(
temporary[3], temporary[4], false, false)) {
throw new SSLException("Unrecognized record version " +
ProtocolVersion.nameOf(temporary[3], temporary[4]) +
" , plaintext connection?");
}
len = ((byteZero & 0x7F) << 8) + (temporary[1] & 0xFF) + 2;
} else {
throw new SSLException(
"Unrecognized SSL message, plaintext connection?");
}
}
return len;
}
@Override
Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset,
int srcsLength) throws IOException, BadPaddingException {
if (isClosed) {
return null;
}
if (!hasHeader) {
int really = read(is, temporary, 0, headerSize);
if (really < 0) {
throw new EOFException("SSL peer shut down incorrectly");
}
hasHeader = true;
}
Plaintext plaintext = null;
if (!formatVerified) {
formatVerified = true;
if ((temporary[0] != ContentType.HANDSHAKE.id) &&
(temporary[0] != ContentType.ALERT.id)) {
hasHeader = false;
return handleUnknownRecord(temporary);
}
}
hasHeader = false;
return decodeInputRecord(temporary);
}
@Override
void setReceiverStream(InputStream inputStream) {
this.is = inputStream;
}
@Override
void setDeliverStream(OutputStream outputStream) {
this.os = outputStream;
}
private Plaintext[] decodeInputRecord(
byte[] header) throws IOException, BadPaddingException {
byte contentType = header[0];
byte majorVersion = header[1];
byte minorVersion = header[2];
int contentLen = ((header[3] & 0xFF) << 8) +
(header[4] & 0xFF);
if (SSLLogger.isOn && SSLLogger.isOn("record")) {
SSLLogger.fine(
"READ: " +
ProtocolVersion.nameOf(majorVersion, minorVersion) +
" " + ContentType.nameOf(contentType) + ", length = " +
contentLen);
}
if (contentLen < 0 || contentLen > maxLargeRecordSize - headerSize) {
throw new SSLProtocolException(
"Bad input record size, TLSCiphertext.length = " + contentLen);
}
ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen);
int dstPos = destination.position();
destination.put(temporary, 0, headerSize);
while (contentLen > 0) {
int howmuch = Math.min(temporary.length, contentLen);
int really = read(is, temporary, 0, howmuch);
if (really < 0) {
throw new EOFException("SSL peer shut down incorrectly");
}
destination.put(temporary, 0, howmuch);
contentLen -= howmuch;
}
destination.flip();
destination.position(dstPos + headerSize);
if (SSLLogger.isOn && SSLLogger.isOn("record")) {
SSLLogger.fine(
"READ: " +
ProtocolVersion.nameOf(majorVersion, minorVersion) +
" " + ContentType.nameOf(contentType) + ", length = " +
destination.remaining());
}
ByteBuffer fragment;
try {
Plaintext plaintext =
readCipher.decrypt(contentType, destination, null);
fragment = plaintext.fragment;
contentType = plaintext.contentType;
} catch (BadPaddingException bpe) {
throw bpe;
} catch (GeneralSecurityException gse) {
throw (SSLProtocolException)(new SSLProtocolException(
"Unexpected exception")).initCause(gse);
}
if (contentType != ContentType.HANDSHAKE.id &&
handshakeBuffer != null && handshakeBuffer.hasRemaining()) {
throw new SSLProtocolException(
"Expecting a handshake fragment, but received " +
ContentType.nameOf(contentType));
}
if (contentType == ContentType.HANDSHAKE.id) {
ByteBuffer handshakeFrag = fragment;
if ((handshakeBuffer != null) &&
(handshakeBuffer.remaining() != 0)) {
ByteBuffer bb = ByteBuffer.wrap(new byte[
handshakeBuffer.remaining() + fragment.remaining()]);
bb.put(handshakeBuffer);
bb.put(fragment);
handshakeFrag = bb.rewind();
handshakeBuffer = null;
}
ArrayList<Plaintext> plaintexts = new ArrayList<>(5);
while (handshakeFrag.hasRemaining()) {
int remaining = handshakeFrag.remaining();
if (remaining < handshakeHeaderSize) {
handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
handshakeBuffer.put(handshakeFrag);
handshakeBuffer.rewind();
break;
}
handshakeFrag.mark();
byte handshakeType = handshakeFrag.get();
int handshakeBodyLen = Record.getInt24(handshakeFrag);
handshakeFrag.reset();
int handshakeMessageLen =
handshakeHeaderSize + handshakeBodyLen;
if (remaining < handshakeMessageLen) {
handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
handshakeBuffer.put(handshakeFrag);
handshakeBuffer.rewind();
break;
} if (remaining == handshakeMessageLen) {
if (handshakeHash.isHashable(handshakeType)) {
handshakeHash.receive(handshakeFrag);
}
plaintexts.add(
new Plaintext(contentType,
majorVersion, minorVersion, -1, -1L, handshakeFrag)
);
break;
} else {
int fragPos = handshakeFrag.position();
int fragLim = handshakeFrag.limit();
int nextPos = fragPos + handshakeMessageLen;
handshakeFrag.limit(nextPos);
if (handshakeHash.isHashable(handshakeType)) {
handshakeHash.receive(handshakeFrag);
}
plaintexts.add(
new Plaintext(contentType, majorVersion, minorVersion,
-1, -1L, handshakeFrag.slice())
);
handshakeFrag.position(nextPos);
handshakeFrag.limit(fragLim);
}
}
return plaintexts.toArray(new Plaintext[0]);
}
return new Plaintext[] {
new Plaintext(contentType,
majorVersion, minorVersion, -1, -1L, fragment)
};
}
private Plaintext[] handleUnknownRecord(
byte[] header) throws IOException, BadPaddingException {
byte firstByte = header[0];
byte thirdByte = header[2];
if (((firstByte & 0x80) != 0) && (thirdByte == 1)) {
if (helloVersion != ProtocolVersion.SSL20Hello) {
throw new SSLHandshakeException("SSLv2Hello is not enabled");
}
byte majorVersion = header[3];
byte minorVersion = header[4];
if ((majorVersion == ProtocolVersion.SSL20Hello.major) &&
(minorVersion == ProtocolVersion.SSL20Hello.minor)) {
os.write(SSLRecord.v2NoCipher);
if (SSLLogger.isOn) {
if (SSLLogger.isOn("record")) {
SSLLogger.fine(
"Requested to negotiate unsupported SSLv2!");
}
if (SSLLogger.isOn("packet")) {
SSLLogger.fine("Raw write", SSLRecord.v2NoCipher);
}
}
throw new SSLException("Unsupported SSL v2.0 ClientHello");
}
int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF);
ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen);
destination.put(temporary, 0, headerSize);
msgLen -= 3;
while (msgLen > 0) {
int howmuch = Math.min(temporary.length, msgLen);
int really = read(is, temporary, 0, howmuch);
if (really < 0) {
throw new EOFException("SSL peer shut down incorrectly");
}
destination.put(temporary, 0, howmuch);
msgLen -= howmuch;
}
destination.flip();
destination.position(2);
handshakeHash.receive(destination);
destination.position(0);
ByteBuffer converted = convertToClientHello(destination);
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
SSLLogger.fine(
"[Converted] ClientHello", converted);
}
return new Plaintext[] {
new Plaintext(ContentType.HANDSHAKE.id,
majorVersion, minorVersion, -1, -1L, converted)
};
} else {
if (((firstByte & 0x80) != 0) && (thirdByte == 4)) {
throw new SSLException("SSL V2.0 servers are not supported.");
}
throw new SSLException("Unsupported or unrecognized SSL message");
}
}
private static int read(InputStream is,
byte[] buffer, int offset, int len) throws IOException {
int n = 0;
while (n < len) {
int readLen = is.read(buffer, offset + n, len - n);
if (readLen < 0) {
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
SSLLogger.fine("Raw read: EOF");
}
return -1;
}
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen);
SSLLogger.fine("Raw read", bb);
}
n += readLen;
}
return n;
}
void deplete(boolean tryToRead) throws IOException {
int remaining = is.available();
if (tryToRead && (remaining == 0)) {
is.read();
}
while ((remaining = is.available()) != 0) {
is.skip(remaining);
}
}
}