package com.mongodb.internal.connection;
import com.mongodb.MongoException;
import com.mongodb.MongoInternalException;
import com.mongodb.MongoInterruptedException;
import com.mongodb.MongoSocketReadException;
import com.mongodb.MongoSocketReadTimeoutException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.BufferProvider;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.Stream;
import org.bson.ByteBuf;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.nio.channels.InterruptedByTimeoutException;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import static com.mongodb.assertions.Assertions.isTrue;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
public abstract class AsynchronousChannelStream implements Stream {
private final ServerAddress serverAddress;
private final SocketSettings settings;
private final BufferProvider bufferProvider;
private volatile ExtendedAsynchronousByteChannel channel;
private volatile boolean isClosed;
public AsynchronousChannelStream(final ServerAddress serverAddress, final SocketSettings settings,
final BufferProvider bufferProvider) {
this.serverAddress = serverAddress;
this.settings = settings;
this.bufferProvider = bufferProvider;
}
public ServerAddress getServerAddress() {
return serverAddress;
}
public SocketSettings getSettings() {
return settings;
}
public BufferProvider getBufferProvider() {
return bufferProvider;
}
public ExtendedAsynchronousByteChannel getChannel() {
return channel;
}
protected void setChannel(final ExtendedAsynchronousByteChannel channel) {
isTrue("current channel is null", this.channel == null);
this.channel = channel;
}
@Override
public void writeAsync(final List<ByteBuf> buffers, final AsyncCompletionHandler<Void> handler) {
final AsyncWritableByteChannelAdapter byteChannel = new AsyncWritableByteChannelAdapter();
final Iterator<ByteBuf> iter = buffers.iterator();
pipeOneBuffer(byteChannel, iter.next(), new AsyncCompletionHandler<Void>() {
@Override
public void completed(final Void t) {
if (iter.hasNext()) {
pipeOneBuffer(byteChannel, iter.next(), this);
} else {
handler.completed(null);
}
}
@Override
public void failed(final Throwable t) {
handler.failed(t);
}
});
}
@Override
public void readAsync(final int numBytes, final AsyncCompletionHandler<ByteBuf> handler) {
ByteBuf buffer = bufferProvider.getBuffer(numBytes);
channel.read(buffer.asNIO(), settings.getReadTimeout(MILLISECONDS), MILLISECONDS, null,
new BasicCompletionHandler(buffer, handler));
}
@Override
public void open() throws IOException {
FutureAsyncCompletionHandler<Void> handler = new FutureAsyncCompletionHandler<Void>();
openAsync(handler);
handler.getOpen();
}
@Override
public void write(final List<ByteBuf> buffers) throws IOException {
FutureAsyncCompletionHandler<Void> handler = new FutureAsyncCompletionHandler<Void>();
writeAsync(buffers, handler);
handler.getWrite();
}
@Override
public ByteBuf read(final int numBytes) throws IOException {
FutureAsyncCompletionHandler<ByteBuf> handler = new FutureAsyncCompletionHandler<ByteBuf>();
readAsync(numBytes, handler);
return handler.getRead();
}
@Override
public ServerAddress getAddress() {
return serverAddress;
}
@Override
public void close() {
try {
if (channel != null) {
channel.close();
}
} catch (IOException e) {
} finally {
channel = null;
isClosed = true;
}
}
@Override
public boolean isClosed() {
return isClosed;
}
@Override
public ByteBuf getBuffer(final int size) {
return bufferProvider.getBuffer(size);
}
private void pipeOneBuffer(final AsyncWritableByteChannelAdapter byteChannel, final ByteBuf byteBuffer,
final AsyncCompletionHandler<Void> outerHandler) {
byteChannel.write(byteBuffer.asNIO(), new AsyncCompletionHandler<Void>() {
@Override
public void completed(final Void t) {
if (byteBuffer.hasRemaining()) {
byteChannel.write(byteBuffer.asNIO(), this);
} else {
outerHandler.completed(null);
}
}
@Override
public void failed(final Throwable t) {
outerHandler.failed(t);
}
});
}
private class AsyncWritableByteChannelAdapter {
void write(final ByteBuffer src, final AsyncCompletionHandler<Void> handler) {
channel.write(src, null, new AsyncWritableByteChannelAdapter.WriteCompletionHandler(handler));
}
private class WriteCompletionHandler extends BaseCompletionHandler<Void, Integer, Object> {
WriteCompletionHandler(final AsyncCompletionHandler<Void> handler) {
super(handler);
}
@Override
public void completed(final Integer result, final Object attachment) {
AsyncCompletionHandler<Void> localHandler = getHandlerAndClear();
localHandler.completed(null);
}
@Override
public void failed(final Throwable exc, final Object attachment) {
AsyncCompletionHandler<Void> localHandler = getHandlerAndClear();
localHandler.failed(exc);
}
}
}
private final class BasicCompletionHandler extends BaseCompletionHandler<ByteBuf, Integer, Void> {
private final AtomicReference<ByteBuf> byteBufReference;
private BasicCompletionHandler(final ByteBuf dst, final AsyncCompletionHandler<ByteBuf> handler) {
super(handler);
this.byteBufReference = new AtomicReference<ByteBuf>(dst);
}
@Override
public void completed(final Integer result, final Void attachment) {
AsyncCompletionHandler<ByteBuf> localHandler = getHandlerAndClear();
ByteBuf localByteBuf = byteBufReference.getAndSet(null);
if (result == -1) {
localByteBuf.release();
localHandler.failed(new MongoSocketReadException("Prematurely reached end of stream", serverAddress));
} else if (!localByteBuf.hasRemaining()) {
localByteBuf.flip();
localHandler.completed(localByteBuf);
} else {
channel.read(localByteBuf.asNIO(), settings.getReadTimeout(MILLISECONDS), MILLISECONDS, null,
new BasicCompletionHandler(localByteBuf, localHandler));
}
}
@Override
public void failed(final Throwable t, final Void attachment) {
AsyncCompletionHandler<ByteBuf> localHandler = getHandlerAndClear();
ByteBuf localByteBuf = byteBufReference.getAndSet(null);
localByteBuf.release();
if (t instanceof InterruptedByTimeoutException) {
localHandler.failed(new MongoSocketReadTimeoutException("Timeout while receiving message", serverAddress, t));
} else {
localHandler.failed(t);
}
}
}
private abstract static class BaseCompletionHandler<T, V, A> implements CompletionHandler<V, A> {
private final AtomicReference<AsyncCompletionHandler<T>> handlerReference;
BaseCompletionHandler(final AsyncCompletionHandler<T> handler) {
this.handlerReference = new AtomicReference<AsyncCompletionHandler<T>>(handler);
}
AsyncCompletionHandler<T> getHandlerAndClear() {
return handlerReference.getAndSet(null);
}
}
static class FutureAsyncCompletionHandler<T> implements AsyncCompletionHandler<T> {
private final CountDownLatch latch = new CountDownLatch(1);
private volatile T result;
private volatile Throwable error;
@Override
public void completed(final T result) {
this.result = result;
latch.countDown();
}
@Override
public void failed(final Throwable t) {
this.error = t;
latch.countDown();
}
void getOpen() throws IOException {
get("Opening");
}
void getWrite() throws IOException {
get("Writing to");
}
T getRead() throws IOException {
return get("Reading from");
}
private T get(final String prefix) throws IOException {
try {
latch.await();
} catch (InterruptedException e) {
throw new MongoInterruptedException(prefix + " the AsynchronousSocketChannelStream failed", e);
}
if (error != null) {
if (error instanceof IOException) {
throw (IOException) error;
} else if (error instanceof MongoException) {
throw (MongoException) error;
} else {
throw new MongoInternalException(prefix + " the TlsChannelStream failed", error);
}
}
return result;
}
}
}