package org.apache.http.nio.reactor.ssl;
import java.io.IOException;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.apache.http.HttpHost;
import org.apache.http.annotation.Contract;
import org.apache.http.annotation.ThreadingBehavior;
import org.apache.http.nio.reactor.EventMask;
import org.apache.http.nio.reactor.IOSession;
import org.apache.http.nio.reactor.SessionBufferStatus;
import org.apache.http.nio.reactor.SocketAccessor;
import org.apache.http.util.Args;
import org.apache.http.util.Asserts;
@Contract(threading = ThreadingBehavior.SAFE_CONDITIONAL)
public class SSLIOSession implements IOSession, SessionBufferStatus, SocketAccessor {
public static final String SESSION_KEY = "http.session.ssl";
private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);
private final IOSession session;
private final SSLEngine sslEngine;
private final SSLBuffer inEncrypted;
private final SSLBuffer outEncrypted;
private final SSLBuffer inPlain;
private final InternalByteChannel channel;
private final SSLSetupHandler handler;
private final AtomicInteger outboundClosedCount;
private int appEventMask;
private SessionBufferStatus appBufferStatus;
private boolean endOfStream;
private volatile SSLMode sslMode;
private volatile int status;
private volatile boolean initialized;
public SSLIOSession(
final IOSession session,
final SSLMode sslMode,
final HttpHost host,
final SSLContext sslContext,
final SSLSetupHandler handler) {
this(session, sslMode, host, sslContext, handler, new PermanentSSLBufferManagementStrategy());
}
public SSLIOSession(
final IOSession session,
final SSLMode sslMode,
final HttpHost host,
final SSLContext sslContext,
final SSLSetupHandler handler,
final SSLBufferManagementStrategy bufferManagementStrategy) {
super();
Args.notNull(session, "IO session");
Args.notNull(sslContext, "SSL context");
Args.notNull(bufferManagementStrategy, "Buffer management strategy");
this.session = session;
this.sslMode = sslMode;
this.appEventMask = session.getEventMask();
this.channel = new InternalByteChannel();
this.handler = handler;
this.session.setBufferStatus(this);
if (this.sslMode == SSLMode.CLIENT && host != null) {
this.sslEngine = sslContext.createSSLEngine(host.getHostName(), host.getPort());
} else {
this.sslEngine = sslContext.createSSLEngine();
}
final int netBuffersize = this.sslEngine.getSession().getPacketBufferSize();
this.inEncrypted = bufferManagementStrategy.constructBuffer(netBuffersize);
this.outEncrypted = bufferManagementStrategy.constructBuffer(netBuffersize);
final int appBuffersize = this.sslEngine.getSession().getApplicationBufferSize();
this.inPlain = bufferManagementStrategy.constructBuffer(appBuffersize);
this.outboundClosedCount = new AtomicInteger(0);
}
public SSLIOSession(
final IOSession session,
final SSLMode sslMode,
final SSLContext sslContext,
final SSLSetupHandler handler) {
this(session, sslMode, null, sslContext, handler);
}
protected SSLSetupHandler getSSLSetupHandler() {
return this.handler;
}
public boolean isInitialized() {
return this.initialized;
}
@Deprecated
public synchronized void initialize(final SSLMode sslMode) throws SSLException {
this.sslMode = sslMode;
initialize();
}
public synchronized void initialize() throws SSLException {
Asserts.check(!this.initialized, "SSL I/O session already initialized");
if (this.status >= IOSession.CLOSING) {
return;
}
switch (this.sslMode) {
case CLIENT:
this.sslEngine.setUseClientMode(true);
break;
case SERVER:
this.sslEngine.setUseClientMode(false);
break;
}
if (this.handler != null) {
try {
this.handler.initalize(this.sslEngine);
} catch (final RuntimeException ex) {
throw convert(ex);
}
}
this.initialized = true;
this.sslEngine.beginHandshake();
this.inEncrypted.release();
this.outEncrypted.release();
this.inPlain.release();
doHandshake();
}
public synchronized SSLSession getSSLSession() {
return this.sslEngine.getSession();
}
private SSLException convert(final RuntimeException ex) {
Throwable cause = ex.getCause();
if (cause == null) {
cause = ex;
}
return new SSLException(cause);
}
private SSLEngineResult doWrap(final ByteBuffer src, final ByteBuffer dst) throws SSLException {
try {
return this.sslEngine.wrap(src, dst);
} catch (final RuntimeException ex) {
throw convert(ex);
}
}
private SSLEngineResult doUnwrap(final ByteBuffer src, final ByteBuffer dst) throws SSLException {
try {
return this.sslEngine.unwrap(src, dst);
} catch (final RuntimeException ex) {
throw convert(ex);
}
}
private void doRunTask() throws SSLException {
try {
final Runnable r = this.sslEngine.getDelegatedTask();
if (r != null) {
r.run();
}
} catch (final RuntimeException ex) {
throw convert(ex);
}
}
private void doHandshake() throws SSLException {
boolean handshaking = true;
SSLEngineResult result = null;
while (handshaking) {
HandshakeStatus handshakeStatus = this.sslEngine.getHandshakeStatus();
if (handshakeStatus == HandshakeStatus.NOT_HANDSHAKING && outboundClosedCount.get() > 0) {
handshakeStatus = HandshakeStatus.NEED_WRAP;
}
switch (handshakeStatus) {
case NEED_WRAP:
final ByteBuffer outEncryptedBuf = this.outEncrypted.acquire();
result = doWrap(ByteBuffer.allocate(0), outEncryptedBuf);
if (result.getStatus() != Status.OK || result.getHandshakeStatus() == HandshakeStatus.NEED_WRAP) {
handshaking = false;
}
break;
case NEED_UNWRAP:
final ByteBuffer inEncryptedBuf = this.inEncrypted.acquire();
final ByteBuffer inPlainBuf = this.inPlain.acquire();
inEncryptedBuf.flip();
try {
result = doUnwrap(inEncryptedBuf, inPlainBuf);
} finally {
inEncryptedBuf.compact();
}
try {
if (!inEncryptedBuf.hasRemaining() && result.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) {
throw new SSLException("Input buffer is full");
}
} finally {
if (inEncryptedBuf.position() == 0) {
this.inEncrypted.release();
}
}
if (this.status >= IOSession.CLOSING) {
this.inPlain.release();
}
if (result.getStatus() != Status.OK) {
handshaking = false;
}
break;
case NEED_TASK:
doRunTask();
break;
case NOT_HANDSHAKING:
handshaking = false;
break;
case FINISHED:
break;
}
}
if (result != null && result.getHandshakeStatus() == HandshakeStatus.FINISHED) {
if (this.handler != null) {
this.handler.verify(this.session, this.sslEngine.getSession());
}
}
}
private void updateEventMask() {
if (this.status == ACTIVE
&& (this.endOfStream || this.sslEngine.isInboundDone())) {
this.status = CLOSING;
}
if (this.status == CLOSING && !this.outEncrypted.hasData()) {
this.sslEngine.closeOutbound();
this.outboundClosedCount.incrementAndGet();
}
if (this.status == CLOSING && this.sslEngine.isOutboundDone()
&& (this.endOfStream || this.sslEngine.isInboundDone())
&& !this.inPlain.hasData()
&& this.appBufferStatus != null && !this.appBufferStatus.hasBufferedInput()) {
this.status = CLOSED;
}
if (this.status <= CLOSING && this.endOfStream
&& this.sslEngine.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) {
this.status = CLOSED;
}
if (this.status == CLOSED) {
this.session.close();
return;
}
final int oldMask = this.session.getEventMask();
int newMask = oldMask;
switch (this.sslEngine.getHandshakeStatus()) {
case NEED_WRAP:
newMask = EventMask.READ_WRITE;
break;
case NEED_UNWRAP:
newMask = EventMask.READ;
break;
case NOT_HANDSHAKING:
newMask = this.appEventMask;
break;
case NEED_TASK:
break;
case FINISHED:
break;
}
if (this.endOfStream &&
!this.inPlain.hasData() &&
(this.appBufferStatus == null || !this.appBufferStatus.hasBufferedInput())) {
newMask = newMask & ~EventMask.READ;
} else if (this.status == CLOSING) {
newMask = newMask | EventMask.READ;
}
if (this.outEncrypted.hasData()) {
newMask = newMask | EventMask.WRITE;
} else if (this.sslEngine.isOutboundDone()) {
newMask = newMask & ~EventMask.WRITE;
}
if (oldMask != newMask) {
this.session.setEventMask(newMask);
}
}
private int sendEncryptedData() throws IOException {
if (!this.outEncrypted.hasData()) {
return this.session.channel().write(EMPTY_BUFFER);
}
final ByteBuffer outEncryptedBuf = this.outEncrypted.acquire();
final int bytesWritten;
outEncryptedBuf.flip();
try {
bytesWritten = this.session.channel().write(outEncryptedBuf);
} finally {
outEncryptedBuf.compact();
}
if (outEncryptedBuf.position() == 0) {
this.outEncrypted.release();
}
return bytesWritten;
}
private int receiveEncryptedData() throws IOException {
if (this.endOfStream) {
return -1;
}
final ByteBuffer inEncryptedBuf = this.inEncrypted.acquire();
final int bytesRead = this.session.channel().read(inEncryptedBuf);
if (inEncryptedBuf.position() == 0) {
this.inEncrypted.release();
}
if (bytesRead == -1) {
this.endOfStream = true;
}
return bytesRead;
}
private boolean decryptData() throws SSLException {
boolean decrypted = false;
while (this.inEncrypted.hasData()) {
final ByteBuffer inEncryptedBuf = this.inEncrypted.acquire();
final ByteBuffer inPlainBuf = this.inPlain.acquire();
final SSLEngineResult result;
inEncryptedBuf.flip();
try {
result = doUnwrap(inEncryptedBuf, inPlainBuf);
} finally {
inEncryptedBuf.compact();
}
try {
if (!inEncryptedBuf.hasRemaining() && result.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) {
throw new SSLException("Unable to complete SSL handshake");
}
final Status status = result.getStatus();
if (status == Status.OK) {
decrypted = true;
} else {
if (status == Status.BUFFER_UNDERFLOW && this.endOfStream) {
throw new SSLException("Unable to decrypt incoming data due to unexpected end of stream");
}
break;
}
} finally {
if (this.inEncrypted.acquire().position() == 0) {
this.inEncrypted.release();
}
}
}
if (this.sslEngine.isInboundDone()) {
this.endOfStream = true;
}
return decrypted;
}
public synchronized boolean isAppInputReady() throws IOException {
do {
receiveEncryptedData();
doHandshake();
final HandshakeStatus status = this.sslEngine.getHandshakeStatus();
if (status == HandshakeStatus.NOT_HANDSHAKING || status == HandshakeStatus.FINISHED) {
decryptData();
}
} while (this.sslEngine.getHandshakeStatus() == HandshakeStatus.NEED_TASK);
return (this.appEventMask & SelectionKey.OP_READ) > 0
&& (this.inPlain.hasData()
|| (this.appBufferStatus != null && this.appBufferStatus.hasBufferedInput())
|| (this.endOfStream && this.status == ACTIVE));
}
public synchronized boolean isAppOutputReady() throws IOException {
return (this.appEventMask & SelectionKey.OP_WRITE) > 0
&& this.status == ACTIVE
&& this.sslEngine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING;
}
public synchronized void inboundTransport() throws IOException {
updateEventMask();
}
public synchronized void outboundTransport() throws IOException {
sendEncryptedData();
doHandshake();
updateEventMask();
}
public synchronized boolean isInboundDone() {
return this.sslEngine.isInboundDone();
}
public synchronized boolean isOutboundDone() {
return this.sslEngine.isOutboundDone();
}
private synchronized int writePlain(final ByteBuffer src) throws IOException {
Args.notNull(src, "Byte buffer");
if (this.status != ACTIVE) {
throw new ClosedChannelException();
}
final ByteBuffer outEncryptedBuf = this.outEncrypted.acquire();
final SSLEngineResult result = doWrap(src, outEncryptedBuf);
if (result.getStatus() == Status.CLOSED) {
this.status = CLOSED;
}
return result.bytesConsumed();
}
private synchronized int readPlain(final ByteBuffer dst) {
Args.notNull(dst, "Byte buffer");
if (this.inPlain.hasData()) {
final ByteBuffer inPlainBuf = this.inPlain.acquire();
inPlainBuf.flip();
final int n = Math.min(inPlainBuf.remaining(), dst.remaining());
for (int i = 0; i < n; i++) {
dst.put(inPlainBuf.get());
}
inPlainBuf.compact();
if (inPlainBuf.position() == 0) {
this.inPlain.release();
}
return n;
}
return this.endOfStream ? -1 : 0;
}
@Override
public synchronized void close() {
if (this.status >= CLOSING) {
return;
}
this.status = CLOSING;
if (this.session.getSocketTimeout() == 0) {
this.session.setSocketTimeout(1000);
}
try {
updateEventMask();
} catch (final CancelledKeyException ex) {
shutdown();
}
}
@Override
public synchronized void shutdown() {
if (this.status == CLOSED) {
return;
}
this.status = CLOSED;
this.session.shutdown();
this.inEncrypted.release();
this.outEncrypted.release();
this.inPlain.release();
}
@Override
public int getStatus() {
return this.status;
}
@Override
public boolean isClosed() {
return this.status >= CLOSING || this.session.isClosed();
}
@Override
public ByteChannel channel() {
return this.channel;
}
@Override
public SocketAddress getLocalAddress() {
return this.session.getLocalAddress();
}
@Override
public SocketAddress getRemoteAddress() {
return this.session.getRemoteAddress();
}
@Override
public synchronized int getEventMask() {
return this.appEventMask;
}
@Override
public synchronized void setEventMask(final int ops) {
this.appEventMask = ops;
updateEventMask();
}
@Override
public synchronized void setEvent(final int op) {
this.appEventMask = this.appEventMask | op;
updateEventMask();
}
@Override
public synchronized void clearEvent(final int op) {
this.appEventMask = this.appEventMask & ~op;
updateEventMask();
}
@Override
public int getSocketTimeout() {
return this.session.getSocketTimeout();
}
@Override
public void setSocketTimeout(final int timeout) {
this.session.setSocketTimeout(timeout);
}
@Override
public synchronized boolean hasBufferedInput() {
return (this.appBufferStatus != null && this.appBufferStatus.hasBufferedInput())
|| this.inEncrypted.hasData()
|| this.inPlain.hasData();
}
@Override
public synchronized boolean hasBufferedOutput() {
return (this.appBufferStatus != null && this.appBufferStatus.hasBufferedOutput())
|| this.outEncrypted.hasData();
}
@Override
public synchronized void setBufferStatus(final SessionBufferStatus status) {
this.appBufferStatus = status;
}
@Override
public Object getAttribute(final String name) {
return this.session.getAttribute(name);
}
@Override
public Object removeAttribute(final String name) {
return this.session.removeAttribute(name);
}
@Override
public void setAttribute(final String name, final Object obj) {
this.session.setAttribute(name, obj);
}
private static void formatOps(final StringBuilder buffer, final int ops) {
if ((ops & SelectionKey.OP_READ) > 0) {
buffer.append('r');
}
if ((ops & SelectionKey.OP_WRITE) > 0) {
buffer.append('w');
}
}
@Override
public String toString() {
final StringBuilder buffer = new StringBuilder();
buffer.append(this.session);
buffer.append("[");
switch (this.status) {
case ACTIVE:
buffer.append("ACTIVE");
break;
case CLOSING:
buffer.append("CLOSING");
break;
case CLOSED:
buffer.append("CLOSED");
break;
}
buffer.append("][");
formatOps(buffer, this.appEventMask);
buffer.append("][");
buffer.append(this.sslEngine.getHandshakeStatus());
if (this.sslEngine.isInboundDone()) {
buffer.append("][inbound done][");
}
if (this.sslEngine.isOutboundDone()) {
buffer.append("][outbound done][");
}
if (this.endOfStream) {
buffer.append("][EOF][");
}
buffer.append("][");
buffer.append(!this.inEncrypted.hasData() ? 0 : inEncrypted.acquire().position());
buffer.append("][");
buffer.append(!this.inPlain.hasData() ? 0 : inPlain.acquire().position());
buffer.append("][");
buffer.append(!this.outEncrypted.hasData() ? 0 : outEncrypted.acquire().position());
buffer.append("]");
return buffer.toString();
}
@Override
public Socket getSocket(){
return this.session instanceof SocketAccessor ? ((SocketAccessor) this.session).getSocket() : null;
}
private class InternalByteChannel implements ByteChannel {
@Override
public int write(final ByteBuffer src) throws IOException {
return SSLIOSession.this.writePlain(src);
}
@Override
public int read(final ByteBuffer dst) throws IOException {
return SSLIOSession.this.readPlain(dst);
}
@Override
public void close() throws IOException {
SSLIOSession.this.close();
}
@Override
public boolean isOpen() {
return !SSLIOSession.this.isClosed();
}
}
}