package org.glassfish.grizzly.http2;
import java.io.IOException;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLEngine;
import org.glassfish.grizzly.CloseListener;
import org.glassfish.grizzly.Closeable;
import org.glassfish.grizzly.Connection;
import org.glassfish.grizzly.Grizzly;
import org.glassfish.grizzly.ICloseType;
import org.glassfish.grizzly.Transport;
import org.glassfish.grizzly.npn.AlpnClientNegotiator;
import org.glassfish.grizzly.npn.AlpnServerNegotiator;
import org.glassfish.grizzly.npn.NegotiationSupport;
import org.glassfish.grizzly.ssl.SSLBaseFilter;
import org.glassfish.grizzly.ssl.SSLFilter;
import org.glassfish.grizzly.ssl.SSLUtils;
public class AlpnSupport {
private final static Logger LOGGER = Grizzly.logger(AlpnSupport.class);
private final static Map<SSLEngine, Connection> SSL_TO_CONNECTION_MAP =
new WeakHashMap<>();
private static final AlpnSupport INSTANCE;
static {
boolean isExtensionFound = false;
try {
ClassLoader.getSystemClassLoader().loadClass("sun.security.ssl.GrizzlyNPN");
isExtensionFound = true;
} catch (Throwable e) {
LOGGER.log(Level.FINE, "TLS ALPN extension is not found:", e);
}
INSTANCE = isExtensionFound ? new AlpnSupport() : null;
}
public static boolean isEnabled() {
return INSTANCE != null;
}
public static AlpnSupport getInstance() {
if (!isEnabled()) {
throw new IllegalStateException("TLS ALPN is disabled");
}
return INSTANCE;
}
public static Connection getConnection(final SSLEngine engine) {
synchronized (SSL_TO_CONNECTION_MAP) {
return SSL_TO_CONNECTION_MAP.get(engine);
}
}
private static void setConnection(final SSLEngine engine,
final Connection connection) {
synchronized (SSL_TO_CONNECTION_MAP) {
SSL_TO_CONNECTION_MAP.put(engine, connection);
}
}
private final Map<Object, AlpnServerNegotiator> serverSideNegotiators =
new WeakHashMap<>();
private final ReadWriteLock serverSideLock = new ReentrantReadWriteLock();
private final Map<Object, AlpnClientNegotiator> clientSideNegotiators =
new WeakHashMap<>();
private final ReadWriteLock clientSideLock = new ReentrantReadWriteLock();
private final SSLFilter.HandshakeListener handshakeListener =
new SSLFilter.HandshakeListener() {
@Override
public void onStart(final Connection connection) {
final SSLEngine sslEngine = SSLUtils.getSSLEngine(connection);
assert sslEngine != null;
if (sslEngine.getUseClientMode()) {
AlpnClientNegotiator negotiator;
clientSideLock.readLock().lock();
try {
negotiator = clientSideNegotiators.get(connection);
if (negotiator == null) {
negotiator = clientSideNegotiators.get(connection.getTransport());
}
} finally {
clientSideLock.readLock().unlock();
}
if (negotiator != null) {
connection.addCloseListener(new CloseListener() {
@Override
public void onClosed(Closeable closeable, ICloseType type) throws IOException {
NegotiationSupport.removeClientNegotiator(sslEngine);
}
});
setConnection(sslEngine, connection);
NegotiationSupport.addNegotiator(sslEngine, negotiator);
}
} else {
AlpnServerNegotiator negotiator;
serverSideLock.readLock().lock();
try {
negotiator = serverSideNegotiators.get(connection);
if (negotiator == null) {
negotiator = serverSideNegotiators.get(connection.getTransport());
}
} finally {
serverSideLock.readLock().unlock();
}
if (negotiator != null) {
connection.addCloseListener(new CloseListener() {
@Override
public void onClosed(Closeable closeable, ICloseType type) throws IOException {
NegotiationSupport.removeServerNegotiator(sslEngine);
}
});
setConnection(sslEngine, connection);
NegotiationSupport.addNegotiator(sslEngine, negotiator);
}
}
}
@Override
public void onComplete(final Connection connection) {
}
@Override
public void onFailure(Connection connection, Throwable t) {
}
};
private AlpnSupport() {
}
public void configure(final SSLBaseFilter sslFilter) {
sslFilter.addHandshakeListener(handshakeListener);
}
public void setServerSideNegotiator(final Transport transport,
final AlpnServerNegotiator negotiator) {
putServerSideNegotiator(transport, negotiator);
}
public void setServerSideNegotiator(final Connection connection,
final AlpnServerNegotiator negotiator) {
putServerSideNegotiator(connection, negotiator);
}
public void setClientSideNegotiator(final Transport transport,
final AlpnClientNegotiator negotiator) {
putClientSideNegotiator(transport, negotiator);
}
public void setClientSideNegotiator(final Connection connection,
final AlpnClientNegotiator negotiator) {
putClientSideNegotiator(connection, negotiator);
}
private void putServerSideNegotiator(final Object object,
final AlpnServerNegotiator negotiator) {
serverSideLock.writeLock().lock();
try {
serverSideNegotiators.put(object, negotiator);
} finally {
serverSideLock.writeLock().unlock();
}
}
private void putClientSideNegotiator(final Object object,
final AlpnClientNegotiator negotiator) {
clientSideLock.writeLock().lock();
try {
clientSideNegotiators.put(object, negotiator);
} finally {
clientSideLock.writeLock().unlock();
}
}
}