package io.undertow.protocols.ssl;
import io.undertow.UndertowLogger;
import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
public class ALPNHackSSLEngine extends SSLEngine {
public static final boolean ENABLED;
private static final Field HANDSHAKER;
private static final Field HANDSHAKER_PROTOCOL_VERSION;
private static final Field HANDSHAKE_HASH;
private static final Field HANDSHAKE_HASH_VERSION;
private static final Method HANDSHAKE_HASH_UPDATE;
private static final Method HANDSHAKE_HASH_PROTOCOL_DETERMINED;
private static final Field HANDSHAKE_HASH_DATA;
private static final Field HANDSHAKE_HASH_FIN_MD;
private static final Class<?> SSL_ENGINE_IMPL_CLASS;
static {
boolean enabled = true;
Field handshaker;
Field handshakeHash;
Field handshakeHashVersion;
Field handshakeHashData;
Field handshakeHashFinMd;
Field protocolVersion;
Method handshakeHashUpdate;
Method handshakeHashProtocolDetermined;
Class<?> sslEngineImpleClass;
try {
Class<?> protocolVersionClass = Class.forName("sun.security.ssl.ProtocolVersion", true, ClassLoader.getSystemClassLoader());
sslEngineImpleClass = Class.forName("sun.security.ssl.SSLEngineImpl", true, ClassLoader.getSystemClassLoader());
handshaker = sslEngineImpleClass.getDeclaredField("handshaker");
handshaker.setAccessible(true);
handshakeHash = handshaker.getType().getDeclaredField("handshakeHash");
handshakeHash.setAccessible(true);
protocolVersion = handshaker.getType().getDeclaredField("protocolVersion");
protocolVersion.setAccessible(true);
handshakeHashVersion = handshakeHash.getType().getDeclaredField("version");
handshakeHashVersion.setAccessible(true);
handshakeHashUpdate = handshakeHash.getType().getDeclaredMethod("update", byte[].class, int.class, int.class);
handshakeHashUpdate.setAccessible(true);
handshakeHashProtocolDetermined = handshakeHash.getType().getDeclaredMethod("protocolDetermined", protocolVersionClass);
handshakeHashProtocolDetermined.setAccessible(true);
handshakeHashData = handshakeHash.getType().getDeclaredField("data");
handshakeHashData.setAccessible(true);
handshakeHashFinMd = handshakeHash.getType().getDeclaredField("finMD");
handshakeHashFinMd.setAccessible(true);
} catch (Exception e) {
UndertowLogger.ROOT_LOGGER.debug("JDK8 ALPN Hack failed ", e);
enabled = false;
handshaker = null;
handshakeHash = null;
handshakeHashVersion = null;
handshakeHashUpdate = null;
handshakeHashProtocolDetermined = null;
handshakeHashData = null;
handshakeHashFinMd = null;
protocolVersion = null;
sslEngineImpleClass = null;
}
ENABLED = enabled && !Boolean.getBoolean("io.undertow.disable-jdk8-alpn");
HANDSHAKER = handshaker;
HANDSHAKE_HASH = handshakeHash;
HANDSHAKE_HASH_PROTOCOL_DETERMINED = handshakeHashProtocolDetermined;
HANDSHAKE_HASH_VERSION = handshakeHashVersion;
HANDSHAKE_HASH_UPDATE = handshakeHashUpdate;
HANDSHAKE_HASH_DATA = handshakeHashData;
HANDSHAKE_HASH_FIN_MD = handshakeHashFinMd;
HANDSHAKER_PROTOCOL_VERSION = protocolVersion;
SSL_ENGINE_IMPL_CLASS = sslEngineImpleClass;
}
private final SSLEngine delegate;
private boolean unwrapHelloSeen = false;
private boolean ourHelloSent = false;
private ALPNHackServerByteArrayOutputStream alpnHackServerByteArrayOutputStream;
private ALPNHackClientByteArrayOutputStream ALPNHackClientByteArrayOutputStream;
private List<String> applicationProtocols;
private String selectedApplicationProtocol;
private ByteBuffer bufferedWrapData;
public ALPNHackSSLEngine(SSLEngine delegate) {
this.delegate = delegate;
}
public static boolean isEnabled(SSLEngine engine) {
if(!ENABLED) {
return false;
}
return SSL_ENGINE_IMPL_CLASS.isAssignableFrom(engine.getClass());
}
@Override
public SSLEngineResult wrap(ByteBuffer[] byteBuffers, int i, int i1, ByteBuffer byteBuffer) throws SSLException {
if(bufferedWrapData != null) {
int prod = bufferedWrapData.remaining();
byteBuffer.put(bufferedWrapData);
bufferedWrapData = null;
return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, prod);
}
int pos = byteBuffer.position();
int limit = byteBuffer.limit();
SSLEngineResult res = delegate.wrap(byteBuffers, i, i1, byteBuffer);
if(!ourHelloSent && res.bytesProduced() > 0) {
if(delegate.getUseClientMode() && applicationProtocols != null && !applicationProtocols.isEmpty()) {
ourHelloSent = true;
ALPNHackClientByteArrayOutputStream = replaceClientByteOutput(delegate);
ByteBuffer newBuf = byteBuffer.duplicate();
newBuf.flip();
byte[] data = new byte[newBuf.remaining()];
newBuf.get(data);
byte[] newData = ALPNHackClientHelloExplorer.rewriteClientHello(data, applicationProtocols);
if(newData != null) {
byte[] clientHelloMesage = new byte[newData.length - 5];
System.arraycopy(newData, 5, clientHelloMesage, 0 , clientHelloMesage.length);
ALPNHackClientByteArrayOutputStream.setSentClientHello(clientHelloMesage);
byteBuffer.clear();
byteBuffer.put(newData);
}
} else if (!getUseClientMode()) {
if(selectedApplicationProtocol != null && alpnHackServerByteArrayOutputStream != null) {
byte[] newServerHello = alpnHackServerByteArrayOutputStream.getServerHello();
if (newServerHello != null) {
byteBuffer.flip();
List<ByteBuffer> records = ALPNHackServerHelloExplorer.extractRecords(byteBuffer);
ByteBuffer newData = ALPNHackServerHelloExplorer.createNewOutputRecords(newServerHello, records);
byteBuffer.position(pos);
byteBuffer.limit(limit);
if (newData.remaining() > byteBuffer.remaining()) {
int old = newData.limit();
newData.limit(newData.position() + byteBuffer.remaining());
res = new SSLEngineResult(res.getStatus(), res.getHandshakeStatus(), res.bytesConsumed(), newData.remaining());
byteBuffer.put(newData);
newData.limit(old);
bufferedWrapData = newData;
} else {
res = new SSLEngineResult(res.getStatus(), res.getHandshakeStatus(), res.bytesConsumed(), newData.remaining());
byteBuffer.put(newData);
}
}
}
}
}
if(res.bytesProduced() > 0) {
ourHelloSent = true;
}
return res;
}
@Override
public SSLEngineResult unwrap(ByteBuffer dataToUnwrap, ByteBuffer[] byteBuffers, int i, int i1) throws SSLException {
if(!unwrapHelloSeen) {
if(!delegate.getUseClientMode() && applicationProtocols != null) {
try {
List<String> result = ALPNHackClientHelloExplorer.exploreClientHello(dataToUnwrap.duplicate());
if(result != null) {
for(String protocol : applicationProtocols) {
if(result.contains(protocol)) {
selectedApplicationProtocol = protocol;
break;
}
}
}
unwrapHelloSeen = true;
} catch (BufferUnderflowException e) {
return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0);
}
} else if(delegate.getUseClientMode() && ALPNHackClientByteArrayOutputStream != null) {
if(!dataToUnwrap.hasRemaining()) {
return delegate.unwrap(dataToUnwrap, byteBuffers, i, i1);
}
try {
ByteBuffer dup = dataToUnwrap.duplicate();
int type = dup.get();
int major = dup.get();
int minor = dup.get();
if(type == 22 && major == 3 && minor == 3) {
List<ByteBuffer> records = ALPNHackServerHelloExplorer.extractRecords(dataToUnwrap.duplicate());
ByteBuffer firstRecord = records.get(0);
final AtomicReference<String> alpnResult = new AtomicReference<>();
ByteBuffer dupFirst = firstRecord.duplicate();
dupFirst.position(firstRecord.position() + 5);
ByteBuffer firstLessFraming = dupFirst.duplicate();
byte[] result = ALPNHackServerHelloExplorer.removeAlpnExtensionsFromServerHello(dupFirst, alpnResult);
firstLessFraming.limit(dupFirst.position());
unwrapHelloSeen = true;
if (result != null) {
selectedApplicationProtocol = alpnResult.get();
int newFirstRecordLength = result.length + dupFirst.remaining();
byte[] newFirstRecord = new byte[newFirstRecordLength];
System.arraycopy(result, 0, newFirstRecord, 0, result.length);
dupFirst.get(newFirstRecord, result.length, dupFirst.remaining());
dataToUnwrap.position(dataToUnwrap.limit());
byte[] originalFirstRecord = new byte[firstLessFraming.remaining()];
firstLessFraming.get(originalFirstRecord);
ByteBuffer newData = ALPNHackServerHelloExplorer.createNewOutputRecords(newFirstRecord, records);
dataToUnwrap.clear();
dataToUnwrap.put(newData);
dataToUnwrap.flip();
ALPNHackClientByteArrayOutputStream.setReceivedServerHello(originalFirstRecord);
}
}
} catch (BufferUnderflowException e) {
return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0);
}
}
}
SSLEngineResult res = delegate.unwrap(dataToUnwrap, byteBuffers, i, i1);
if(!delegate.getUseClientMode() && selectedApplicationProtocol != null && alpnHackServerByteArrayOutputStream == null) {
alpnHackServerByteArrayOutputStream = replaceServerByteOutput(delegate, selectedApplicationProtocol);
}
return res;
}
@Override
public Runnable getDelegatedTask() {
return delegate.getDelegatedTask();
}
@Override
public void closeInbound() throws SSLException {
delegate.closeInbound();
}
@Override
public boolean isInboundDone() {
return delegate.isInboundDone();
}
@Override
public void closeOutbound() {
delegate.closeOutbound();
}
@Override
public boolean isOutboundDone() {
return delegate.isOutboundDone();
}
@Override
public String[] getSupportedCipherSuites() {
return delegate.getSupportedCipherSuites();
}
@Override
public String[] getEnabledCipherSuites() {
return delegate.getEnabledCipherSuites();
}
@Override
public void setEnabledCipherSuites(String[] strings) {
delegate.setEnabledCipherSuites(strings);
}
@Override
public String[] getSupportedProtocols() {
return delegate.getSupportedProtocols();
}
@Override
public String[] getEnabledProtocols() {
return delegate.getEnabledProtocols();
}
@Override
public void setEnabledProtocols(String[] strings) {
delegate.setEnabledProtocols(strings);
}
@Override
public SSLSession getSession() {
return delegate.getSession();
}
@Override
public void beginHandshake() throws SSLException {
delegate.beginHandshake();
}
@Override
public SSLEngineResult.HandshakeStatus getHandshakeStatus() {
return delegate.getHandshakeStatus();
}
@Override
public void setUseClientMode(boolean b) {
delegate.setUseClientMode(b);
}
@Override
public boolean getUseClientMode() {
return delegate.getUseClientMode();
}
@Override
public void setNeedClientAuth(boolean b) {
delegate.setNeedClientAuth(b);
}
@Override
public boolean getNeedClientAuth() {
return delegate.getNeedClientAuth();
}
@Override
public void setWantClientAuth(boolean b) {
delegate.setWantClientAuth(b);
}
@Override
public boolean getWantClientAuth() {
return delegate.getWantClientAuth();
}
@Override
public void setEnableSessionCreation(boolean b) {
delegate.setEnableSessionCreation(b);
}
@Override
public boolean getEnableSessionCreation() {
return delegate.getEnableSessionCreation();
}
public void setApplicationProtocols(List<String> applicationProtocols) {
this.applicationProtocols = applicationProtocols;
}
public List<String> getApplicationProtocols() {
return applicationProtocols;
}
public String getSelectedApplicationProtocol() {
return selectedApplicationProtocol;
}
static ALPNHackServerByteArrayOutputStream replaceServerByteOutput(SSLEngine sslEngine, String selectedAlpnProtocol) {
try {
Object handshaker = HANDSHAKER.get(sslEngine);
Object hash = HANDSHAKE_HASH.get(handshaker);
ByteArrayOutputStream existing = (ByteArrayOutputStream) HANDSHAKE_HASH_DATA.get(hash);
ALPNHackServerByteArrayOutputStream out = new ALPNHackServerByteArrayOutputStream(sslEngine, existing.toByteArray(), selectedAlpnProtocol);
HANDSHAKE_HASH_DATA.set(hash, out);
return out;
} catch (Exception e) {
UndertowLogger.ROOT_LOGGER.debug("Failed to replace hash output stream ", e);
return null;
}
}
static ALPNHackClientByteArrayOutputStream replaceClientByteOutput(SSLEngine sslEngine) {
try {
Object handshaker = HANDSHAKER.get(sslEngine);
Object hash = HANDSHAKE_HASH.get(handshaker);
ALPNHackClientByteArrayOutputStream out = new ALPNHackClientByteArrayOutputStream(sslEngine);
HANDSHAKE_HASH_DATA.set(hash, out);
return out;
} catch (Exception e) {
UndertowLogger.ROOT_LOGGER.debug("Failed to replace hash output stream ", e);
return null;
}
}
static void regenerateHashes(SSLEngine sslEngineToHack, ByteArrayOutputStream data, byte[]... hashBytes) {
try {
Object handshaker = HANDSHAKER.get(sslEngineToHack);
Object hash = HANDSHAKE_HASH.get(handshaker);
data.reset();
Object protocolVersion = HANDSHAKER_PROTOCOL_VERSION.get(handshaker);
HANDSHAKE_HASH_VERSION.set(hash, -1);
HANDSHAKE_HASH_PROTOCOL_DETERMINED.invoke(hash, protocolVersion);
MessageDigest digest = (MessageDigest) HANDSHAKE_HASH_FIN_MD.get(hash);
digest.reset();
for (byte[] b : hashBytes) {
HANDSHAKE_HASH_UPDATE.invoke(hash, b, 0, b.length);
}
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
}