package com.mongodb.connection.netty;
import com.mongodb.MongoClientException;
import com.mongodb.MongoException;
import com.mongodb.MongoInternalException;
import com.mongodb.MongoInterruptedException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.MongoSocketReadTimeoutException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.connection.Stream;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.ReadTimeoutException;
import io.netty.util.concurrent.EventExecutor;
import org.bson.ByteBuf;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.io.IOException;
import java.net.SocketAddress;
import java.security.NoSuchAlgorithmException;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
import static com.mongodb.internal.connection.SslHelper.enableSni;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
final class NettyStream implements Stream {
private static final String READ_HANDLER_NAME = "ReadTimeoutHandler";
private final ServerAddress address;
private final SocketSettings settings;
private final SslSettings sslSettings;
private final EventLoopGroup workerGroup;
private final Class<? extends SocketChannel> socketChannelClass;
private final ByteBufAllocator allocator;
private volatile boolean isClosed;
private volatile Channel channel;
private final LinkedList<io.netty.buffer.ByteBuf> pendingInboundBuffers = new LinkedList<io.netty.buffer.ByteBuf>();
private volatile PendingReader pendingReader;
private volatile Throwable pendingException;
NettyStream(final ServerAddress address, final SocketSettings settings, final SslSettings sslSettings, final EventLoopGroup workerGroup,
final Class<? extends SocketChannel> socketChannelClass, final ByteBufAllocator allocator) {
this.address = address;
this.settings = settings;
this.sslSettings = sslSettings;
this.workerGroup = workerGroup;
this.socketChannelClass = socketChannelClass;
this.allocator = allocator;
}
@Override
public ByteBuf getBuffer(final int size) {
return new NettyByteBuf(allocator.buffer(size, size));
}
@Override
public void open() throws IOException {
FutureAsyncCompletionHandler<Void> handler = new FutureAsyncCompletionHandler<Void>();
openAsync(handler);
handler.get();
}
@Override
public void openAsync(final AsyncCompletionHandler<Void> handler) {
initializeChannel(handler, new LinkedList<SocketAddress>(address.getSocketAddresses()));
}
@SuppressWarnings("deprecation")
private void initializeChannel(final AsyncCompletionHandler<Void> handler, final Queue<SocketAddress> socketAddressQueue) {
if (socketAddressQueue.isEmpty()) {
handler.failed(new MongoSocketException("Exception opening socket", getAddress()));
} else {
SocketAddress nextAddress = socketAddressQueue.poll();
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup);
bootstrap.channel(socketChannelClass);
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.getConnectTimeout(MILLISECONDS));
bootstrap.option(ChannelOption.TCP_NODELAY, true);
bootstrap.option(ChannelOption.SO_KEEPALIVE, settings.isKeepAlive());
if (settings.getReceiveBufferSize() > 0) {
bootstrap.option(ChannelOption.SO_RCVBUF, settings.getReceiveBufferSize());
}
if (settings.getSendBufferSize() > 0) {
bootstrap.option(ChannelOption.SO_SNDBUF, settings.getSendBufferSize());
}
bootstrap.option(ChannelOption.ALLOCATOR, allocator);
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(final SocketChannel ch) {
if (sslSettings.isEnabled()) {
SSLEngine engine = getSslContext().createSSLEngine(address.getHost(), address.getPort());
engine.setUseClientMode(true);
SSLParameters sslParameters = engine.getSSLParameters();
enableSni(address.getHost(), sslParameters);
if (!sslSettings.isInvalidHostNameAllowed()) {
enableHostNameVerification(sslParameters);
}
engine.setSSLParameters(sslParameters);
ch.pipeline().addFirst("ssl", new SslHandler(engine, false));
}
int readTimeout = settings.getReadTimeout(MILLISECONDS);
if (readTimeout > 0) {
ch.pipeline().addLast(READ_HANDLER_NAME, new ReadTimeoutHandler(readTimeout));
}
ch.pipeline().addLast(new InboundBufferHandler());
}
});
final ChannelFuture channelFuture = bootstrap.connect(nextAddress);
channelFuture.addListener(new OpenChannelFutureListener(socketAddressQueue, channelFuture, handler));
}
}
@Override
public void write(final List<ByteBuf> buffers) throws IOException {
FutureAsyncCompletionHandler<Void> future = new FutureAsyncCompletionHandler<Void>();
writeAsync(buffers, future);
future.get();
}
@Override
public ByteBuf read(final int numBytes) throws IOException {
FutureAsyncCompletionHandler<ByteBuf> future = new FutureAsyncCompletionHandler<ByteBuf>();
readAsync(numBytes, future);
return future.get();
}
@Override
public void writeAsync(final List<ByteBuf> buffers, final AsyncCompletionHandler<Void> handler) {
CompositeByteBuf composite = PooledByteBufAllocator.DEFAULT.compositeBuffer();
for (ByteBuf cur : buffers) {
composite.addComponent(true, ((NettyByteBuf) cur).asByteBuf());
}
channel.writeAndFlush(composite).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(final ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
handler.failed(future.cause());
} else {
handler.completed(null);
}
}
});
}
@Override
public void readAsync(final int numBytes, final AsyncCompletionHandler<ByteBuf> handler) {
scheduleReadTimeout();
ByteBuf buffer = null;
Throwable exceptionResult = null;
synchronized (this) {
exceptionResult = pendingException;
if (exceptionResult == null) {
if (!hasBytesAvailable(numBytes)) {
pendingReader = new PendingReader(numBytes, handler);
} else {
CompositeByteBuf composite = allocator.compositeBuffer(pendingInboundBuffers.size());
int bytesNeeded = numBytes;
for (Iterator<io.netty.buffer.ByteBuf> iter = pendingInboundBuffers.iterator(); iter.hasNext();) {
io.netty.buffer.ByteBuf next = iter.next();
int bytesNeededFromCurrentBuffer = Math.min(next.readableBytes(), bytesNeeded);
if (bytesNeededFromCurrentBuffer == next.readableBytes()) {
composite.addComponent(next);
iter.remove();
} else {
next.retain();
composite.addComponent(next.readSlice(bytesNeededFromCurrentBuffer));
}
composite.writerIndex(composite.writerIndex() + bytesNeededFromCurrentBuffer);
bytesNeeded -= bytesNeededFromCurrentBuffer;
if (bytesNeeded == 0) {
break;
}
}
buffer = new NettyByteBuf(composite).flip();
}
}
}
if (exceptionResult != null) {
disableReadTimeout();
handler.failed(exceptionResult);
}
if (buffer != null) {
disableReadTimeout();
handler.completed(buffer);
}
}
private boolean hasBytesAvailable(final int numBytes) {
int bytesAvailable = 0;
for (io.netty.buffer.ByteBuf cur : pendingInboundBuffers) {
bytesAvailable += cur.readableBytes();
if (bytesAvailable >= numBytes) {
return true;
}
}
return false;
}
private void handleReadResponse(final io.netty.buffer.ByteBuf buffer, final Throwable t) {
PendingReader localPendingReader = null;
synchronized (this) {
if (buffer != null) {
pendingInboundBuffers.add(buffer.retain());
} else {
pendingException = t;
}
if (pendingReader != null) {
localPendingReader = pendingReader;
pendingReader = null;
}
}
if (localPendingReader != null) {
readAsync(localPendingReader.numBytes, localPendingReader.handler);
}
}
@Override
public ServerAddress getAddress() {
return address;
}
@Override
public void close() {
isClosed = true;
if (channel != null) {
channel.close();
channel = null;
}
for (Iterator<io.netty.buffer.ByteBuf> iterator = pendingInboundBuffers.iterator(); iterator.hasNext();) {
io.netty.buffer.ByteBuf nextByteBuf = iterator.next();
iterator.remove();
nextByteBuf.release();
}
}
@Override
public boolean isClosed() {
return isClosed;
}
public SocketSettings getSettings() {
return settings;
}
public SslSettings getSslSettings() {
return sslSettings;
}
public EventLoopGroup getWorkerGroup() {
return workerGroup;
}
public Class<? extends SocketChannel> getSocketChannelClass() {
return socketChannelClass;
}
public ByteBufAllocator getAllocator() {
return allocator;
}
private SSLContext getSslContext() {
try {
return (sslSettings.getContext() == null) ? SSLContext.getDefault() : sslSettings.getContext();
} catch (NoSuchAlgorithmException e) {
throw new MongoClientException("Unable to create default SSLContext", e);
}
}
private class InboundBufferHandler extends SimpleChannelInboundHandler<io.netty.buffer.ByteBuf> {
@Override
protected void channelRead0(final ChannelHandlerContext ctx, final io.netty.buffer.ByteBuf buffer) throws Exception {
handleReadResponse(buffer, null);
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable t) {
if (t instanceof ReadTimeoutException) {
handleReadResponse(null, new MongoSocketReadTimeoutException("Timeout while receiving message", address, t));
} else {
handleReadResponse(null, t);
}
ctx.close();
}
}
private static final class PendingReader {
private final int numBytes;
private final AsyncCompletionHandler<ByteBuf> handler;
private PendingReader(final int numBytes, final AsyncCompletionHandler<ByteBuf> handler) {
this.numBytes = numBytes;
this.handler = handler;
}
}
private static final class FutureAsyncCompletionHandler<T> implements AsyncCompletionHandler<T> {
private final CountDownLatch latch = new CountDownLatch(1);
private volatile T t;
private volatile Throwable throwable;
FutureAsyncCompletionHandler() {
}
@Override
public void completed(final T t) {
this.t = t;
latch.countDown();
}
@Override
public void failed(final Throwable t) {
this.throwable = t;
latch.countDown();
}
public T get() throws IOException {
try {
latch.await();
if (throwable != null) {
if (throwable instanceof IOException) {
throw (IOException) throwable;
} else if (throwable instanceof MongoException) {
throw (MongoException) throwable;
} else {
throw new MongoInternalException("Exception thrown from Netty Stream", throwable);
}
}
return t;
} catch (InterruptedException e) {
throw new MongoInterruptedException("Interrupted", e);
}
}
}
private class OpenChannelFutureListener implements ChannelFutureListener {
private final Queue<SocketAddress> socketAddressQueue;
private final ChannelFuture channelFuture;
private final AsyncCompletionHandler<Void> handler;
OpenChannelFutureListener(final Queue<SocketAddress> socketAddressQueue, final ChannelFuture channelFuture,
final AsyncCompletionHandler<Void> handler) {
this.socketAddressQueue = socketAddressQueue;
this.channelFuture = channelFuture;
this.handler = handler;
}
@Override
public void operationComplete(final ChannelFuture future) {
if (future.isSuccess()) {
channel = channelFuture.channel();
channel.closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(final ChannelFuture future) {
handleReadResponse(null, new IOException("The connection to the server was closed"));
}
});
handler.completed(null);
} else {
if (socketAddressQueue.isEmpty()) {
handler.failed(new MongoSocketOpenException("Exception opening socket", getAddress(), future.cause()));
} else {
initializeChannel(handler, socketAddressQueue);
}
}
}
}
private void scheduleReadTimeout() {
adjustTimeout(false);
}
private void disableReadTimeout() {
adjustTimeout(true);
}
private void adjustTimeout(final boolean disable) {
ChannelHandler timeoutHandler = channel.pipeline().get(READ_HANDLER_NAME);
if (timeoutHandler != null) {
final ReadTimeoutHandler readTimeoutHandler = (ReadTimeoutHandler) timeoutHandler;
final ChannelHandlerContext handlerContext = channel.pipeline().context(timeoutHandler);
EventExecutor executor = handlerContext.executor();
if (disable) {
if (executor.inEventLoop()) {
readTimeoutHandler.removeTimeout(handlerContext);
} else {
executor.submit(new Runnable() {
@Override
public void run() {
readTimeoutHandler.removeTimeout(handlerContext);
}
});
}
} else {
if (executor.inEventLoop()) {
readTimeoutHandler.scheduleTimeout(handlerContext);
} else {
executor.submit(new Runnable() {
@Override
public void run() {
readTimeoutHandler.scheduleTimeout(handlerContext);
}
});
}
}
}
}
}