package io.vertx.core.net.impl;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.ChannelGroupFuture;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.concurrent.GlobalEventExecutor;
import io.vertx.core.*;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.impl.VertxInternal;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;
import io.vertx.core.net.NetServer;
import io.vertx.core.net.NetServerOptions;
import io.vertx.core.net.NetSocket;
import io.vertx.core.net.SocketAddress;
import io.vertx.core.spi.metrics.Metrics;
import io.vertx.core.spi.metrics.MetricsProvider;
import io.vertx.core.spi.metrics.TCPMetrics;
import io.vertx.core.spi.metrics.VertxMetrics;
import io.vertx.core.streams.ReadStream;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
public class NetServerImpl implements Closeable, MetricsProvider, NetServer {
private static final Logger log = LoggerFactory.getLogger(NetServerImpl.class);
protected final VertxInternal vertx;
protected final NetServerOptions options;
protected final ContextInternal creatingContext;
protected final SSLHelper sslHelper;
protected final boolean logEnabled;
private final Map<Channel, NetSocketImpl> socketMap = new ConcurrentHashMap<>();
private final VertxEventLoopGroup availableWorkers = new VertxEventLoopGroup();
private final HandlerManager<Handlers> handlerManager = new HandlerManager<>(availableWorkers);
private final NetSocketStream connectStream = new NetSocketStream();
private ChannelGroup serverChannelGroup;
private long demand = Long.MAX_VALUE;
private volatile boolean listening;
private Handler<NetSocket> registeredHandler;
private volatile ServerID id;
private NetServerImpl actualServer;
private AsyncResolveConnectHelper bindFuture;
private volatile int actualPort;
private ContextInternal listenContext;
private TCPMetrics metrics;
private Handler<NetSocket> handler;
private Handler<Void> endHandler;
private Handler<Throwable> exceptionHandler;
public NetServerImpl(VertxInternal vertx, NetServerOptions options) {
this.vertx = vertx;
this.options = new NetServerOptions(options);
this.sslHelper = new SSLHelper(options, options.getKeyCertOptions(), options.getTrustOptions());
this.creatingContext = vertx.getContext();
this.logEnabled = options.getLogActivity();
if (creatingContext != null) {
if (creatingContext.isMultiThreadedWorkerContext()) {
throw new IllegalStateException("Cannot use NetServer in a multi-threaded worker verticle");
}
creatingContext.addCloseHook(this);
}
}
private synchronized void pauseAccepting() {
demand = 0L;
}
private synchronized void resumeAccepting() {
demand = Long.MAX_VALUE;
}
private synchronized void fetchAccepting(long amount) {
if (amount > 0L) {
demand += amount;
if (demand < 0L) {
demand = Long.MAX_VALUE;
}
}
}
protected synchronized boolean accept() {
boolean accept = demand > 0L;
if (accept && demand != Long.MAX_VALUE) {
demand--;
}
return accept;
}
protected boolean isListening() {
return listening;
}
@Override
public synchronized Handler<NetSocket> connectHandler() {
return handler;
}
@Override
public synchronized NetServer connectHandler(Handler<NetSocket> handler) {
if (isListening()) {
throw new IllegalStateException("Cannot set connectHandler when server is listening");
}
this.handler = handler;
return this;
}
@Override
public synchronized NetServer exceptionHandler(Handler<Throwable> handler) {
if (isListening()) {
throw new IllegalStateException("Cannot set exceptionHandler when server is listening");
}
this.exceptionHandler = handler;
return this;
}
protected void initChannel(ChannelPipeline pipeline) {
if (logEnabled) {
pipeline.addLast("logging", new LoggingHandler());
}
if (sslHelper.isSSL()) {
pipeline.addLast("chunkedWriter", new ChunkedWriteHandler());
}
if (options.getIdleTimeout() > 0) {
pipeline.addLast("idle", new IdleStateHandler(0, 0, options.getIdleTimeout(), options.getIdleTimeoutUnit()));
}
}
public synchronized void listen(Handler<NetSocket> handler, SocketAddress socketAddress, Handler<AsyncResult<Void>> listenHandler) {
if (handler == null) {
throw new IllegalStateException("Set connect handler first");
}
if (listening) {
throw new IllegalStateException("Listen already called");
}
listening = true;
listenContext = vertx.getOrCreateContext();
registeredHandler = handler;
synchronized (vertx.sharedNetServers()) {
this.actualPort = socketAddress.port();
String hostOrPath = socketAddress.host() != null ? socketAddress.host() : socketAddress.path();
id = new ServerID(actualPort, hostOrPath);
NetServerImpl shared = vertx.sharedNetServers().get(id);
if (shared == null || actualPort == 0) {
serverChannelGroup = new DefaultChannelGroup("vertx-acceptor-channels", GlobalEventExecutor.INSTANCE);
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(availableWorkers);
sslHelper.validate(vertx);
bootstrap.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
if (!accept()) {
ch.close();
return;
}
HandlerHolder<Handlers> handler = handlerManager.chooseHandler(ch.eventLoop());
if (handler != null) {
if (sslHelper.isSSL()) {
ch.pipeline().addFirst("handshaker", new SslHandshakeCompletionHandler(ar -> {
if (ar.succeeded()) {
connected(handler, ch);
} else {
Handler<Throwable> exceptionHandler = handler.handler.exceptionHandler;
if (exceptionHandler != null) {
handler.context.executeFromIO(v -> {
exceptionHandler.handle(ar.cause());
});
} else {
log.error("Client from origin " + ch.remoteAddress() + " failed to connect over ssl: " + ar.cause());
}
}
}));
if (options.isSni()) {
SniHandler sniHandler = new SniHandler(sslHelper.serverNameMapper(vertx));
ch.pipeline().addFirst("ssl", sniHandler);
} else {
SslHandler sslHandler = new SslHandler(sslHelper.createEngine(vertx));
sslHandler.setHandshakeTimeout(sslHelper.getSslHandshakeTimeout(), sslHelper.getSslHandshakeTimeoutUnit());
ch.pipeline().addFirst("ssl", sslHandler);
}
} else {
connected(handler, ch);
}
}
}
});
applyConnectionOptions(socketAddress.path() != null, bootstrap);
handlerManager.addHandler(new Handlers(this, handler, exceptionHandler), listenContext);
try {
bindFuture = AsyncResolveConnectHelper.doBind(vertx, socketAddress, bootstrap);
bindFuture.addListener(res -> {
if (res.succeeded()) {
Channel ch = res.result();
log.trace("Net server listening on " + (hostOrPath) + ":" + ch.localAddress());
if (NetServerImpl.this.actualPort != -1) {
NetServerImpl.this.actualPort = ((InetSocketAddress)ch.localAddress()).getPort();
}
NetServerImpl.this.id = new ServerID(NetServerImpl.this.actualPort, id.host);
serverChannelGroup.add(ch);
vertx.sharedNetServers().put(id, NetServerImpl.this);
VertxMetrics metrics = vertx.metricsSPI();
if (metrics != null) {
this.metrics = metrics.createNetServerMetrics(options, new SocketAddressImpl(id.port, id.host));
}
} else {
vertx.sharedNetServers().remove(id);
}
});
} catch (Throwable t) {
if (listenHandler != null) {
vertx.runOnContext(v -> listenHandler.handle(Future.failedFuture(t)));
} else {
log.error(t);
}
listening = false;
return;
}
if (actualPort != 0) {
vertx.sharedNetServers().put(id, this);
}
actualServer = this;
} else {
actualServer = shared;
this.actualPort = shared.actualPort();
VertxMetrics metrics = vertx.metricsSPI();
this.metrics = metrics != null ? metrics.createNetServerMetrics(options, new SocketAddressImpl(id.port, id.host)) : null;
actualServer.handlerManager.addHandler(new Handlers(this, handler, exceptionHandler), listenContext);
}
actualServer.bindFuture.addListener(res -> {
if (listenHandler != null) {
AsyncResult<Void> ares;
if (res.succeeded()) {
ares = Future.succeededFuture();
} else {
listening = false;
ares = Future.failedFuture(res.cause());
}
listenContext.runOnContext(v -> listenHandler.handle(ares));
} else if (res.failed()) {
log.error("Failed to listen", res.cause());
listening = false;
}
});
}
return;
}
public synchronized void close() {
close(null);
}
@Override
public NetServer listen(int port, String host) {
return listen(port, host, null);
}
@Override
public NetServer listen(int port) {
return listen(port, "0.0.0.0", null);
}
@Override
public NetServer listen(int port, Handler<AsyncResult<NetServer>> listenHandler) {
return listen(port, "0.0.0.0", listenHandler);
}
@Override
public NetServer listen(SocketAddress localAddress) {
return listen(localAddress, null);
}
@Override
public synchronized NetServer listen(SocketAddress localAddress, Handler<AsyncResult<NetServer>> listenHandler) {
listen(handler, localAddress, ar -> {
if (listenHandler != null) {
listenHandler.handle(ar.map(this));
}
});
return this;
}
@Override
public NetServer listen() {
listen((Handler<AsyncResult<NetServer>>) null);
return this;
}
@Override
public NetServer listen(int port, String host, Handler<AsyncResult<NetServer>> listenHandler) {
return listen(SocketAddress.inetSocketAddress(port, host), listenHandler);
}
@Override
public synchronized NetServer listen(Handler<AsyncResult<NetServer>> listenHandler) {
return listen(options.getPort(), options.getHost(), listenHandler);
}
@Override
public ReadStream<NetSocket> connectStream() {
return connectStream;
}
public void closeAll(Handler<AsyncResult<Void>> handler) {
List<Handlers> list = handlerManager.handlers();
List<Future> futures = list.stream()
.<Future<Void>>map(handlers -> Future.future(handlers.server::close))
.collect(Collectors.toList());
CompositeFuture fut = CompositeFuture.all(futures);
fut.setHandler(ar -> handler.handle(ar.mapEmpty()));
}
@Override
public synchronized void close(Handler<AsyncResult<Void>> completionHandler) {
if (creatingContext != null) {
creatingContext.removeCloseHook(this);
}
Handler<AsyncResult<Void>> done;
if (endHandler != null) {
Handler<Void> handler = endHandler;
endHandler = null;
done = event -> {
if (event.succeeded()) {
handler.handle(event.result());
}
if (completionHandler != null) {
completionHandler.handle(event);
}
};
} else {
done = completionHandler;
}
ContextInternal context = vertx.getOrCreateContext();
if (!listening) {
if (done != null) {
executeCloseDone(context, done, null);
}
return;
}
listening = false;
synchronized (vertx.sharedNetServers()) {
if (actualServer != null) {
actualServer.handlerManager.removeHandler(new Handlers(this, registeredHandler, exceptionHandler), listenContext);
if (actualServer.handlerManager.hasHandlers()) {
if (done != null) {
executeCloseDone(context, done, null);
}
} else {
actualServer.actualClose(context, done);
}
} else {
context.runOnContext(v -> {
done.handle(Future.succeededFuture());
});
}
}
}
public synchronized boolean isClosed() {
return !listening;
}
public synchronized int actualPort() {
return actualPort;
}
@Override
public boolean isMetricsEnabled() {
return metrics != null;
}
@Override
public Metrics getMetrics() {
return metrics;
}
private void actualClose(ContextInternal closeContext, Handler<AsyncResult<Void>> done) {
if (id != null) {
vertx.sharedNetServers().remove(id);
}
ContextInternal currCon = vertx.getContext();
for (NetSocketImpl sock : socketMap.values()) {
sock.close();
}
if (vertx.getContext() != currCon) {
throw new IllegalStateException("Context was changed");
}
ChannelGroupFuture fut = serverChannelGroup.close();
fut.addListener(cg -> {
if (metrics != null) {
metrics.close();
}
executeCloseDone(closeContext, done, fut.cause());
});
}
private void connected(HandlerHolder<Handlers> handler, Channel ch) {
NetServerImpl.this.initChannel(ch.pipeline());
VertxHandler<NetSocketImpl> nh = VertxHandler.<NetSocketImpl>create(handler.context, ctx -> new NetSocketImpl(vertx, ctx, handler.context, sslHelper, metrics));
nh.addHandler(conn -> socketMap.put(ch, conn));
nh.removeHandler(conn -> socketMap.remove(ch));
ch.pipeline().addLast("handler", nh);
NetSocketImpl sock = nh.getConnection();
handler.context.executeFromIO(v -> {
if (metrics != null) {
sock.metric(metrics.connected(sock.remoteAddress(), sock.remoteName()));
}
sock.registerEventBusHandler();
handler.handler.connectionHandler.handle(sock);
});
}
private void executeCloseDone(ContextInternal closeContext, Handler<AsyncResult<Void>> done, Exception e) {
if (done != null) {
Future<Void> fut = e == null ? Future.succeededFuture() : Future.failedFuture(e);
closeContext.runOnContext(v -> done.handle(fut));
}
}
private void applyConnectionOptions(boolean domainSocket, ServerBootstrap bootstrap) {
vertx.transport().configure(options, domainSocket, bootstrap);
}
@Override
protected void finalize() throws Throwable {
close();
super.finalize();
}
private class NetSocketStream implements ReadStream<NetSocket> {
@Override
public NetSocketStream handler(Handler<NetSocket> handler) {
connectHandler(handler);
return this;
}
@Override
public NetSocketStream pause() {
pauseAccepting();
return this;
}
@Override
public NetSocketStream resume() {
resumeAccepting();
return this;
}
@Override
public ReadStream<NetSocket> fetch(long amount) {
fetchAccepting(amount);
return this;
}
@Override
public NetSocketStream endHandler(Handler<Void> handler) {
synchronized (NetServerImpl.this) {
endHandler = handler;
return this;
}
}
@Override
public NetSocketStream exceptionHandler(Handler<Throwable> handler) {
return this;
}
}
static class Handlers {
final NetServer server;
final Handler<NetSocket> connectionHandler;
final Handler<Throwable> exceptionHandler;
public Handlers(NetServer server, Handler<NetSocket> connectionHandler, Handler<Throwable> exceptionHandler) {
this.server = server;
this.connectionHandler = connectionHandler;
this.exceptionHandler = exceptionHandler;
}
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Handlers that = (Handlers) o;
if (!Objects.equals(connectionHandler, that.connectionHandler)) return false;
if (!Objects.equals(exceptionHandler, that.exceptionHandler)) return false;
return true;
}
public int hashCode() {
int result = 0;
if (connectionHandler != null) {
result = 31 * result + connectionHandler.hashCode();
}
if (exceptionHandler != null) {
result = 31 * result + exceptionHandler.hashCode();
}
return result;
}
}
}