package io.vertx.core.net.impl;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedFile;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.FutureListener;
import io.vertx.core.*;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.impl.future.PromiseInternal;
import io.vertx.core.impl.VertxInternal;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.net.SocketAddress;
import io.vertx.core.spi.metrics.NetworkMetrics;
import io.vertx.core.spi.metrics.TCPMetrics;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.security.cert.X509Certificate;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.net.InetSocketAddress;
import static io.vertx.core.spi.metrics.Metrics.METRICS_ENABLED;
public abstract class ConnectionBase {
private static final long METRICS_REPORTED_BYTES_LOW_MASK = 0xFFF;
private static final long METRICS_REPORTED_BYTES_HIGH_MASK = ~METRICS_REPORTED_BYTES_LOW_MASK;
public static final VertxException CLOSED_EXCEPTION = new VertxException("Connection was closed", true);
public static final AttributeKey<SocketAddress> REMOTE_ADDRESS_OVERRIDE = AttributeKey.valueOf("RemoteAddressOverride");
public static final AttributeKey<SocketAddress> LOCAL_ADDRESS_OVERRIDE = AttributeKey.valueOf("LocalAddressOverride");
private static final Logger log = LoggerFactory.getLogger(ConnectionBase.class);
private static final int MAX_REGION_SIZE = 1024 * 1024;
public final VoidChannelPromise voidPromise;
protected final VertxInternal vertx;
protected final ChannelHandlerContext chctx;
protected final ContextInternal context;
private Handler<Throwable> exceptionHandler;
private Handler<Void> closeHandler;
private int writeInProgress;
private Object metric;
private SocketAddress remoteAddress;
private SocketAddress localAddress;
private ChannelPromise closePromise;
private Future<Void> closeFuture;
private long remainingBytesRead;
private long remainingBytesWritten;
private boolean read;
private boolean needsFlush;
private boolean closed;
protected ConnectionBase(ContextInternal context, ChannelHandlerContext chctx) {
this.vertx = context.owner();
this.chctx = chctx;
this.context = context;
this.voidPromise = new VoidChannelPromise(chctx.channel(), false);
this.closePromise = chctx.newPromise();
PromiseInternal<Void> p = context.promise();
closePromise.addListener(p);
closeFuture = p.future();
closeFuture.onComplete(this::checkCloseHandler);
}
public Future<Void> closeFuture() {
return closeFuture;
}
public void fail(Throwable error) {
chctx.pipeline().fireExceptionCaught(error);
}
void close(ChannelPromise promise) {
closePromise.addListener(l -> {
if (l.isSuccess()) {
promise.setSuccess();
} else {
promise.setFailure(l.cause());
}
});
close();
}
final void endReadAndFlush() {
if (read) {
read = false;
if (needsFlush) {
needsFlush = false;
chctx.flush();
}
}
}
final void read(Object msg) {
read = true;
if (!closed) {
if (METRICS_ENABLED) {
reportBytesRead(msg);
}
handleMessage(msg);
}
}
private void write(Object msg, Boolean flush, ChannelPromise promise) {
if (METRICS_ENABLED) {
reportsBytesWritten(msg);
}
boolean writeAndFlush;
if (flush == null) {
writeAndFlush = !read;
} else {
writeAndFlush = flush;
}
needsFlush = !writeAndFlush;
if (writeAndFlush) {
chctx.writeAndFlush(msg, promise);
} else {
chctx.write(msg, promise);
}
}
private void writeClose(PromiseInternal<Void> promise) {
if (closed) {
promise.complete();
return;
}
closed = true;
ChannelPromise channelPromise = chctx
.newPromise()
.addListener((ChannelFutureListener) f -> {
chctx.close().addListener(promise);
});
writeToChannel(Unpooled.EMPTY_BUFFER, true, channelPromise);
}
private ChannelPromise wrap(FutureListener<Void> handler) {
ChannelPromise promise = chctx.newPromise();
promise.addListener(handler);
return promise;
}
public final void writeToChannel(Object msg, FutureListener<Void> listener) {
writeToChannel(msg, listener == null ? voidPromise : wrap(listener));
}
public final void writeToChannel(Object msg, ChannelPromise promise) {
writeToChannel(msg, false, promise);
}
public final void writeToChannel(Object msg, boolean forceFlush, ChannelPromise promise) {
synchronized (this) {
if (!chctx.executor().inEventLoop() || writeInProgress > 0) {
queueForWrite(msg, forceFlush, promise);
return;
}
}
write(msg, forceFlush ? true : null, promise);
}
private void queueForWrite(Object msg, boolean forceFlush, ChannelPromise promise) {
writeInProgress++;
chctx.executor().execute(() -> {
boolean flush;
if (forceFlush) {
flush = true;
} else {
synchronized (this) {
flush = --writeInProgress == 0;
}
}
write(msg, flush, promise);
});
}
public void writeToChannel(Object obj) {
writeToChannel(obj, voidPromise);
}
public final void flush() {
flush(voidPromise);
}
public final void flush(ChannelPromise promise) {
writeToChannel(Unpooled.EMPTY_BUFFER, true, promise);
}
public boolean isNotWritable() {
return !chctx.channel().isWritable();
}
public Future<Void> close() {
PromiseInternal<Void> promise = context.promise();
EventExecutor exec = chctx.executor();
if (exec.inEventLoop()) {
writeClose(promise);
} else {
exec.execute(() -> writeClose(promise));
}
return promise.future();
}
public final void close(Handler<AsyncResult<Void>> handler) {
close().onComplete(handler);
}
public synchronized ConnectionBase closeHandler(Handler<Void> handler) {
closeHandler = handler;
return this;
}
public synchronized ConnectionBase exceptionHandler(Handler<Throwable> handler) {
this.exceptionHandler = handler;
return this;
}
protected synchronized Handler<Throwable> exceptionHandler() {
return exceptionHandler;
}
public void doPause() {
chctx.channel().config().setAutoRead(false);
}
public void doResume() {
chctx.channel().config().setAutoRead(true);
}
public void doSetWriteQueueMaxSize(int size) {
ChannelConfig config = chctx.channel().config();
config.setWriteBufferWaterMark(new WriteBufferWaterMark(size / 2, size));
}
public final Channel channel() {
return chctx.channel();
}
public final ChannelHandlerContext channelHandlerContext() {
return chctx;
}
public final ContextInternal getContext() {
return context;
}
public final synchronized void metric(Object metric) {
this.metric = metric;
}
public final synchronized Object metric() {
return metric;
}
public abstract NetworkMetrics metrics();
protected void handleException(Throwable t) {
NetworkMetrics metrics = metrics();
if (metrics != null) {
metrics.exceptionOccurred(metric, remoteAddress(), t);
}
context.emit(t, err -> {
Handler<Throwable> handler;
synchronized (ConnectionBase.this) {
handler = exceptionHandler;
}
if (handler != null) {
handler.handle(err);
} else {
if (log.isDebugEnabled()) {
log.error(t.getMessage(), t);
} else {
log.error(t.getMessage());
}
}
});
}
protected void handleClosed() {
closed = true;
NetworkMetrics metrics = metrics();
if (metrics != null) {
flushBytesRead();
flushBytesWritten();
if (metrics instanceof TCPMetrics) {
((TCPMetrics) metrics).disconnected(metric(), remoteAddress());
}
}
closePromise.setSuccess();
}
private void checkCloseHandler(AsyncResult<Void> ar) {
Handler<Void> handler;
synchronized (ConnectionBase.this) {
handler = closeHandler;
}
if (handler != null) {
handler.handle(null);
}
}
protected void handleIdle() {
chctx.close();
}
protected abstract void handleInterestedOpsChanged();
protected boolean supportsFileRegion() {
return !isSsl();
}
protected void reportBytesRead(Object msg) {
}
public void reportBytesRead(long numberOfBytes) {
if (numberOfBytes < 0L) {
throw new IllegalArgumentException();
}
long bytes = remainingBytesRead;
bytes += numberOfBytes;
NetworkMetrics metrics = metrics();
long val = bytes & METRICS_REPORTED_BYTES_HIGH_MASK;
if (metrics != null && val > 0) {
bytes &= METRICS_REPORTED_BYTES_LOW_MASK;
metrics.bytesRead(metric(), remoteAddress(), val);
}
remainingBytesRead = bytes;
}
protected void reportsBytesWritten(Object msg) {
}
public void reportBytesWritten(long numberOfBytes) {
if (numberOfBytes < 0L) {
throw new IllegalArgumentException();
}
long bytes = remainingBytesWritten;
bytes += numberOfBytes;
NetworkMetrics metrics = metrics();
long val = bytes & METRICS_REPORTED_BYTES_HIGH_MASK;
if (metrics != null && val > 0) {
bytes &= METRICS_REPORTED_BYTES_LOW_MASK;
metrics.bytesWritten(metric, remoteAddress(), val);
}
remainingBytesWritten = bytes;
}
public void flushBytesRead() {
long val = remainingBytesRead;
if (val > 0L) {
NetworkMetrics metrics = metrics();
remainingBytesRead = 0L;
if (metrics != null)
metrics.bytesRead(metric(), remoteAddress(), val);
}
}
public void flushBytesWritten() {
long val = remainingBytesWritten;
if (val > 0L) {
NetworkMetrics metrics = metrics();
remainingBytesWritten = 0L;
if (metrics != null)
metrics.bytesWritten(metric(), remoteAddress(), val);
}
}
private void sendFileRegion(RandomAccessFile file, long offset, long length, ChannelPromise writeFuture) {
if (length < MAX_REGION_SIZE) {
writeToChannel(new DefaultFileRegion(file.getChannel(), offset, length), writeFuture);
} else {
ChannelPromise promise = chctx.newPromise();
FileRegion region = new DefaultFileRegion(file.getChannel(), offset, MAX_REGION_SIZE);
region.retain();
writeToChannel(region, promise);
promise.addListener(future -> {
if (future.isSuccess()) {
sendFileRegion(file, offset + MAX_REGION_SIZE, length - MAX_REGION_SIZE, writeFuture);
} else {
log.error(future.cause().getMessage(), future.cause());
writeFuture.setFailure(future.cause());
}
});
}
}
public final ChannelFuture sendFile(RandomAccessFile raf, long offset, long length) throws IOException {
ChannelPromise writeFuture = chctx.newPromise();
if (!supportsFileRegion()) {
writeToChannel(new ChunkedFile(raf, offset, length, 8192), writeFuture);
} else {
sendFileRegion(raf, offset, length, writeFuture);
}
if (writeFuture != null) {
writeFuture.addListener(fut -> raf.close());
} else {
raf.close();
}
return writeFuture;
}
public boolean isSsl() {
return chctx.pipeline().get(SslHandler.class) != null;
}
public SSLSession sslSession() {
ChannelHandlerContext sslHandlerContext = chctx.pipeline().context(SslHandler.class);
if (sslHandlerContext != null) {
SslHandler sslHandler = (SslHandler) sslHandlerContext.handler();
return sslHandler.engine().getSession();
} else {
return null;
}
}
public X509Certificate[] peerCertificateChain() throws SSLPeerUnverifiedException {
SSLSession session = sslSession();
if (session != null) {
return session.getPeerCertificateChain();
} else {
return null;
}
}
public String indicatedServerName() {
if (chctx.channel().hasAttr(SslHandshakeCompletionHandler.SERVER_NAME_ATTR)) {
return chctx.channel().attr(SslHandshakeCompletionHandler.SERVER_NAME_ATTR).get();
} else {
return null;
}
}
public ChannelPromise channelFuture() {
return chctx.newPromise();
}
public String remoteName() {
java.net.SocketAddress addr = chctx.channel().remoteAddress();
if (addr instanceof InetSocketAddress) {
return ((InetSocketAddress)addr).getHostString();
}
return null;
}
public SocketAddress remoteAddress() {
SocketAddress address = remoteAddress;
if (address == null) {
if (chctx.channel().hasAttr(REMOTE_ADDRESS_OVERRIDE)) {
address = chctx.channel().attr(REMOTE_ADDRESS_OVERRIDE).getAndSet(null);
} else {
java.net.SocketAddress addr = chctx.channel().remoteAddress();
if (addr != null) {
address = vertx.transport().convert(addr);
}
}
if (address != null)
remoteAddress = address;
}
return address;
}
public SocketAddress localAddress() {
SocketAddress address = localAddress;
if (address == null) {
if (chctx.channel().hasAttr(LOCAL_ADDRESS_OVERRIDE)) {
address = chctx.channel().attr(LOCAL_ADDRESS_OVERRIDE).getAndSet(null);
} else {
java.net.SocketAddress addr = chctx.channel().localAddress();
if (addr != null) {
address = vertx.transport().convert(addr);
}
}
if (address != null)
localAddress = address;
}
return address;
}
protected void handleMessage(Object msg) {
}
}