package org.xnio.ssl;
import static org.xnio.IoUtils.safeClose;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.Set;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLSession;
import org.xnio.ChannelListener;
import org.xnio.ChannelListeners;
import org.xnio.Option;
import org.xnio.Options;
import org.xnio.Pool;
import org.xnio.SslClientAuthMode;
import org.xnio.StreamConnection;
import org.xnio.conduits.StreamSinkConduit;
import org.xnio.conduits.StreamSourceConduit;
public final class JsseSslStreamConnection extends SslConnection {
private final StreamConnection connection;
private final JsseSslConduitEngine sslConduitEngine;
private volatile boolean tls;
private final ChannelListener.SimpleSetter<SslConnection> handshakeSetter = new ChannelListener.SimpleSetter<SslConnection>();
public JsseSslStreamConnection(StreamConnection connection, SSLEngine sslEngine, final boolean startTls) {
this(connection, sslEngine, JsseXnioSsl.bufferPool, JsseXnioSsl.bufferPool, startTls);
}
JsseSslStreamConnection(StreamConnection connection, SSLEngine sslEngine, final Pool<ByteBuffer> socketBufferPool, final Pool<ByteBuffer> applicationBufferPool, final boolean startTls) {
super(connection.getIoThread());
this.connection = connection;
final StreamSinkConduit sinkConduit = connection.getSinkChannel().getConduit();
final StreamSourceConduit sourceConduit = connection.getSourceChannel().getConduit();
sslConduitEngine = new JsseSslConduitEngine(this, sinkConduit, sourceConduit, sslEngine, socketBufferPool, applicationBufferPool);
tls = ! startTls;
setSinkConduit(new JsseSslStreamSinkConduit(sinkConduit, sslConduitEngine, tls));
setSourceConduit(new JsseSslStreamSourceConduit(sourceConduit, sslConduitEngine, tls));
}
@Override
public synchronized void startHandshake() throws IOException {
if (! tls) {
tls = true;
((JsseSslStreamSourceConduit) getSourceChannel().getConduit()).enableTls();
((JsseSslStreamSinkConduit) getSinkChannel().getConduit()).enableTls();
}
sslConduitEngine.beginHandshake();
}
@Override
public SocketAddress getPeerAddress() {
return connection.getPeerAddress();
}
@Override
public SocketAddress getLocalAddress() {
return connection.getLocalAddress();
}
@Override
protected void closeAction() throws IOException {
if (tls) {
try {
getSinkChannel().getConduit().truncateWrites();
} catch (IOException e) {
try {
getSourceChannel().getConduit().terminateReads();
} catch (IOException ignored) {
}
safeClose(connection);
throw e;
}
try {
getSourceChannel().getConduit().terminateReads();
} catch (IOException e) {
safeClose(connection);
throw e;
}
}
connection.close();
}
@Override
protected void notifyWriteClosed() {}
@Override
protected void notifyReadClosed() {}
@Override
public <T> T setOption(final Option<T> option, final T value) throws IllegalArgumentException, IOException {
if (option == Options.SSL_CLIENT_AUTH_MODE) {
final SSLEngine engine = sslConduitEngine.getEngine();
try {
return option.cast(engine.getNeedClientAuth() ? SslClientAuthMode.REQUIRED : engine.getWantClientAuth() ? SslClientAuthMode.REQUESTED : SslClientAuthMode.NOT_REQUESTED);
} finally {
engine.setNeedClientAuth(value == SslClientAuthMode.REQUIRED);
engine.setWantClientAuth(value == SslClientAuthMode.REQUESTED);
}
} else if (option == Options.SECURE) {
throw new IllegalArgumentException();
} else {
return connection.setOption(option, value);
}
}
@Override
public <T> T getOption(final Option<T> option) throws IOException {
if (option == Options.SSL_CLIENT_AUTH_MODE) {
final SSLEngine engine = sslConduitEngine.getEngine();
return option.cast(engine.getNeedClientAuth() ? SslClientAuthMode.REQUIRED : engine.getWantClientAuth() ? SslClientAuthMode.REQUESTED : SslClientAuthMode.NOT_REQUESTED);
} else {
return option == Options.SECURE ? option.cast(Boolean.valueOf(tls)) : connection.getOption(option);
}
}
private static final Set<Option<?>> SUPPORTED_OPTIONS = Option.setBuilder().add(Options.SECURE, Options.SSL_CLIENT_AUTH_MODE).create();
@Override
public boolean supportsOption(final Option<?> option) {
return SUPPORTED_OPTIONS.contains(option) || connection.supportsOption(option);
}
@Override
public SSLSession getSslSession() {
return tls ? sslConduitEngine.getSession() : null;
}
@Override
public org.xnio.ChannelListener.Setter<? extends SslConnection> getHandshakeSetter() {
return handshakeSetter;
}
SSLEngine getEngine() {
return sslConduitEngine.getEngine();
}
protected boolean readClosed() {
return super.readClosed();
}
protected boolean writeClosed() {
return super.writeClosed();
}
protected void handleHandshakeFinished() {
final ChannelListener<? super SslConnection> listener = handshakeSetter.get();
if (listener == null) {
return;
}
ChannelListeners.<SslConnection>invokeChannelListener(this, listener);
}
}