package org.glassfish.grizzly.http2;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLEngine;
import org.glassfish.grizzly.CloseListener;
import org.glassfish.grizzly.CloseType;
import org.glassfish.grizzly.Closeable;
import org.glassfish.grizzly.Connection;
import org.glassfish.grizzly.Grizzly;
import org.glassfish.grizzly.Transport;
import org.glassfish.grizzly.npn.AlpnClientNegotiator;
import org.glassfish.grizzly.npn.AlpnServerNegotiator;
import org.glassfish.grizzly.npn.NegotiationSupport;
import org.glassfish.grizzly.ssl.SSLBaseFilter;
import org.glassfish.grizzly.ssl.SSLBaseFilter.HandshakeListener;
import org.glassfish.grizzly.ssl.SSLUtils;
public class AlpnSupport {
private static final Logger LOGGER = Grizzly.logger(AlpnSupport.class);
private static final Map<SSLEngine, Connection<?>> SSL_TO_CONNECTION_MAP = new WeakHashMap<>();
private static final AlpnSupport INSTANCE;
private static final AplnExtensionCompatibility COMPATIBILITY;
static {
COMPATIBILITY = AplnExtensionCompatibility.getInstance();
LOGGER.config(() -> "Detected ALPN compatibility info: " + COMPATIBILITY);
INSTANCE = COMPATIBILITY.isAlpnExtensionAvailable() ? new AlpnSupport() : null;
}
public static boolean isEnabled() {
return INSTANCE != null;
}
public static AlpnSupport getInstance() {
if (!isEnabled()) {
throw new IllegalStateException("TLS ALPN is disabled");
}
return INSTANCE;
}
public static Connection<?> getConnection(final SSLEngine engine) {
synchronized (SSL_TO_CONNECTION_MAP) {
return SSL_TO_CONNECTION_MAP.get(engine);
}
}
private static void setConnection(final SSLEngine engine, final Connection<?> connection) {
synchronized (SSL_TO_CONNECTION_MAP) {
SSL_TO_CONNECTION_MAP.put(engine, connection);
}
}
private final Map<Object, AlpnServerNegotiator> serverSideNegotiators = new WeakHashMap<>();
private final ReadWriteLock serverSideLock = new ReentrantReadWriteLock();
private final Map<Object, AlpnClientNegotiator> clientSideNegotiators = new WeakHashMap<>();
private final ReadWriteLock clientSideLock = new ReentrantReadWriteLock();
private final HandshakeListener handshakeListener = new HandshakeListener() {
@Override
public void onInit(final Connection<?> connection, final SSLEngine sslEngine) {
assert sslEngine != null;
if (sslEngine.getUseClientMode()) {
return;
}
if (!COMPATIBILITY.isProtocolSelectorSetterInImpl()) {
return;
}
final AlpnServerNegotiator negotiator = getServerNegotiator(connection);
if (negotiator == null) {
return;
}
final Method setter = COMPATIBILITY.getProtocolSelectorSetter(sslEngine);
try {
setter.invoke(sslEngine, negotiator);
} catch (Exception ex) {
LOGGER.log(Level.SEVERE, "Couldn't execute " + setter, ex);
}
}
@Override
public void onStart(final Connection<?> connection) {
final SSLEngine sslEngine = SSLUtils.getSSLEngine(connection);
assert sslEngine != null;
if (sslEngine.getUseClientMode()) {
AlpnClientNegotiator negotiator = getClientNegotiator(connection);
if (negotiator != null) {
connection.addCloseListener(new CloseListener<Closeable, CloseType>() {
@Override
public void onClosed(Closeable closeable, CloseType type) throws IOException {
NegotiationSupport.removeAlpnClientNegotiator(sslEngine);
SSL_TO_CONNECTION_MAP.remove(sslEngine);
}
});
setConnection(sslEngine, connection);
NegotiationSupport.addNegotiator(sslEngine, negotiator);
}
} else {
AlpnServerNegotiator negotiator = getServerNegotiator(connection);
if (negotiator != null) {
connection.addCloseListener(new CloseListener<Closeable, CloseType>() {
@Override
public void onClosed(Closeable closeable, CloseType type) throws IOException {
NegotiationSupport.removeAlpnServerNegotiator(sslEngine);
SSL_TO_CONNECTION_MAP.remove(sslEngine);
}
});
setConnection(sslEngine, connection);
NegotiationSupport.addNegotiator(sslEngine, negotiator);
}
}
}
@Override
public void onComplete(final Connection<?> connection) {
}
@Override
public void onFailure(Connection<?> connection, Throwable t) {
}
};
private AlpnSupport() {
}
public void configure(final SSLBaseFilter sslFilter) {
sslFilter.addHandshakeListener(handshakeListener);
}
public void setServerSideNegotiator(final Transport transport, final AlpnServerNegotiator negotiator) {
putServerSideNegotiator(transport, negotiator);
}
public void setServerSideNegotiator(final Connection<?> connection, final AlpnServerNegotiator negotiator) {
putServerSideNegotiator(connection, negotiator);
}
public void setClientSideNegotiator(final Transport transport, final AlpnClientNegotiator negotiator) {
putClientSideNegotiator(transport, negotiator);
}
public void setClientSideNegotiator(final Connection<?> connection, final AlpnClientNegotiator negotiator) {
putClientSideNegotiator(connection, negotiator);
}
private void putServerSideNegotiator(final Object object, final AlpnServerNegotiator negotiator) {
serverSideLock.writeLock().lock();
try {
serverSideNegotiators.put(object, negotiator);
} finally {
serverSideLock.writeLock().unlock();
}
}
private void putClientSideNegotiator(final Object object, final AlpnClientNegotiator negotiator) {
clientSideLock.writeLock().lock();
try {
clientSideNegotiators.put(object, negotiator);
} finally {
clientSideLock.writeLock().unlock();
}
}
private AlpnClientNegotiator getClientNegotiator(Connection<?> connection) {
AlpnClientNegotiator negotiator;
clientSideLock.readLock().lock();
try {
negotiator = clientSideNegotiators.get(connection);
if (negotiator == null) {
negotiator = clientSideNegotiators.get(connection.getTransport());
}
} finally {
clientSideLock.readLock().unlock();
}
return negotiator;
}
private AlpnServerNegotiator getServerNegotiator(Connection<?> connection) {
AlpnServerNegotiator negotiator;
serverSideLock.readLock().lock();
try {
negotiator = serverSideNegotiators.get(connection);
if (negotiator == null) {
negotiator = serverSideNegotiators.get(connection.getTransport());
}
} finally {
serverSideLock.readLock().unlock();
}
return negotiator;
}
}