/*
 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

package sun.security.ssl;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.text.MessageFormat;
import java.util.*;

import sun.security.ssl.SSLHandshake.HandshakeMessage;
import sun.security.util.HexDumpEncoder;

SSL/(D)TLS extensions in a handshake message.
/** * SSL/(D)TLS extensions in a handshake message. */
final class SSLExtensions { private final HandshakeMessage handshakeMessage; private Map<SSLExtension, byte[]> extMap = new LinkedHashMap<>(); private int encodedLength; // Extension map for debug logging private final Map<Integer, byte[]> logMap = SSLLogger.isOn ? new LinkedHashMap<>() : null; SSLExtensions(HandshakeMessage handshakeMessage) { this.handshakeMessage = handshakeMessage; this.encodedLength = 2; // 2: the length of the extensions. } SSLExtensions(HandshakeMessage hm, ByteBuffer m, SSLExtension[] extensions) throws IOException { this.handshakeMessage = hm; int len = Record.getInt16(m); encodedLength = len + 2; // 2: the length of the extensions. while (len > 0) { int extId = Record.getInt16(m); int extLen = Record.getInt16(m); if (extLen > m.remaining()) { throw hm.handshakeContext.conContext.fatal( Alert.ILLEGAL_PARAMETER, "Error parsing extension (" + extId + "): no sufficient data"); } boolean isSupported = true; SSLHandshake handshakeType = hm.handshakeType(); if (SSLExtension.isConsumable(extId) && SSLExtension.valueOf(handshakeType, extId) == null) { if (extId == SSLExtension.CH_SUPPORTED_GROUPS.id && handshakeType == SSLHandshake.SERVER_HELLO) { // Note: It does not comply to the specification. However, // there are servers that send the supported_groups // extension in ServerHello handshake message. // // TLS 1.3 should not send this extension. We may want to // limit the workaround for TLS 1.2 and prior version only. // However, the implementation of the limit is complicated // and inefficient, and may not worthy the maintenance. isSupported = false; if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.warning( "Received buggy supported_groups extension " + "in the ServerHello handshake message"); } } else { throw hm.handshakeContext.conContext.fatal( Alert.UNSUPPORTED_EXTENSION, "extension (" + extId + ") should not be presented in " + handshakeType.name); } } if (isSupported) { isSupported = false; for (SSLExtension extension : extensions) { if ((extension.id != extId) || (extension.onLoadConsumer == null)) { continue; } if (extension.handshakeType != handshakeType) { throw hm.handshakeContext.conContext.fatal( Alert.UNSUPPORTED_EXTENSION, "extension (" + extId + ") should not be " + "presented in " + handshakeType.name); } byte[] extData = new byte[extLen]; m.get(extData); extMap.put(extension, extData); if (logMap != null) { logMap.put(extId, extData); } isSupported = true; break; } } if (!isSupported) { if (logMap != null) { // cache the extension for debug logging byte[] extData = new byte[extLen]; m.get(extData); logMap.put(extId, extData); if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.fine( "Ignore unknown or unsupported extension", toString(extId, extData)); } } else { // ignore the extension int pos = m.position() + extLen; m.position(pos); } } len -= extLen + 4; } } byte[] get(SSLExtension ext) { return extMap.get(ext); }
Consume the specified extensions.
/** * Consume the specified extensions. */
void consumeOnLoad(HandshakeContext context, SSLExtension[] extensions) throws IOException { for (SSLExtension extension : extensions) { if (context.negotiatedProtocol != null && !extension.isAvailable(context.negotiatedProtocol)) { if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.fine( "Ignore unsupported extension: " + extension.name); } continue; } if (!extMap.containsKey(extension)) { if (extension.onLoadAbsence != null) { extension.absentOnLoad(context, handshakeMessage); } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.fine( "Ignore unavailable extension: " + extension.name); } continue; } if (extension.onLoadConsumer == null) { if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.warning( "Ignore unsupported extension: " + extension.name); } continue; } ByteBuffer m = ByteBuffer.wrap(extMap.get(extension)); extension.consumeOnLoad(context, handshakeMessage, m); if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.fine("Consumed extension: " + extension.name); } } }
Consider impact of the specified extensions.
/** * Consider impact of the specified extensions. */
void consumeOnTrade(HandshakeContext context, SSLExtension[] extensions) throws IOException { for (SSLExtension extension : extensions) { if (!extMap.containsKey(extension)) { if (extension.onTradeAbsence != null) { extension.absentOnTrade(context, handshakeMessage); } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.fine( "Ignore unavailable extension: " + extension.name); } continue; } if (extension.onTradeConsumer == null) { if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.warning( "Ignore impact of unsupported extension: " + extension.name); } continue; } extension.consumeOnTrade(context, handshakeMessage); if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.fine("Populated with extension: " + extension.name); } } }
Produce extension values for the specified extensions.
/** * Produce extension values for the specified extensions. */
void produce(HandshakeContext context, SSLExtension[] extensions) throws IOException { for (SSLExtension extension : extensions) { if (extMap.containsKey(extension)) { if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.fine( "Ignore, duplicated extension: " + extension.name); } continue; } if (extension.networkProducer == null) { if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.warning( "Ignore, no extension producer defined: " + extension.name); } continue; } byte[] encoded = extension.produce(context, handshakeMessage); if (encoded != null) { extMap.put(extension, encoded); encodedLength += encoded.length + 4; // extension_type (2) // extension_data length(2) } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { // The extension is not available in the context. SSLLogger.fine( "Ignore, context unavailable extension: " + extension.name); } } }
Produce extension values for the specified extensions, replacing if there is an existing extension value for a specified extension.
/** * Produce extension values for the specified extensions, replacing if * there is an existing extension value for a specified extension. */
void reproduce(HandshakeContext context, SSLExtension[] extensions) throws IOException { for (SSLExtension extension : extensions) { if (extension.networkProducer == null) { if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.warning( "Ignore, no extension producer defined: " + extension.name); } continue; } byte[] encoded = extension.produce(context, handshakeMessage); if (encoded != null) { if (extMap.containsKey(extension)) { byte[] old = extMap.replace(extension, encoded); if (old != null) { encodedLength -= old.length + 4; } encodedLength += encoded.length + 4; } else { extMap.put(extension, encoded); encodedLength += encoded.length + 4; // extension_type (2) // extension_data length(2) } } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { // The extension is not available in the context. SSLLogger.fine( "Ignore, context unavailable extension: " + extension.name); } } } // Note that TLS 1.3 may use empty extensions. Please consider it while // using this method. int length() { if (extMap.isEmpty()) { return 0; } else { return encodedLength; } } // Note that TLS 1.3 may use empty extensions. Please consider it while // using this method. void send(HandshakeOutStream hos) throws IOException { int extsLen = length(); if (extsLen == 0) { return; } hos.putInt16(extsLen - 2); // extensions must be sent in the order they appear in the enum for (SSLExtension ext : SSLExtension.values()) { byte[] extData = extMap.get(ext); if (extData != null) { hos.putInt16(ext.id); hos.putBytes16(extData); } } } @Override public String toString() { if (extMap.isEmpty() && (logMap == null || logMap.isEmpty())) { return "<no extension>"; } else { StringBuilder builder = new StringBuilder(512); if (logMap != null && !logMap.isEmpty()) { for (Map.Entry<Integer, byte[]> en : logMap.entrySet()) { SSLExtension ext = SSLExtension.valueOf( handshakeMessage.handshakeType(), en.getKey()); if (builder.length() != 0) { builder.append(",\n"); } if (ext != null) { builder.append( ext.toString(ByteBuffer.wrap(en.getValue()))); } else { builder.append(toString(en.getKey(), en.getValue())); } } return builder.toString(); } else { for (Map.Entry<SSLExtension, byte[]> en : extMap.entrySet()) { if (builder.length() != 0) { builder.append(",\n"); } builder.append( en.getKey().toString(ByteBuffer.wrap(en.getValue()))); } return builder.toString(); } } } private static String toString(int extId, byte[] extData) { MessageFormat messageFormat = new MessageFormat( "\"unknown extension ({0})\": '{'\n" + "{1}\n" + "'}'", Locale.ENGLISH); HexDumpEncoder hexEncoder = new HexDumpEncoder(); String encoded = hexEncoder.encodeBuffer(extData); Object[] messageFields = { extId, Utilities.indent(encoded) }; return messageFormat.format(messageFields); } }