/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.protocols.ssl;

import io.undertow.UndertowMessages;

import javax.net.ssl.SSLException;
import java.io.ByteArrayOutputStream;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

This class is used to both read and write the ALPN protocol names in the ClientHello SSL message. If the out parameter is not null then the read function is being used, while if it present then it is rewriting the hello message to include ALPN. Even though this dual approach is not particularly clean it does remove the need to have two versions of each function, that do almost exactly the same thing.
/** * This class is used to both read and write the ALPN protocol names in the ClientHello SSL message. * * If the out parameter is not null then the read function is being used, while if it present then it is rewriting * the hello message to include ALPN. * * Even though this dual approach is not particularly clean it does remove the need to have two versions of each function, * that do almost exactly the same thing. * */
final class ALPNHackClientHelloExplorer { // Private constructor prevents construction outside this class. private ALPNHackClientHelloExplorer() { }
The header size of TLS/SSL records.

The value of this constant is 5.

/** * The header size of TLS/SSL records. * <P> * The value of this constant is {@value}. */
public static final int RECORD_HEADER_SIZE = 0x05; /** * * */ static List<String> exploreClientHello(ByteBuffer source) throws SSLException { ByteBuffer input = source.duplicate(); // Do we have a complete header? if (input.remaining() < RECORD_HEADER_SIZE) { throw new BufferUnderflowException(); } List<String> alpnProtocols = new ArrayList<>(); // Is it a handshake message? byte firstByte = input.get(); byte secondByte = input.get(); byte thirdByte = input.get(); if ((firstByte & 0x80) != 0 && thirdByte == 0x01) { // looks like a V2ClientHello, we ignore it. return null; } else if (firstByte == 22) { // 22: handshake record if(secondByte == 3 && thirdByte >= 1 && thirdByte <= 3) { exploreTLSRecord(input, firstByte, secondByte, thirdByte, alpnProtocols, null); return alpnProtocols; } return null; } else { throw UndertowMessages.MESSAGES.notHandshakeRecord(); } } static byte[] rewriteClientHello(byte[] source, List<String> alpnProtocols) throws SSLException { ByteBuffer input = ByteBuffer.wrap(source); ByteArrayOutputStream out = new ByteArrayOutputStream(); // Do we have a complete header? if (input.remaining() < RECORD_HEADER_SIZE) { throw new BufferUnderflowException(); } try { // Is it a handshake message? byte firstByte = input.get(); byte secondByte = input.get(); byte thirdByte = input.get(); out.write(firstByte & 0xFF); out.write(secondByte & 0xFF); out.write(thirdByte & 0xFF); if ((firstByte & 0x80) != 0 && thirdByte == 0x01) { // looks like a V2ClientHello, we ignore it. return null; } else if (firstByte == 22) { // 22: handshake record if (secondByte == 3 && thirdByte == 3) { //TLS1.2 is the only one we care about. Previous versions can't use HTTP/2, newer versions won't be backported to JDK8 exploreTLSRecord(input, firstByte, secondByte, thirdByte, alpnProtocols, out); //we need to adjust the record length; int clientHelloLength = out.size() - 9; byte[] data = out.toByteArray(); int newLength = data.length - 5; data[3] = (byte) ((newLength >> 8) & 0xFF); data[4] = (byte) (newLength & 0xFF); //now we need to adjust the handshake frame length data[6] = (byte) ((clientHelloLength >> 16) & 0xFF); data[7] = (byte) ((clientHelloLength >> 8) & 0xFF); data[8] = (byte) (clientHelloLength & 0xFF); return data; } return null; } else { throw UndertowMessages.MESSAGES.notHandshakeRecord(); } } catch (ALPNPresentException e) { return null; } } /* * struct { * uint8 major; * uint8 minor; * } ProtocolVersion; * * enum { * change_cipher_spec(20), alert(21), handshake(22), * application_data(23), (255) * } ContentType; * * struct { * ContentType type; * ProtocolVersion version; * uint16 length; * opaque fragment[TLSPlaintext.length]; * } TLSPlaintext; */ private static void exploreTLSRecord( ByteBuffer input, byte firstByte, byte secondByte, byte thirdByte, List<String> alpnProtocols, ByteArrayOutputStream out) throws SSLException { // Is it a handshake message? if (firstByte != 22) { // 22: handshake record throw UndertowMessages.MESSAGES.notHandshakeRecord(); } // Is there enough data for a full record? int recordLength = getInt16(input); if (recordLength > input.remaining()) { throw new BufferUnderflowException(); } if(out != null) { out.write(0); out.write(0); } // We have already had enough source bytes. try { exploreHandshake(input, secondByte, thirdByte, recordLength, alpnProtocols, out); } catch (BufferUnderflowException ignored) { throw UndertowMessages.MESSAGES.invalidHandshakeRecord(); } } /* * enum { * hello_request(0), client_hello(1), server_hello(2), * certificate(11), server_key_exchange (12), * certificate_request(13), server_hello_done(14), * certificate_verify(15), client_key_exchange(16), * finished(20) * (255) * } HandshakeType; * * struct { * HandshakeType msg_type; * uint24 length; * select (HandshakeType) { * case hello_request: HelloRequest; * case client_hello: ClientHello; * case server_hello: ServerHello; * case certificate: Certificate; * case server_key_exchange: ServerKeyExchange; * case certificate_request: CertificateRequest; * case server_hello_done: ServerHelloDone; * case certificate_verify: CertificateVerify; * case client_key_exchange: ClientKeyExchange; * case finished: Finished; * } body; * } Handshake; */ private static void exploreHandshake( ByteBuffer input, byte recordMajorVersion, byte recordMinorVersion, int recordLength, List<String> alpnProtocols, ByteArrayOutputStream out) throws SSLException { // What is the handshake type? byte handshakeType = input.get(); if (handshakeType != 0x01) { // 0x01: client_hello message throw UndertowMessages.MESSAGES.expectedClientHello(); } if(out != null) { out.write(handshakeType & 0xFF); } // What is the handshake body length? int handshakeLength = getInt24(input); if(out != null) { //placeholder out.write(0); out.write(0); out.write(0); } // Theoretically, a single handshake message might span multiple // records, but in practice this does not occur. if (handshakeLength > recordLength - 4) { // 4: handshake header size throw UndertowMessages.MESSAGES.multiRecordSSLHandshake(); } input = input.duplicate(); input.limit(handshakeLength + input.position()); exploreClientHello(input, alpnProtocols, out); } /* * struct { * uint32 gmt_unix_time; * opaque random_bytes[28]; * } Random; * * opaque SessionID<0..32>; * * uint8 CipherSuite[2]; * * enum { null(0), (255) } CompressionMethod; * * struct { * ProtocolVersion client_version; * Random random; * SessionID session_id; * CipherSuite cipher_suites<2..2^16-2>; * CompressionMethod compression_methods<1..2^8-1>; * select (extensions_present) { * case false: * struct {}; * case true: * Extension extensions<0..2^16-1>; * }; * } ClientHello; */ private static void exploreClientHello( ByteBuffer input, List<String> alpnProtocols, ByteArrayOutputStream out) throws SSLException { // client version byte helloMajorVersion = input.get(); byte helloMinorVersion = input.get(); if(out != null) { out.write(helloMajorVersion & 0xFF); out.write(helloMinorVersion & 0xFF); } if(helloMajorVersion != 3 && helloMinorVersion != 3) { //we only care about TLS 1.2 return; } // ignore random for(int i = 0; i < 32; ++i) {// 32: the length of Random byte d = input.get(); if(out != null) { out.write(d & 0xFF); } } // session id processByteVector8(input, out); // cipher_suites processByteVector16(input, out); // compression methods processByteVector8(input, out); if (input.remaining() > 0) { exploreExtensions(input, alpnProtocols, out); } else if(out != null) { byte[] data = generateAlpnExtension(alpnProtocols); writeInt16(out, data.length); out.write(data, 0, data.length); } } private static void writeInt16(ByteArrayOutputStream out, int l) { if(out == null) return; out.write((l >> 8) & 0xFF); out.write(l & 0xFF); } private static byte[] generateAlpnExtension(List<String> alpnProtocols) { ByteArrayOutputStream alpnBits = new ByteArrayOutputStream(); alpnBits.write(0); alpnBits.write(16); //ALPN type int length = 2; for(String p : alpnProtocols) { length++; length += p.length(); } writeInt16(alpnBits, length); length -= 2; writeInt16(alpnBits, length); for(String p : alpnProtocols) { alpnBits.write(p.length() & 0xFF); for (int i = 0; i < p.length(); ++i) { alpnBits.write(p.charAt(i) & 0xFF); } } return alpnBits.toByteArray(); } /* * struct { * ExtensionType extension_type; * opaque extension_data<0..2^16-1>; * } Extension; * * enum { * server_name(0), max_fragment_length(1), * client_certificate_url(2), trusted_ca_keys(3), * truncated_hmac(4), status_request(5), (65535) * } ExtensionType; */ private static void exploreExtensions(ByteBuffer input, List<String> alpnProtocols, ByteArrayOutputStream out) throws SSLException { ByteArrayOutputStream extensionOut = out == null ? null : new ByteArrayOutputStream(); int length = getInt16(input); // length of extensions writeInt16(extensionOut, 0); //placeholder while (length > 0) { int extType = getInt16(input); // extenson type writeInt16(extensionOut, extType); int extLen = getInt16(input); // length of extension data writeInt16(extensionOut, extLen); if (extType == 16) { // 0x00: ty if(out == null) { exploreALPNExt(input, alpnProtocols); } else { throw new ALPNPresentException(); } } else { // ignore other extensions processByteVector(input, extLen, extensionOut); } length -= extLen + 4; } if(out != null) { byte[] alpnBits = generateAlpnExtension(alpnProtocols); extensionOut.write(alpnBits,0 ,alpnBits.length); byte[] extensionsData = extensionOut.toByteArray(); int newLength = extensionsData.length - 2; extensionsData[0] = (byte) ((newLength >> 8) & 0xFF); extensionsData[1] = (byte) (newLength & 0xFF); out.write(extensionsData, 0, extensionsData.length); } } private static void exploreALPNExt(ByteBuffer input, List<String> alpnProtocols) { int length = getInt16(input); int end = input.position() + length; while (input.position() < end) { alpnProtocols.add(readByteVector8(input)); } } private static int getInt8(ByteBuffer input) { return input.get(); } private static int getInt16(ByteBuffer input) { return (input.get() & 0xFF) << 8 | input.get() & 0xFF; } private static int getInt24(ByteBuffer input) { return (input.get() & 0xFF) << 16 | (input.get() & 0xFF) << 8 | input.get() & 0xFF; } private static void processByteVector8(ByteBuffer input, ByteArrayOutputStream out) { int int8 = getInt8(input); if(out != null) { out.write(int8 & 0xFF); } processByteVector(input, int8, out); } private static void processByteVector(ByteBuffer input, int length, ByteArrayOutputStream out) { for (int i = 0; i < length; ++i) { byte b = input.get(); if(out != null) { out.write(b & 0xFF); } } } private static String readByteVector8(ByteBuffer input) { int length = getInt8(input); byte[] data = new byte[length]; input.get(data); return new String(data, StandardCharsets.US_ASCII); } private static void processByteVector16(ByteBuffer input, ByteArrayOutputStream out) { int int16 = getInt16(input); writeInt16(out, int16); processByteVector(input, int16, out); } private static final class ALPNPresentException extends RuntimeException { } }