package com.mongodb.internal.connection.tlschannel.async;
import com.mongodb.diagnostics.logging.Logger;
import com.mongodb.diagnostics.logging.Loggers;
import com.mongodb.internal.connection.tlschannel.NeedsReadException;
import com.mongodb.internal.connection.tlschannel.NeedsTaskException;
import com.mongodb.internal.connection.tlschannel.NeedsWriteException;
import com.mongodb.internal.connection.tlschannel.TlsChannel;
import com.mongodb.internal.connection.tlschannel.impl.ByteBufferSet;
import com.mongodb.internal.connection.tlschannel.util.Util;
import java.io.IOException;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.InterruptedByTimeoutException;
import java.nio.channels.ReadPendingException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ShutdownChannelGroupException;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritePendingException;
import java.util.Iterator;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.IntBinaryOperator;
import java.util.function.LongConsumer;
import static java.lang.String.format;
public class AsynchronousTlsChannelGroup {
private static final Logger LOGGER = Loggers.getLogger("connection.tls");
private static final int QUEUE_LENGTH_MULTIPLIER = 32;
private static AtomicInteger globalGroupCount = new AtomicInteger();
class RegisteredSocket {
final TlsChannel tlsChannel;
final SocketChannel socketChannel;
final CountDownLatch registered = new CountDownLatch(1);
SelectionKey key;
final Lock readLock = new ReentrantLock();
final Lock writeLock = new ReentrantLock();
ReadOperation readOperation;
WriteOperation writeOperation;
final AtomicInteger pendingOps = new AtomicInteger();
RegisteredSocket(final TlsChannel tlsChannel, final SocketChannel socketChannel) {
this.tlsChannel = tlsChannel;
this.socketChannel = socketChannel;
}
public void close() {
doCancelRead(this, null);
doCancelWrite(this, null);
key.cancel();
currentRegistrations.getAndDecrement();
selector.wakeup();
}
}
private abstract static class Operation {
final ByteBufferSet bufferSet;
final LongConsumer onSuccess;
final Consumer<Throwable> onFailure;
Future<?> timeoutFuture;
Operation(final ByteBufferSet bufferSet, final LongConsumer onSuccess, final Consumer<Throwable> onFailure) {
this.bufferSet = bufferSet;
this.onSuccess = onSuccess;
this.onFailure = onFailure;
}
}
static final class ReadOperation extends Operation {
ReadOperation(final ByteBufferSet bufferSet, final LongConsumer onSuccess, final Consumer<Throwable> onFailure) {
super(bufferSet, onSuccess, onFailure);
}
}
static final class WriteOperation extends Operation {
long consumesBytes = 0;
WriteOperation(final ByteBufferSet bufferSet, final LongConsumer onSuccess, final Consumer<Throwable> onFailure) {
super(bufferSet, onSuccess, onFailure);
}
}
private final int id = globalGroupCount.getAndIncrement();
private final AtomicBoolean loggedTaskWarning = new AtomicBoolean();
private final Selector selector;
final ExecutorService executor;
private final ScheduledThreadPoolExecutor timeoutExecutor = new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
@Override
public Thread newThread(final Runnable runnable) {
return new Thread(runnable, format("async-channel-group-%d-timeout-thread", id));
}
}
);
private final Thread selectorThread = new Thread(new Runnable() {
@Override
public void run() {
AsynchronousTlsChannelGroup.this.loop();
}
}, format("async-channel-group-%d-selector", id));
private final ConcurrentLinkedQueue<RegisteredSocket> pendingRegistrations = new ConcurrentLinkedQueue<RegisteredSocket>();
private enum Shutdown {
No, Wait, Immediate
}
private volatile Shutdown shutdown = Shutdown.No;
private LongAdder selectionCount = new LongAdder();
private LongAdder startedReads = new LongAdder();
private LongAdder startedWrites = new LongAdder();
private LongAdder successfulReads = new LongAdder();
private LongAdder successfulWrites = new LongAdder();
private LongAdder failedReads = new LongAdder();
private LongAdder failedWrites = new LongAdder();
private LongAdder cancelledReads = new LongAdder();
private LongAdder cancelledWrites = new LongAdder();
private AtomicInteger currentRegistrations = new AtomicInteger();
private LongAdder currentReads = new LongAdder();
private LongAdder currentWrites = new LongAdder();
public AsynchronousTlsChannelGroup(final int nThreads) {
try {
selector = Selector.open();
} catch (IOException e) {
throw new RuntimeException(e);
}
timeoutExecutor.setRemoveOnCancelPolicy(true);
this.executor = new ThreadPoolExecutor(
nThreads, nThreads,
0, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(nThreads * QUEUE_LENGTH_MULTIPLIER),
new ThreadFactory() {
@Override
public Thread newThread(final Runnable runnable) {
return new Thread(runnable, format("async-channel-group-%d-handler-executor", id));
}
},
new ThreadPoolExecutor.CallerRunsPolicy());
selectorThread.start();
}
public AsynchronousTlsChannelGroup() {
this(Runtime.getRuntime().availableProcessors());
}
RegisteredSocket registerSocket(final TlsChannel reader, final SocketChannel socketChannel) {
if (shutdown != Shutdown.No) {
throw new ShutdownChannelGroupException();
}
RegisteredSocket socket = new RegisteredSocket(reader, socketChannel);
currentRegistrations.getAndIncrement();
pendingRegistrations.add(socket);
selector.wakeup();
return socket;
}
boolean doCancelRead(final RegisteredSocket socket, final ReadOperation op) {
socket.readLock.lock();
try {
if (op != null && socket.readOperation == op || op == null && socket.readOperation != null) {
socket.readOperation = null;
cancelledReads.increment();
currentReads.decrement();
return true;
} else {
return false;
}
} finally {
socket.readLock.unlock();
}
}
boolean doCancelWrite(final RegisteredSocket socket, final WriteOperation op) {
socket.writeLock.lock();
try {
if (op != null && socket.writeOperation == op || op == null && socket.writeOperation != null) {
socket.writeOperation = null;
cancelledWrites.increment();
currentWrites.decrement();
return true;
} else {
return false;
}
} finally {
socket.writeLock.unlock();
}
}
ReadOperation startRead(
final RegisteredSocket socket,
final ByteBufferSet buffer,
final long timeout, final TimeUnit unit,
final LongConsumer onSuccess, final Consumer<Throwable> onFailure)
throws ReadPendingException {
checkTerminated();
Util.assertTrue(buffer.hasRemaining());
waitForSocketRegistration(socket);
ReadOperation op;
socket.readLock.lock();
try {
if (socket.readOperation != null) {
throw new ReadPendingException();
}
op = new ReadOperation(buffer, onSuccess, onFailure);
final ReadOperation finalOp = op;
socket.pendingOps.set(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
if (timeout != 0) {
op.timeoutFuture = timeoutExecutor.schedule(new Runnable() {
@Override
public void run() {
boolean success = AsynchronousTlsChannelGroup.this.doCancelRead(socket, finalOp);
if (success) {
finalOp.onFailure.accept(new InterruptedByTimeoutException());
}
}
}, timeout, unit);
}
socket.readOperation = op;
} finally {
socket.readLock.unlock();
}
selector.wakeup();
startedReads.increment();
currentReads.increment();
return op;
}
WriteOperation startWrite(
final RegisteredSocket socket,
final ByteBufferSet buffer,
final long timeout, final TimeUnit unit,
final LongConsumer onSuccess, final Consumer<Throwable> onFailure)
throws WritePendingException {
checkTerminated();
Util.assertTrue(buffer.hasRemaining());
waitForSocketRegistration(socket);
WriteOperation op;
socket.writeLock.lock();
try {
if (socket.writeOperation != null) {
throw new WritePendingException();
}
op = new WriteOperation(buffer, onSuccess, onFailure);
final WriteOperation finalOp = op;
socket.pendingOps.set(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
if (timeout != 0) {
op.timeoutFuture = timeoutExecutor.schedule(new Runnable() {
@Override
public void run() {
boolean success = AsynchronousTlsChannelGroup.this.doCancelWrite(socket, finalOp);
if (success) {
finalOp.onFailure.accept(new InterruptedByTimeoutException());
}
}
}, timeout, unit);
}
socket.writeOperation = op;
} finally {
socket.writeLock.unlock();
}
selector.wakeup();
startedWrites.increment();
currentWrites.increment();
return op;
}
private void checkTerminated() {
if (isTerminated()) {
throw new ShutdownChannelGroupException();
}
}
private void waitForSocketRegistration(final RegisteredSocket socket) {
try {
socket.registered.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
private void loop() {
try {
while (shutdown == Shutdown.No || shutdown == Shutdown.Wait && currentRegistrations.intValue() > 0) {
int c = selector.select();
selectionCount.increment();
if (c > 0) {
Iterator<SelectionKey> it = selector.selectedKeys().iterator();
while (it.hasNext()) {
SelectionKey key = it.next();
it.remove();
try {
key.interestOps(0);
} catch (CancelledKeyException e) {
continue;
}
RegisteredSocket socket = (RegisteredSocket) key.attachment();
processRead(socket);
processWrite(socket);
}
}
registerPendingSockets();
processPendingInterests();
}
} catch (Throwable e) {
LOGGER.error("error in selector loop", e);
} finally {
executor.shutdown();
timeoutExecutor.shutdownNow();
if (shutdown == Shutdown.Immediate) {
for (SelectionKey key : selector.keys()) {
RegisteredSocket socket = (RegisteredSocket) key.attachment();
socket.close();
}
}
try {
selector.close();
} catch (IOException e) {
LOGGER.warn(format("error closing selector: %s", e.getMessage()));
}
}
}
private void processPendingInterests() {
for (SelectionKey key : selector.keys()) {
RegisteredSocket socket = (RegisteredSocket) key.attachment();
int pending = socket.pendingOps.getAndSet(0);
if (pending != 0) {
key.interestOps(key.interestOps() | pending);
}
}
}
private void processWrite(final RegisteredSocket socket) {
socket.writeLock.lock();
try {
final WriteOperation op = socket.writeOperation;
if (op != null) {
executor.execute(new Runnable() {
@Override
public void run() {
try {
AsynchronousTlsChannelGroup.this.doWrite(socket, op);
} catch (Throwable e) {
LOGGER.error("error in operation", e);
}
}
});
}
} finally {
socket.writeLock.unlock();
}
}
private void processRead(final RegisteredSocket socket) {
socket.readLock.lock();
try {
final ReadOperation op = socket.readOperation;
if (op != null) {
executor.execute(new Runnable() {
@Override
public void run() {
try {
AsynchronousTlsChannelGroup.this.doRead(socket, op);
} catch (Throwable e) {
LOGGER.error("error in operation", e);
}
}
});
}
} finally {
socket.readLock.unlock();
}
}
private void doWrite(final RegisteredSocket socket, final WriteOperation op) {
socket.writeLock.lock();
try {
if (socket.writeOperation != op) {
return;
}
try {
long before = op.bufferSet.remaining();
try {
writeHandlingTasks(socket, op);
} finally {
long c = before - op.bufferSet.remaining();
Util.assertTrue(c >= 0);
op.consumesBytes += c;
}
socket.writeOperation = null;
if (op.timeoutFuture != null) {
op.timeoutFuture.cancel(false);
}
op.onSuccess.accept(op.consumesBytes);
successfulWrites.increment();
currentWrites.decrement();
} catch (NeedsReadException e) {
socket.pendingOps.accumulateAndGet(SelectionKey.OP_READ, new IntBinaryOperator() {
@Override
public int applyAsInt(final int a, final int b) {
return a | b;
}
});
selector.wakeup();
} catch (NeedsWriteException e) {
socket.pendingOps.accumulateAndGet(SelectionKey.OP_WRITE, new IntBinaryOperator() {
@Override
public int applyAsInt(final int a, final int b) {
return a | b;
}
});
selector.wakeup();
} catch (IOException e) {
if (socket.writeOperation == op) {
socket.writeOperation = null;
}
if (op.timeoutFuture != null) {
op.timeoutFuture.cancel(false);
}
op.onFailure.accept(e);
failedWrites.increment();
currentWrites.decrement();
}
} finally {
socket.writeLock.unlock();
}
}
private void writeHandlingTasks(final RegisteredSocket socket, final WriteOperation op) throws IOException {
while (true) {
try {
socket.tlsChannel.write(op.bufferSet.array, op.bufferSet.offset, op.bufferSet.length);
return;
} catch (NeedsTaskException e) {
warnAboutNeedTask();
e.getTask().run();
}
}
}
private void warnAboutNeedTask() {
if (!loggedTaskWarning.getAndSet(true)) {
LOGGER.warn(format(
"caught %s; channels used in asynchronous groups should run tasks themselves; "
+ "although task is being dealt with anyway, consider configuring channels properly",
NeedsTaskException.class.getName()));
}
}
private void doRead(final RegisteredSocket socket, final ReadOperation op) {
socket.readLock.lock();
try {
if (socket.readOperation != op) {
return;
}
try {
Util.assertTrue(op.bufferSet.hasRemaining());
long c = readHandlingTasks(socket, op);
Util.assertTrue(c > 0 || c == -1);
socket.readOperation = null;
if (op.timeoutFuture != null) {
op.timeoutFuture.cancel(false);
}
op.onSuccess.accept(c);
successfulReads.increment();
currentReads.decrement();
} catch (NeedsReadException e) {
socket.pendingOps.accumulateAndGet(SelectionKey.OP_READ, new IntBinaryOperator() {
@Override
public int applyAsInt(final int a, final int b) {
return a | b;
}
});
selector.wakeup();
} catch (NeedsWriteException e) {
socket.pendingOps.accumulateAndGet(SelectionKey.OP_WRITE, new IntBinaryOperator() {
@Override
public int applyAsInt(final int a, final int b) {
return a | b;
}
});
selector.wakeup();
} catch (IOException e) {
if (socket.readOperation == op) {
socket.readOperation = null;
}
if (op.timeoutFuture != null) {
op.timeoutFuture.cancel(false);
}
op.onFailure.accept(e);
failedReads.increment();
currentReads.decrement();
}
} finally {
socket.readLock.unlock();
}
}
private long readHandlingTasks(final RegisteredSocket socket, final ReadOperation op) throws IOException {
while (true) {
try {
return socket.tlsChannel.read(op.bufferSet.array, op.bufferSet.offset, op.bufferSet.length);
} catch (NeedsTaskException e) {
warnAboutNeedTask();
e.getTask().run();
}
}
}
private void registerPendingSockets() throws ClosedChannelException {
RegisteredSocket socket;
while ((socket = pendingRegistrations.poll()) != null) {
socket.key = socket.socketChannel.register(selector, 0, socket);
if (LOGGER.isTraceEnabled()) {
LOGGER.trace(format("registered key: %ss", socket.key));
}
socket.registered.countDown();
}
}
public boolean isShutdown() {
return shutdown != Shutdown.No;
}
public void shutdown() {
shutdown = Shutdown.Wait;
selector.wakeup();
}
public void shutdownNow() {
shutdown = Shutdown.Immediate;
selector.wakeup();
}
public boolean isTerminated() {
return executor.isTerminated();
}
public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException {
return executor.awaitTermination(timeout, unit);
}
long getSelectionCount() {
return selectionCount.longValue();
}
public long getStartedReadCount() {
return startedReads.longValue();
}
public long getStartedWriteCount() {
return startedWrites.longValue();
}
public long getSuccessfulReadCount() {
return successfulReads.longValue();
}
public long getSuccessfulWriteCount() {
return successfulWrites.longValue();
}
public long getFailedReadCount() {
return failedReads.longValue();
}
public long getFailedWriteCount() {
return failedWrites.longValue();
}
public long getCancelledReadCount() {
return cancelledReads.longValue();
}
public long getCancelledWriteCount() {
return cancelledWrites.longValue();
}
public long getCurrentReadCount() {
return currentReads.longValue();
}
public long getCurrentWriteCount() {
return currentWrites.longValue();
}
public long getCurrentRegistrationCount() {
return currentRegistrations.longValue();
}
}