/*
* Copyright (c) 2011-2019 Contributors to the Eclipse Foundation
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
* which is available at https://www.apache.org/licenses/LICENSE-2.0.
*
* SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
*/
package io.vertx.core.http.impl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.vertx.codegen.annotations.Nullable;
import io.vertx.core.AsyncResult;
import io.vertx.core.Promise;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.eventbus.EventBus;
import io.vertx.core.eventbus.Message;
import io.vertx.core.eventbus.MessageConsumer;
import io.vertx.core.http.HttpConnection;
import io.vertx.core.http.WebSocketBase;
import io.vertx.core.http.WebSocketFrame;
import io.vertx.core.http.impl.ws.WebSocketFrameImpl;
import io.vertx.core.http.impl.ws.WebSocketFrameInternal;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.impl.future.PromiseInternal;
import io.vertx.core.net.SocketAddress;
import io.vertx.core.net.impl.ConnectionBase;
import io.vertx.core.streams.impl.InboundBuffer;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.security.cert.X509Certificate;
import java.util.UUID;
import static io.vertx.core.net.impl.VertxHandler.safeBuffer;
This class is optimised for performance when used on the same event loop. However it can be used safely from other threads.
The internal state is protected using the synchronized keyword. If always used on the same event loop, then
we benefit from biased locking which makes the overhead of synchronized near zero.
Author: Tim Fox Type parameters: - <S> – self return type
/**
* This class is optimised for performance when used on the same event loop. However it can be used safely from other threads.
* <p>
* The internal state is protected using the synchronized keyword. If always used on the same event loop, then
* we benefit from biased locking which makes the overhead of synchronized near zero.
*
* @author <a href="http://tfox.org">Tim Fox</a>
* @param <S> self return type
*/
public abstract class WebSocketImplBase<S extends WebSocketBase> implements WebSocketInternal {
private final boolean supportsContinuation;
private final String textHandlerID;
private final String binaryHandlerID;
private final int maxWebSocketFrameSize;
private final int maxWebSocketMessageSize;
private final InboundBuffer<WebSocketFrameInternal> pending;
private ChannelHandlerContext chctx;
protected final ContextInternal context;
private MessageConsumer binaryHandlerRegistration;
private MessageConsumer textHandlerRegistration;
private String subProtocol;
private Object metric;
private Handler<Buffer> handler;
private Handler<WebSocketFrameInternal> frameHandler;
private Handler<Buffer> pongHandler;
private Handler<Void> drainHandler;
private Handler<Throwable> exceptionHandler;
private Handler<Void> closeHandler;
private Handler<Void> endHandler;
protected final Http1xConnectionBase conn;
private boolean writable;
private boolean closed;
private Short closeStatusCode;
private String closeReason;
private MultiMap headers;
WebSocketImplBase(ContextInternal context, Http1xConnectionBase conn, boolean supportsContinuation,
int maxWebSocketFrameSize, int maxWebSocketMessageSize) {
this.supportsContinuation = supportsContinuation;
this.textHandlerID = "__vertx.ws." + UUID.randomUUID().toString();
this.binaryHandlerID = "__vertx.ws." + UUID.randomUUID().toString();
this.conn = conn;
this.context = context;
this.maxWebSocketFrameSize = maxWebSocketFrameSize;
this.maxWebSocketMessageSize = maxWebSocketMessageSize;
this.pending = new InboundBuffer<>(context);
this.writable = !conn.isNotWritable();
this.chctx = conn.channelHandlerContext();
pending.handler(this::receiveFrame);
pending.drainHandler(v -> conn.doResume());
}
void registerHandler(EventBus eventBus) {
Handler<Message<Buffer>> binaryHandler = msg -> writeBinaryFrameInternal(msg.body());
Handler<Message<String>> textHandler = msg -> writeTextFrameInternal(msg.body());
binaryHandlerRegistration = eventBus.<Buffer>localConsumer(binaryHandlerID).handler(binaryHandler);
textHandlerRegistration = eventBus.<String>localConsumer(textHandlerID).handler(textHandler);
}
@Override
public ChannelHandlerContext channelHandlerContext() {
return chctx;
}
@Override
public HttpConnection connection() {
return conn;
}
public String binaryHandlerID() {
return binaryHandlerID;
}
public String textHandlerID() {
return textHandlerID;
}
public boolean writeQueueFull() {
synchronized (conn) {
checkClosed();
return conn.isNotWritable();
}
}
@Override
public Future<Void> close() {
return close((short) 1000, (String) null);
}
@Override
public void close(Handler<AsyncResult<Void>> handler) {
Future<Void> future = close();
if (handler != null) {
future.onComplete(handler);
}
}
@Override
public Future<Void> close(short statusCode) {
return close(statusCode, (String) null);
}
@Override
public void close(short statusCode, Handler<AsyncResult<Void>> handler) {
Future<Void> future = close(statusCode, (String) null);
if (handler != null) {
future.onComplete(handler);
}
}
@Override
public void close(short statusCode, @Nullable String reason, Handler<AsyncResult<Void>> handler) {
Future<Void> fut = close(statusCode, reason);
if (handler != null) {
fut.onComplete(handler);
}
}
@Override
public Future<Void> close(short statusCode, String reason) {
boolean sendCloseFrame;
synchronized (conn) {
if (sendCloseFrame = closeStatusCode == null) {
closeStatusCode = statusCode;
closeReason = reason;
}
}
if (sendCloseFrame) {
// Close the WebSocket by sending a close frame with specified payload
ByteBuf byteBuf = HttpUtils.generateWSCloseFrameByteBuf(statusCode, reason);
CloseWebSocketFrame frame = new CloseWebSocketFrame(true, 0, byteBuf);
PromiseInternal<Void> promise = context.promise();
conn.writeToChannel(frame, promise);
return promise;
} else {
return context.succeededFuture();
}
}
@Override
public boolean isSsl() {
return conn.isSsl();
}
@Override
public SSLSession sslSession() {
return conn.sslSession();
}
@Override
public X509Certificate[] peerCertificateChain() throws SSLPeerUnverifiedException {
return conn.peerCertificateChain();
}
@Override
public SocketAddress localAddress() {
return conn.localAddress();
}
@Override
public SocketAddress remoteAddress() {
return conn.remoteAddress();
}
@Override
public Future<Void> writeFinalTextFrame(String text) {
Promise<Void> promise = context.promise();
writeFinalTextFrame(text, promise);
return promise.future();
}
@Override
public S writeFinalTextFrame(String text, Handler<AsyncResult<Void>> handler) {
return writeFrame(WebSocketFrame.textFrame(text, true), handler);
}
@Override
public Future<Void> writeFinalBinaryFrame(Buffer data) {
Promise<Void> promise = context.promise();
writeFinalBinaryFrame(data, promise);
return promise.future();
}
@Override
public S writeFinalBinaryFrame(Buffer data, Handler<AsyncResult<Void>> handler) {
return writeFrame(WebSocketFrame.binaryFrame(data, true), handler);
}
@Override
public String subProtocol() {
synchronized(conn) {
return subProtocol;
}
}
void subProtocol(String subProtocol) {
synchronized (conn) {
this.subProtocol = subProtocol;
}
}
@Override
public Short closeStatusCode() {
synchronized (conn) {
return closeStatusCode;
}
}
@Override
public String closeReason() {
synchronized (conn) {
return closeReason;
}
}
@Override
public MultiMap headers() {
synchronized(conn) {
return headers;
}
}
void headers(MultiMap responseHeaders) {
synchronized(conn) {
this.headers = responseHeaders;
}
}
@Override
public Future<Void> writeBinaryMessage(Buffer data) {
return writePartialMessage(FrameType.BINARY, data, 0);
}
@Override
public final S writeBinaryMessage(Buffer data, Handler<AsyncResult<Void>> handler) {
Future<Void> fut = writeBinaryMessage(data);
if (handler != null) {
fut.onComplete(handler);
}
return (S) this;
}
@Override
public Future<Void> writeTextMessage(String text) {
return writePartialMessage(FrameType.TEXT, Buffer.buffer(text), 0);
}
@Override
public final S writeTextMessage(String text, Handler<AsyncResult<Void>> handler) {
Future<Void> fut = writeTextMessage(text);
if (handler != null) {
fut.onComplete(handler);
}
return (S) this;
}
@Override
public Future<Void> write(Buffer data) {
return writeFrame(WebSocketFrame.binaryFrame(data, true));
}
@Override
public final void write(Buffer data, Handler<AsyncResult<Void>> handler) {
Future<Void> fut = write(data);
if (handler != null) {
fut.onComplete(handler);
}
}
@Override
public Future<Void> writePing(Buffer data) {
if (data.length() > maxWebSocketFrameSize || data.length() > 125) {
return context.failedFuture("Ping cannot exceed maxWebSocketFrameSize or 125 bytes");
}
return writeFrame(WebSocketFrame.pingFrame(data));
}
@Override
public final WebSocketBase writePing(Buffer data, Handler<AsyncResult<Void>> handler) {
Future<Void> fut = writePing(data);
if (handler != null) {
fut.onComplete(handler);
}
return (S) this;
}
@Override
public Future<Void> writePong(Buffer data) {
if (data.length() > maxWebSocketFrameSize || data.length() > 125) {
return context.failedFuture("Pong cannot exceed maxWebSocketFrameSize or 125 bytes");
}
return writeFrame(WebSocketFrame.pongFrame(data));
}
@Override
public final WebSocketBase writePong(Buffer data, Handler<AsyncResult<Void>> handler) {
Future<Void> fut = writePong(data);
if (handler != null) {
fut.onComplete(handler);
}
return (S) this;
}
Splits the provided buffer into multiple frames (which do not exceed the maximum web socket frame size)
and writes them in order to the socket.
/**
* Splits the provided buffer into multiple frames (which do not exceed the maximum web socket frame size)
* and writes them in order to the socket.
*/
private Future<Void> writePartialMessage(FrameType frameType, Buffer data, int offset) {
int end = offset + maxWebSocketFrameSize;
boolean isFinal;
if (end >= data.length()) {
end = data.length();
isFinal = true;
} else {
isFinal = false;
}
Buffer slice = data.slice(offset, end);
WebSocketFrame frame;
if (offset == 0 || !supportsContinuation) {
frame = new WebSocketFrameImpl(frameType, slice.getByteBuf(), isFinal);
} else {
frame = WebSocketFrame.continuationFrame(slice, isFinal);
}
int newOffset = offset + maxWebSocketFrameSize;
if (isFinal) {
return writeFrame(frame);
} else {
writeFrame(frame);
return writePartialMessage(frameType, data, newOffset);
}
}
@Override
public Future<Void> writeFrame(WebSocketFrame frame) {
synchronized (conn) {
if (isClosed()) {
return context.failedFuture("WebSocket is closed");
}
PromiseInternal<Void> promise = context.promise();
conn.writeToChannel(encodeFrame((WebSocketFrameImpl) frame), promise);
return promise.future();
}
}
public final S writeFrame(WebSocketFrame frame, Handler<AsyncResult<Void>> handler) {
Future<Void> fut = writeFrame(frame);
if (handler != null) {
fut.onComplete(handler);
}
return (S) this;
}
private void writeBinaryFrameInternal(Buffer data) {
writeFrame(new WebSocketFrameImpl(FrameType.BINARY, data.getByteBuf()));
}
private void writeTextFrameInternal(String str) {
writeFrame(new WebSocketFrameImpl(str));
}
private io.netty.handler.codec.http.websocketx.WebSocketFrame encodeFrame(WebSocketFrameImpl frame) {
ByteBuf buf = frame.getBinaryData();
if (buf != Unpooled.EMPTY_BUFFER) {
buf = safeBuffer(buf, chctx.alloc());
}
switch (frame.type()) {
case BINARY:
return new BinaryWebSocketFrame(frame.isFinal(), 0, buf);
case TEXT:
return new TextWebSocketFrame(frame.isFinal(), 0, buf);
case CLOSE:
return new CloseWebSocketFrame(true, 0, buf);
case CONTINUATION:
return new ContinuationWebSocketFrame(frame.isFinal(), 0, buf);
case PONG:
return new PongWebSocketFrame(buf);
case PING:
return new PingWebSocketFrame(buf);
default:
throw new IllegalStateException("Unsupported WebSocket msg " + frame);
}
}
void checkClosed() {
if (isClosed()) {
throw new IllegalStateException("WebSocket is closed");
}
}
public boolean isClosed() {
synchronized (conn) {
return closed || closeStatusCode != null;
}
}
private WebSocketFrameInternal decodeFrame(io.netty.handler.codec.http.websocketx.WebSocketFrame msg) {
ByteBuf payload = safeBuffer(msg, chctx.alloc());
boolean isFinal = msg.isFinalFragment();
FrameType frameType;
if (msg instanceof BinaryWebSocketFrame) {
frameType = FrameType.BINARY;
} else if (msg instanceof CloseWebSocketFrame) {
frameType = FrameType.CLOSE;
} else if (msg instanceof PingWebSocketFrame) {
frameType = FrameType.PING;
} else if (msg instanceof PongWebSocketFrame) {
frameType = FrameType.PONG;
} else if (msg instanceof TextWebSocketFrame) {
frameType = FrameType.TEXT;
} else if (msg instanceof ContinuationWebSocketFrame) {
frameType = FrameType.CONTINUATION;
} else {
throw new IllegalStateException("Unsupported WebSocket msg " + msg);
}
return new WebSocketFrameImpl(frameType, payload, isFinal);
}
void handleFrame(io.netty.handler.codec.http.websocketx.WebSocketFrame msg) {
WebSocketFrameInternal frame = decodeFrame(msg);
switch (frame.type()) {
case PING:
// Echo back the content of the PING frame as PONG frame as specified in RFC 6455 Section 5.5.2
conn.writeToChannel(new PongWebSocketFrame(frame.getBinaryData().copy()));
break;
case PONG:
Handler<Buffer> pongHandler = pongHandler();
if (pongHandler != null) {
context.dispatch(frame.binaryData(), pongHandler);
}
break;
case CLOSE:
handleCloseFrame((CloseWebSocketFrame) msg);
break;
}
if (!pending.write(frame)) {
conn.doPause();
}
}
private void handleCloseFrame(CloseWebSocketFrame closeFrame) {
boolean echo;
synchronized (conn) {
echo = closeStatusCode == null;
closed = true;
closeStatusCode = (short)closeFrame.statusCode();
closeReason = closeFrame.reasonText();
}
handleClose(true);
if (echo) {
ChannelPromise fut = conn.channelFuture();
conn.writeToChannel(closeFrame.retainedDuplicate(), fut);
fut.addListener(v -> closeConnection());
} else {
closeConnection();
}
}
protected void handleClose(boolean graceful) {
MessageConsumer<?> binaryConsumer;
MessageConsumer<?> textConsumer;
Handler<Void> closeHandler;
Handler<Throwable> exceptionHandler;
synchronized (conn) {
closeHandler = this.closeHandler;
exceptionHandler = this.exceptionHandler;
binaryConsumer = this.binaryHandlerRegistration;
textConsumer = this.textHandlerRegistration;
this.binaryHandlerRegistration = null;
this.textHandlerRegistration = null;
this.closeHandler = null;
this.exceptionHandler = null;
}
if (binaryConsumer != null) {
binaryConsumer.unregister();
}
if (textConsumer != null) {
textConsumer.unregister();
}
if (exceptionHandler != null && !graceful) {
context.dispatch(ConnectionBase.CLOSED_EXCEPTION, exceptionHandler);
}
if (closeHandler != null) {
context.dispatch(null, closeHandler);
}
}
private void receiveFrame(WebSocketFrameInternal frame) {
Handler<WebSocketFrameInternal> frameHandler;
synchronized (conn) {
frameHandler = this.frameHandler;
}
if (frameHandler != null) {
context.dispatch(frame, frameHandler);
}
switch(frame.type()) {
case CLOSE:
Handler<Void> endHandler = endHandler();
if (endHandler != null) {
context.dispatch(endHandler);
}
break;
case TEXT:
case BINARY:
case CONTINUATION:
Handler<Buffer> handler = handler();
if (handler != null) {
context.dispatch(frame.binaryData(), handler);
}
break;
}
}
protected abstract void closeConnection();
private class FrameAggregator implements Handler<WebSocketFrameInternal> {
private Handler<String> textMessageHandler;
private Handler<Buffer> binaryMessageHandler;
private Buffer textMessageBuffer;
private Buffer binaryMessageBuffer;
@Override
public void handle(WebSocketFrameInternal frame) {
switch (frame.type()) {
case TEXT:
handleTextFrame(frame);
break;
case BINARY:
handleBinaryFrame(frame);
break;
case CONTINUATION:
if (textMessageBuffer != null && textMessageBuffer.length() > 0) {
handleTextFrame(frame);
} else if (binaryMessageBuffer != null && binaryMessageBuffer.length() > 0) {
handleBinaryFrame(frame);
}
break;
}
}
private void handleTextFrame(WebSocketFrameInternal frame) {
Buffer frameBuffer = Buffer.buffer(frame.getBinaryData());
if (textMessageBuffer == null) {
textMessageBuffer = frameBuffer;
} else {
textMessageBuffer.appendBuffer(frameBuffer);
}
if (textMessageBuffer.length() > maxWebSocketMessageSize) {
int len = textMessageBuffer.length() - frameBuffer.length();
textMessageBuffer = null;
String msg = "Cannot process text frame of size " + frameBuffer.length() + ", it would cause message buffer (size " +
len + ") to overflow max message size of " + maxWebSocketMessageSize;
handleException(new IllegalStateException(msg));
return;
}
if (frame.isFinal()) {
String fullMessage = textMessageBuffer.toString();
textMessageBuffer = null;
if (textMessageHandler != null) {
textMessageHandler.handle(fullMessage);
}
}
}
private void handleBinaryFrame(WebSocketFrameInternal frame) {
Buffer frameBuffer = Buffer.buffer(frame.getBinaryData());
if (binaryMessageBuffer == null) {
binaryMessageBuffer = frameBuffer;
} else {
binaryMessageBuffer.appendBuffer(frameBuffer);
}
if (binaryMessageBuffer.length() > maxWebSocketMessageSize) {
int len = binaryMessageBuffer.length() - frameBuffer.length();
binaryMessageBuffer = null;
String msg = "Cannot process binary frame of size " + frameBuffer.length() + ", it would cause message buffer (size " +
len + ") to overflow max message size of " + maxWebSocketMessageSize;
handleException(new IllegalStateException(msg));
return;
}
if (frame.isFinal()) {
Buffer fullMessage = binaryMessageBuffer.copy();
binaryMessageBuffer = null;
if (binaryMessageHandler != null) {
binaryMessageHandler.handle(fullMessage);
}
}
}
}
@Override
public S frameHandler(Handler<WebSocketFrame> handler) {
synchronized (conn) {
checkClosed();
this.frameHandler = (Handler)handler;
return (S) this;
}
}
@Override
public WebSocketBase textMessageHandler(Handler<String> handler) {
synchronized (conn) {
checkClosed();
if (frameHandler == null || frameHandler.getClass() != FrameAggregator.class) {
frameHandler = new FrameAggregator();
}
((FrameAggregator) frameHandler).textMessageHandler = handler;
return this;
}
}
@Override
public S binaryMessageHandler(Handler<Buffer> handler) {
synchronized (conn) {
checkClosed();
if (frameHandler == null || frameHandler.getClass() != FrameAggregator.class) {
frameHandler = new FrameAggregator();
}
((FrameAggregator) frameHandler).binaryMessageHandler = handler;
return (S) this;
}
}
@Override
public WebSocketBase pongHandler(Handler<Buffer> handler) {
synchronized (conn) {
checkClosed();
this.pongHandler = handler;
return (S) this;
}
}
private Handler<Buffer> pongHandler() {
synchronized (conn) {
return pongHandler;
}
}
void handleWritabilityChanged(boolean writable) {
Handler<Void> handler;
synchronized (conn) {
boolean skip = this.writable && !writable;
this.writable = writable;
handler = drainHandler;
if (handler == null || skip) {
return;
}
}
context.dispatch(null, handler);
}
void handleException(Throwable t) {
Handler<Throwable> handler;
synchronized (conn) {
handler = this.exceptionHandler;
if (handler == null) {
return;
}
}
context.dispatch(t, handler);
}
void handleConnectionClosed() {
synchronized (conn) {
if (closed) {
return;
}
closed = true;
}
handleClose(false);
}
synchronized void setMetric(Object metric) {
this.metric = metric;
}
synchronized Object getMetric() {
return metric;
}
@Override
public S handler(Handler<Buffer> handler) {
synchronized (conn) {
if (handler != null) {
checkClosed();
}
this.handler = handler;
return (S) this;
}
}
private Handler<Buffer> handler() {
synchronized (conn) {
return handler;
}
}
@Override
public S endHandler(Handler<Void> handler) {
synchronized (conn) {
if (handler != null) {
checkClosed();
}
this.endHandler = handler;
return (S) this;
}
}
private Handler<Void> endHandler() {
synchronized (conn) {
return endHandler;
}
}
@Override
public S exceptionHandler(Handler<Throwable> handler) {
synchronized (conn) {
if (handler != null) {
checkClosed();
}
this.exceptionHandler = handler;
return (S) this;
}
}
@Override
public S closeHandler(Handler<Void> handler) {
synchronized (conn) {
checkClosed();
this.closeHandler = handler;
return (S) this;
}
}
@Override
public S drainHandler(Handler<Void> handler) {
synchronized (conn) {
checkClosed();
this.drainHandler = handler;
return (S) this;
}
}
@Override
public S pause() {
pending.pause();
return (S) this;
}
@Override
public S resume() {
pending.resume();
return (S) this;
}
@Override
public S fetch(long amount) {
pending.fetch(amount);
return (S) this;
}
@Override
public S setWriteQueueMaxSize(int maxSize) {
synchronized (conn) {
checkClosed();
conn.doSetWriteQueueMaxSize(maxSize);
return (S) this;
}
}
@Override
public Future<Void> end() {
return close();
}
@Override
public void end(Handler<AsyncResult<Void>> handler) {
close(handler);
}
}