package org.glassfish.grizzly.ssl;
import static org.glassfish.grizzly.ssl.SSLUtils.isHandshaking;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Filter;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLEngine;
import org.glassfish.grizzly.Buffer;
import org.glassfish.grizzly.CloseType;
import org.glassfish.grizzly.Closeable;
import org.glassfish.grizzly.CompletionHandler;
import org.glassfish.grizzly.Connection;
import org.glassfish.grizzly.FileTransfer;
import org.glassfish.grizzly.GenericCloseListener;
import org.glassfish.grizzly.Grizzly;
import org.glassfish.grizzly.PendingWriteQueueLimitExceededException;
import org.glassfish.grizzly.attributes.Attribute;
import org.glassfish.grizzly.filterchain.FilterChainContext;
import org.glassfish.grizzly.filterchain.FilterChainContext.Operation;
import org.glassfish.grizzly.filterchain.NextAction;
import org.glassfish.grizzly.utils.Exceptions;
import org.glassfish.grizzly.utils.JdkVersion;
public class SSLFilter extends SSLBaseFilter {
private static final Logger LOGGER = Grizzly.logger(SSLFilter.class);
private static final boolean IS_JDK7_OR_HIGHER = JdkVersion.getJdkVersion().compareTo("1.7.0") >= 0;
private final Attribute<SSLHandshakeContext> handshakeContextAttr;
private final SSLEngineConfigurator clientSSLEngineConfigurator;
private final ConnectionCloseListener closeListener = new ConnectionCloseListener();
protected volatile int maxPendingBytes = Integer.MAX_VALUE;
public SSLFilter() {
this(null, null);
}
public SSLFilter(SSLEngineConfigurator serverSSLEngineConfigurator, SSLEngineConfigurator clientSSLEngineConfigurator) {
this(serverSSLEngineConfigurator, clientSSLEngineConfigurator, true);
}
public SSLFilter(SSLEngineConfigurator serverSSLEngineConfigurator, SSLEngineConfigurator clientSSLEngineConfigurator,
boolean renegotiateOnClientAuthWant) {
super(serverSSLEngineConfigurator, renegotiateOnClientAuthWant);
if (clientSSLEngineConfigurator == null) {
this.clientSSLEngineConfigurator = new SSLEngineConfigurator(SSLContextConfigurator.DEFAULT_CONFIG.createSSLContext(true), true, false, false);
} else {
this.clientSSLEngineConfigurator = clientSSLEngineConfigurator;
}
handshakeContextAttr = Grizzly.DEFAULT_ATTRIBUTE_BUILDER.createAttribute("SSLFilter-SSLHandshakeContextAttr");
}
public SSLEngineConfigurator getClientSSLEngineConfigurator() {
return clientSSLEngineConfigurator;
}
@Override
public NextAction handleWrite(final FilterChainContext ctx) throws IOException {
final Connection connection = ctx.getConnection();
if (ctx.getMessage() instanceof FileTransfer) {
throw new IllegalStateException("TLS operations not supported with SendFile messages");
}
synchronized (connection) {
final SSLConnectionContext sslCtx = obtainSslConnectionContext(connection);
final SSLEngine sslEngine = sslCtx.getSslEngine();
if (sslEngine != null && !isHandshaking(sslEngine)) {
return sslCtx.isServerMode() ? super.handleWrite(ctx) : accurateWrite(ctx, true);
} else {
if (sslEngine == null || !handshakeContextAttr.isSet(connection)) {
handshake(connection, null, null, clientSSLEngineConfigurator, ctx, false);
}
return accurateWrite(ctx, false);
}
}
}
public int getMaxPendingBytesPerConnection() {
return maxPendingBytes;
}
public void setMaxPendingBytesPerConnection(final int maxPendingBytes) {
this.maxPendingBytes = maxPendingBytes;
}
public void handshake(final Connection connection, final CompletionHandler<SSLEngine> completionHandler) throws IOException {
handshake(connection, completionHandler, null, clientSSLEngineConfigurator);
}
public void handshake(final Connection connection, final CompletionHandler<SSLEngine> completionHandler, final Object dstAddress) throws IOException {
handshake(connection, completionHandler, dstAddress, clientSSLEngineConfigurator);
}
public void handshake(final Connection connection, final CompletionHandler<SSLEngine> completionHandler, final Object dstAddress,
final SSLEngineConfigurator sslEngineConfigurator) throws IOException {
handshake(connection, completionHandler, dstAddress, sslEngineConfigurator, createContext(connection, Operation.WRITE), true);
}
protected void handshake(final Connection<?> connection, final CompletionHandler<SSLEngine> completionHandler, final Object dstAddress,
final SSLEngineConfigurator sslEngineConfigurator, final FilterChainContext context, final boolean forceBeginHandshake) throws IOException {
final SSLConnectionContext sslCtx = obtainSslConnectionContext(connection);
SSLEngine sslEngine = sslCtx.getSslEngine();
if (sslEngine == null) {
sslEngine = createClientSSLEngine(sslCtx, sslEngineConfigurator);
sslCtx.configure(sslEngine);
} else if (!isHandshaking(sslEngine)) {
sslEngineConfigurator.configure(sslEngine);
}
notifyHandshakeStart(connection);
if (forceBeginHandshake || !sslEngine.getSession().isValid()) {
sslEngine.beginHandshake();
}
handshakeContextAttr.set(connection, new SSLHandshakeContext(connection, completionHandler));
connection.addCloseListener(closeListener);
synchronized (connection) {
final Buffer buffer = doHandshakeStep(sslCtx, context, null);
assert buffer == null;
}
}
private NextAction accurateWrite(final FilterChainContext ctx, final boolean isHandshakeComplete) throws IOException {
final Connection connection = ctx.getConnection();
SSLHandshakeContext handshakeContext = handshakeContextAttr.get(connection);
if (isHandshakeComplete && handshakeContext == null) {
return super.handleWrite(ctx);
} else {
if (handshakeContext == null) {
handshakeContext = new SSLHandshakeContext(connection, null);
handshakeContextAttr.set(connection, handshakeContext);
}
if (!handshakeContext.add(ctx)) {
return super.handleWrite(ctx);
}
}
return ctx.getSuspendAction();
}
@Override
protected void notifyHandshakeComplete(final Connection<?> connection, final SSLEngine sslEngine) {
final SSLHandshakeContext handshakeContext = handshakeContextAttr.get(connection);
if (handshakeContext != null) {
connection.removeCloseListener(closeListener);
handshakeContext.completed(sslEngine);
handshakeContextAttr.remove(connection);
}
super.notifyHandshakeComplete(connection, sslEngine);
}
@Override
protected void notifyHandshakeFailed(Connection connection, Throwable t) {
final SSLHandshakeContext handshakeContext = handshakeContextAttr.get(connection);
if (handshakeContext != null) {
connection.removeCloseListener(closeListener);
handshakeContext.failed(t);
}
super.notifyHandshakeFailed(connection, t);
}
@Override
protected Buffer doHandshakeStep(final SSLConnectionContext sslCtx, final FilterChainContext ctx, final Buffer inputBuffer, final Buffer tmpAppBuffer0)
throws IOException {
try {
return super.doHandshakeStep(sslCtx, ctx, inputBuffer, tmpAppBuffer0);
} catch (IOException ioe) {
SSLHandshakeContext context = handshakeContextAttr.get(ctx.getConnection());
if (context != null) {
context.failed(ioe);
}
throw ioe;
}
}
protected SSLEngine createClientSSLEngine(final SSLConnectionContext sslCtx, final SSLEngineConfigurator sslEngineConfigurator) {
return IS_JDK7_OR_HIGHER ? sslEngineConfigurator.createSSLEngine(HostNameResolver.getPeerHostName(sslCtx.getConnection()), -1)
: sslEngineConfigurator.createSSLEngine();
}
private final class SSLHandshakeContext {
private CompletionHandler<SSLEngine> completionHandler;
private final Connection connection;
private List<FilterChainContext> pendingWriteContexts;
private int sizeInBytes = 0;
private Throwable error;
private boolean isComplete;
public SSLHandshakeContext(final Connection connection, final CompletionHandler<SSLEngine> completionHandler) {
this.connection = connection;
this.completionHandler = completionHandler;
}
public boolean add(FilterChainContext context) throws IOException {
if (error != null) {
throw Exceptions.makeIOException(error);
}
if (isComplete) {
return false;
}
final Buffer buffer = context.getMessage();
final int newSize = sizeInBytes + buffer.remaining();
if (newSize > maxPendingBytes) {
throw new PendingWriteQueueLimitExceededException("Max queued data limit exceeded: " + newSize + '>' + maxPendingBytes);
}
sizeInBytes = newSize;
if (pendingWriteContexts == null) {
pendingWriteContexts = new LinkedList<>();
}
pendingWriteContexts.add(context);
return true;
}
public void completed(final SSLEngine engine) {
try {
synchronized (connection) {
isComplete = true;
final CompletionHandler<SSLEngine> completionHandlerLocal = completionHandler;
completionHandler = null;
if (completionHandlerLocal != null) {
completionHandlerLocal.completed(engine);
}
resumePendingWrites();
}
} catch (Exception e) {
LOGGER.log(Level.FINE, "Unexpected SSLHandshakeContext.completed() error", e);
failed(e);
}
}
public void failed(final Throwable throwable) {
synchronized (connection) {
if (error != null) {
return;
}
error = throwable;
final CompletionHandler<SSLEngine> completionHandlerLocal = completionHandler;
completionHandler = null;
if (completionHandlerLocal != null) {
completionHandlerLocal.failed(throwable);
}
connection.closeWithReason(Exceptions.makeIOException(throwable));
resumePendingWrites();
}
}
private void resumePendingWrites() {
final List<FilterChainContext> pendingWriteContextsLocal = pendingWriteContexts;
pendingWriteContexts = null;
if (pendingWriteContextsLocal != null) {
for (FilterChainContext ctx : pendingWriteContextsLocal) {
try {
ctx.resume();
} catch (Exception e) {
}
}
pendingWriteContextsLocal.clear();
sizeInBytes = 0;
}
}
}
private final class ConnectionCloseListener implements GenericCloseListener {
@Override
public void onClosed(final Closeable closeable, final CloseType type) throws IOException {
final Connection connection = (Connection) closeable;
final SSLHandshakeContext handshakeContext = handshakeContextAttr.get(connection);
if (handshakeContext != null) {
handshakeContext.failed(new java.io.EOFException());
handshakeContextAttr.remove(connection);
}
}
}
private static class HostNameResolver {
public static String getPeerHostName(final Connection<?> connection) {
final Object addr = connection.getPeerAddress();
return addr instanceof InetSocketAddress ? ((InetSocketAddress) addr).getHostString() :
null;
}
}
}