/*
 * Copyright (c) 2011-2017 Contributors to the Eclipse Foundation
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
 * which is available at https://www.apache.org/licenses/LICENSE-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
 */

package io.vertx.core.http.impl;

import io.netty.buffer.ByteBuf;
import io.vertx.codegen.annotations.Nullable;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.eventbus.EventBus;
import io.vertx.core.eventbus.Message;
import io.vertx.core.eventbus.MessageConsumer;
import io.vertx.core.http.WebSocketBase;
import io.vertx.core.http.WebSocketFrame;
import io.vertx.core.http.impl.ws.WebSocketFrameImpl;
import io.vertx.core.http.impl.ws.WebSocketFrameInternal;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.net.SocketAddress;
import io.vertx.core.streams.impl.InboundBuffer;

import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.security.cert.X509Certificate;
import java.util.UUID;

This class is optimised for performance when used on the same event loop. However it can be used safely from other threads.

The internal state is protected using the synchronized keyword. If always used on the same event loop, then we benefit from biased locking which makes the overhead of synchronized near zero.

Author:Tim Fox
Type parameters:
  • <S> – self return type
/** * This class is optimised for performance when used on the same event loop. However it can be used safely from other threads. * <p> * The internal state is protected using the synchronized keyword. If always used on the same event loop, then * we benefit from biased locking which makes the overhead of synchronized near zero. * * @author <a href="http://tfox.org">Tim Fox</a> * @param <S> self return type */
public abstract class WebSocketImplBase<S extends WebSocketBase> implements WebSocketBase { private final boolean supportsContinuation; private final String textHandlerID; private final String binaryHandlerID; private final int maxWebSocketFrameSize; private final int maxWebSocketMessageSize; private final InboundBuffer<Buffer> pending; private MessageConsumer binaryHandlerRegistration; private MessageConsumer textHandlerRegistration; private String subProtocol; private Object metric; private Handler<WebSocketFrameInternal> frameHandler; private Handler<Buffer> pongHandler; private Handler<Void> drainHandler; private Handler<Throwable> exceptionHandler; private Handler<Void> closeHandler; private Handler<Void> endHandler; protected final Http1xConnectionBase conn; protected boolean closed; WebSocketImplBase(Http1xConnectionBase conn, boolean supportsContinuation, int maxWebSocketFrameSize, int maxWebSocketMessageSize) { this.supportsContinuation = supportsContinuation; this.textHandlerID = "__vertx.ws." + UUID.randomUUID().toString(); this.binaryHandlerID = "__vertx.ws." + UUID.randomUUID().toString(); this.conn = conn; this.maxWebSocketFrameSize = maxWebSocketFrameSize; this.maxWebSocketMessageSize = maxWebSocketMessageSize; this.pending = new InboundBuffer<>(conn.getContext()); pending.drainHandler(v -> { conn.doResume(); }); } void registerHandler(EventBus eventBus) { Handler<Message<Buffer>> binaryHandler = msg -> writeBinaryFrameInternal(msg.body()); Handler<Message<String>> textHandler = msg -> writeTextFrameInternal(msg.body()); binaryHandlerRegistration = eventBus.<Buffer>localConsumer(binaryHandlerID).handler(binaryHandler); textHandlerRegistration = eventBus.<String>localConsumer(textHandlerID).handler(textHandler); } public String binaryHandlerID() { return binaryHandlerID; } public String textHandlerID() { return textHandlerID; } public boolean writeQueueFull() { synchronized (conn) { checkClosed(); return conn.isNotWritable(); } } @Override public void close() { close(null); } @Override public void close(Handler<AsyncResult<Void>> handler) { close((short) 1000, null, handler); } @Override public void close(short statusCode) { close(statusCode, (Handler<AsyncResult<Void>>) null); } @Override public void close(short statusCode, Handler<AsyncResult<Void>> handler) { this.close(statusCode, null, handler); } @Override public void close(short statusCode, String reason) { close(statusCode, reason, null); } @Override public void close(short statusCode, @Nullable String reason, Handler<AsyncResult<Void>> handler) { synchronized (conn) { if (closed) { return; } closed = true; } unregisterHandlers(); conn.closeWithPayload(statusCode, reason, handler); } @Override public boolean isSsl() { return conn.isSsl(); } @Override public SSLSession sslSession() { return conn.sslSession(); } @Override public X509Certificate[] peerCertificateChain() throws SSLPeerUnverifiedException { return conn.peerCertificateChain(); } @Override public SocketAddress localAddress() { return conn.localAddress(); } @Override public SocketAddress remoteAddress() { return conn.remoteAddress(); } @Override public S writeFinalTextFrame(String text) { return writeFinalTextFrame(text, null); } @Override public S writeFinalTextFrame(String text, Handler<AsyncResult<Void>> handler) { return writeFrame(WebSocketFrame.textFrame(text, true), handler); } @Override public S writeFinalBinaryFrame(Buffer data) { return writeFinalBinaryFrame(data, null); } @Override public S writeFinalBinaryFrame(Buffer data, Handler<AsyncResult<Void>> handler) { return writeFrame(WebSocketFrame.binaryFrame(data, true), handler); } @Override public String subProtocol() { synchronized(conn) { return subProtocol; } } void subProtocol(String subProtocol) { synchronized (conn) { this.subProtocol = subProtocol; } } @Override public S writeBinaryMessage(Buffer data) { return writeBinaryMessage(data, null); } @Override public S writeBinaryMessage(Buffer data, Handler<AsyncResult<Void>> handler) { synchronized (conn) { checkClosed(); writePartialMessage(FrameType.BINARY, data, 0, handler); return (S) this; } } @Override public S writeTextMessage(String text) { return writeTextMessage(text, null); } @Override public S writeTextMessage(String text, @Nullable Handler<AsyncResult<Void>> handler) { synchronized (conn) { checkClosed(); Buffer data = Buffer.buffer(text); writePartialMessage(FrameType.TEXT, data, 0, handler); return (S) this; } } @Override public S write(Buffer data) { return write(data, null); } @Override public S write(Buffer data, Handler<AsyncResult<Void>> handler) { synchronized (conn) { checkClosed(); writeFrame(WebSocketFrame.binaryFrame(data, true), handler); return (S) this; } } @Override public S writePing(Buffer data) { if(data.length() > maxWebSocketFrameSize || data.length() > 125) throw new IllegalStateException("Ping cannot exceed maxWebSocketFrameSize or 125 bytes"); return writeFrame(WebSocketFrame.pingFrame(data)); } @Override public S writePong(Buffer data) { if(data.length() > maxWebSocketFrameSize || data.length() > 125) throw new IllegalStateException("Pong cannot exceed maxWebSocketFrameSize or 125 bytes"); return writeFrame(WebSocketFrame.pongFrame(data)); }
Splits the provided buffer into multiple frames (which do not exceed the maximum web socket frame size) and writes them in order to the socket.
/** * Splits the provided buffer into multiple frames (which do not exceed the maximum web socket frame size) * and writes them in order to the socket. */
private void writePartialMessage(FrameType frameType, Buffer data, int offset, Handler<AsyncResult<Void>> handler) { int end = offset + maxWebSocketFrameSize; boolean isFinal; if (end >= data.length()) { end = data.length(); isFinal = true; } else { isFinal = false; } Buffer slice = data.slice(offset, end); WebSocketFrame frame; if (offset == 0 || !supportsContinuation) { frame = new WebSocketFrameImpl(frameType, slice.getByteBuf(), isFinal); } else { frame = WebSocketFrame.continuationFrame(slice, isFinal); } int newOffset = offset + maxWebSocketFrameSize; if (isFinal) { writeFrame(frame, handler); } else { writeFrame(frame); writePartialMessage(frameType, data, newOffset, handler); } } private void writeBinaryFrameInternal(Buffer data) { ByteBuf buf = data.getByteBuf(); WebSocketFrame frame = new WebSocketFrameImpl(FrameType.BINARY, buf); writeFrame(frame); } private void writeTextFrameInternal(String str) { WebSocketFrame frame = new WebSocketFrameImpl(str); writeFrame(frame); } @Override public S writeFrame(WebSocketFrame frame) { return writeFrame(frame, null); } public S writeFrame(WebSocketFrame frame, Handler<AsyncResult<Void>> handler) { synchronized (conn) { checkClosed(); conn.reportBytesWritten(((WebSocketFrameInternal)frame).length()); conn.writeToChannel(conn.encodeFrame((WebSocketFrameImpl) frame), conn.toPromise(handler)); } return (S) this; } void checkClosed() { synchronized (conn) { if (closed) { throw new IllegalStateException("WebSocket is closed"); } } } public boolean isClosed() { synchronized (conn) { return closed; } } void handleFrame(WebSocketFrameInternal frame) { synchronized (conn) { if (frame.type() != FrameType.CLOSE) { conn.reportBytesRead(frame.length()); if (!pending.write(frame.binaryData())) { conn.doPause(); } } switch(frame.type()) { case PONG: if (pongHandler != null) { pongHandler.handle(frame.binaryData()); } break; case TEXT: case CLOSE: case BINARY: case CONTINUATION: if (frameHandler != null) { frameHandler.handle(frame); } break; } } } private class FrameAggregator implements Handler<WebSocketFrameInternal> { private Handler<String> textMessageHandler; private Handler<Buffer> binaryMessageHandler; private Buffer textMessageBuffer; private Buffer binaryMessageBuffer; @Override public void handle(WebSocketFrameInternal frame) { switch (frame.type()) { case TEXT: handleTextFrame(frame); break; case BINARY: handleBinaryFrame(frame); break; case CONTINUATION: if (textMessageBuffer != null && textMessageBuffer.length() > 0) { handleTextFrame(frame); } else if (binaryMessageBuffer != null && binaryMessageBuffer.length() > 0) { handleBinaryFrame(frame); } break; } } private void handleTextFrame(WebSocketFrameInternal frame) { Buffer frameBuffer = Buffer.buffer(frame.getBinaryData()); if (textMessageBuffer == null) { textMessageBuffer = frameBuffer; } else { textMessageBuffer.appendBuffer(frameBuffer); } if (textMessageBuffer.length() > maxWebSocketMessageSize) { int len = textMessageBuffer.length() - frameBuffer.length(); textMessageBuffer = null; String msg = "Cannot process text frame of size " + frameBuffer.length() + ", it would cause message buffer (size " + len + ") to overflow max message size of " + maxWebSocketMessageSize; handleException(new IllegalStateException(msg)); return; } if (frame.isFinal()) { String fullMessage = textMessageBuffer.toString(); textMessageBuffer = null; if (textMessageHandler != null) { textMessageHandler.handle(fullMessage); } } } private void handleBinaryFrame(WebSocketFrameInternal frame) { Buffer frameBuffer = Buffer.buffer(frame.getBinaryData()); if (binaryMessageBuffer == null) { binaryMessageBuffer = frameBuffer; } else { binaryMessageBuffer.appendBuffer(frameBuffer); } if (binaryMessageBuffer.length() > maxWebSocketMessageSize) { int len = binaryMessageBuffer.length() - frameBuffer.length(); binaryMessageBuffer = null; String msg = "Cannot process binary frame of size " + frameBuffer.length() + ", it would cause message buffer (size " + len + ") to overflow max message size of " + maxWebSocketMessageSize; handleException(new IllegalStateException(msg)); return; } if (frame.isFinal()) { Buffer fullMessage = binaryMessageBuffer.copy(); binaryMessageBuffer = null; if (binaryMessageHandler != null) { binaryMessageHandler.handle(fullMessage); } } } } @Override public S frameHandler(Handler<WebSocketFrame> handler) { synchronized (conn) { checkClosed(); this.frameHandler = (Handler)handler; return (S) this; } } @Override public WebSocketBase textMessageHandler(Handler<String> handler) { synchronized (conn) { checkClosed(); if (frameHandler == null || frameHandler.getClass() != FrameAggregator.class) { frameHandler = new FrameAggregator(); } ((FrameAggregator) frameHandler).textMessageHandler = handler; return this; } } @Override public S binaryMessageHandler(Handler<Buffer> handler) { synchronized (conn) { checkClosed(); if (frameHandler == null || frameHandler.getClass() != FrameAggregator.class) { frameHandler = new FrameAggregator(); } ((FrameAggregator) frameHandler).binaryMessageHandler = handler; return (S) this; } } @Override public WebSocketBase pongHandler(Handler<Buffer> handler) { synchronized (conn) { checkClosed(); this.pongHandler = handler; return (S) this; } } void handleDrained() { if (drainHandler != null) { Handler<Void> dh = drainHandler; drainHandler = null; dh.handle(null); } } void handleException(Throwable t) { synchronized (conn) { if (exceptionHandler != null) { exceptionHandler.handle(t); } } } void handleClosed() { unregisterHandlers(); Handler<Void> endHandler; Handler<Void> closeHandler; synchronized (conn) { endHandler = pending.isPaused() ? null : this.endHandler; closeHandler = this.closeHandler; closed = true; binaryHandlerRegistration = null; textHandlerRegistration = null; } if (closeHandler != null) { closeHandler.handle(null); } if (endHandler != null) { endHandler.handle(null); } }
Unregister handlers if they when they are present
/** * Unregister handlers if they when they are present */
private void unregisterHandlers() { MessageConsumer binaryConsumer; MessageConsumer textConsumer; synchronized (conn) { binaryConsumer = this.binaryHandlerRegistration; textConsumer = this.textHandlerRegistration; binaryHandlerRegistration = null; textHandlerRegistration = null; } if (binaryConsumer != null) { binaryConsumer.unregister(); } if (textConsumer != null) { textConsumer.unregister(); } } synchronized void setMetric(Object metric) { this.metric = metric; } synchronized Object getMetric() { return metric; } @Override public S handler(Handler<Buffer> handler) { synchronized (conn) { if (handler != null) { checkClosed(); } pending.handler(handler); return (S) this; } } @Override public S endHandler(Handler<Void> handler) { synchronized (conn) { if (handler != null) { checkClosed(); } this.endHandler = handler; return (S) this; } } @Override public S exceptionHandler(Handler<Throwable> handler) { synchronized (conn) { if (handler != null) { checkClosed(); } this.exceptionHandler = handler; return (S) this; } } @Override public S closeHandler(Handler<Void> handler) { synchronized (conn) { checkClosed(); this.closeHandler = handler; return (S) this; } } @Override public S drainHandler(Handler<Void> handler) { synchronized (conn) { checkClosed(); this.drainHandler = handler; return (S) this; } } @Override public S pause() { if (!isClosed()) { pending.pause(); } return (S) this; } @Override public S resume() { synchronized (this) { if (isClosed()) { Handler<Void> handler = endHandler; endHandler = null; if (handler != null) { ContextInternal ctx = conn.getContext(); ctx.runOnContext(v -> handler.handle(null)); } } else { pending.resume(); } } return (S) this; } @Override public S fetch(long amount) { if (!isClosed()) { pending.fetch(amount); } return (S) this; } @Override public S setWriteQueueMaxSize(int maxSize) { synchronized (conn) { checkClosed(); conn.doSetWriteQueueMaxSize(maxSize); return (S) this; } } @Override public void end() { close(); } @Override public void end(Handler<AsyncResult<Void>> handler) { close(handler); } }