package io.undertow.websockets.core.protocol;
import io.undertow.util.Headers;
import io.undertow.websockets.WebSocketExtension;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketVersion;
import io.undertow.websockets.extensions.ExtensionFunction;
import io.undertow.websockets.extensions.ExtensionHandshake;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import org.xnio.IoFuture;
import io.undertow.connector.ByteBufferPool;
import org.xnio.StreamConnection;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
public abstract class Handshake {
private final WebSocketVersion version;
private final String hashAlgorithm;
private final String magicNumber;
protected final Set<String> subprotocols;
private static final byte[] EMPTY = new byte[0];
private static final Pattern PATTERN = Pattern.compile("\\s*,\\s*");
protected Set<ExtensionHandshake> availableExtensions = new HashSet<>();
protected boolean allowExtensions;
protected Handshake(WebSocketVersion version, String hashAlgorithm, String magicNumber, final Set<String> subprotocols) {
this.version = version;
this.hashAlgorithm = hashAlgorithm;
this.magicNumber = magicNumber;
this.subprotocols = subprotocols;
}
public WebSocketVersion getVersion() {
return version;
}
public String getHashAlgorithm() {
return hashAlgorithm;
}
public String getMagicNumber() {
return magicNumber;
}
protected static String getWebSocketLocation(WebSocketHttpExchange exchange) {
String scheme;
if ("https".equals(exchange.getRequestScheme())) {
scheme = "wss";
} else {
scheme = "ws";
}
return scheme + "://" + exchange.getRequestHeader(Headers.HOST_STRING) + exchange.getRequestURI();
}
public final void handshake(final WebSocketHttpExchange exchange) {
exchange.putAttachment(WebSocketVersion.ATTACHMENT_KEY, version);
handshakeInternal(exchange);
}
protected abstract void handshakeInternal(final WebSocketHttpExchange exchange);
public abstract boolean matches(WebSocketHttpExchange exchange);
public abstract WebSocketChannel createChannel(WebSocketHttpExchange exchange, final StreamConnection channel, final ByteBufferPool pool);
protected final void performUpgrade(final WebSocketHttpExchange exchange, final byte[] data) {
exchange.setResponseHeader(Headers.CONTENT_LENGTH_STRING, String.valueOf(data.length));
exchange.setResponseHeader(Headers.UPGRADE_STRING, "WebSocket");
exchange.setResponseHeader(Headers.CONNECTION_STRING, "Upgrade");
upgradeChannel(exchange, data);
}
protected void upgradeChannel(final WebSocketHttpExchange exchange, final byte[] data) {
if (data.length > 0) {
writePayload(exchange, ByteBuffer.wrap(data));
} else {
exchange.endExchange();
}
}
private static void writePayload(final WebSocketHttpExchange exchange, final ByteBuffer payload) {
exchange.sendData(payload).addNotifier(new IoFuture.Notifier<Void, Object>() {
@Override
public void notify(final IoFuture<? extends Void> ioFuture, final Object attachment) {
if (ioFuture.getStatus() == IoFuture.Status.DONE) {
exchange.endExchange();
} else {
exchange.close();
}
}
}, null);
}
protected final void performUpgrade(final WebSocketHttpExchange exchange) {
performUpgrade(exchange, EMPTY);
}
protected final void selectSubprotocol(final WebSocketHttpExchange exchange) {
String requestedSubprotocols = exchange.getRequestHeader(Headers.SEC_WEB_SOCKET_PROTOCOL_STRING);
if (requestedSubprotocols == null) {
return;
}
String[] requestedSubprotocolArray = PATTERN.split(requestedSubprotocols);
String subProtocol = supportedSubprotols(requestedSubprotocolArray);
if (subProtocol != null && !subProtocol.isEmpty()) {
exchange.setResponseHeader(Headers.SEC_WEB_SOCKET_PROTOCOL_STRING, subProtocol);
}
}
protected final void selectExtensions(final WebSocketHttpExchange exchange) {
List<WebSocketExtension> requestedExtensions = WebSocketExtension.parse(exchange.getRequestHeader(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING));
List<WebSocketExtension> extensions = selectedExtension(requestedExtensions);
if (extensions != null && !extensions.isEmpty()) {
exchange.setResponseHeader(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING, WebSocketExtension.toExtensionHeader(extensions));
}
}
protected String supportedSubprotols(String[] requestedSubprotocolArray) {
for (String p : requestedSubprotocolArray) {
String requestedSubprotocol = p.trim();
for (String supportedSubprotocol : subprotocols) {
if (requestedSubprotocol.equals(supportedSubprotocol)) {
return supportedSubprotocol;
}
}
}
return null;
}
protected List<WebSocketExtension> selectedExtension(List<WebSocketExtension> extensionList) {
List<WebSocketExtension> selected = new ArrayList<>();
List<ExtensionHandshake> configured = new ArrayList<>();
for (WebSocketExtension ext : extensionList) {
for (ExtensionHandshake extHandshake : availableExtensions) {
WebSocketExtension negotiated = extHandshake.accept(ext);
if (negotiated != null && !extHandshake.isIncompatible(configured)) {
selected.add(negotiated);
configured.add(extHandshake);
}
}
}
return selected;
}
public final void addExtension(ExtensionHandshake extension) {
availableExtensions.add(extension);
allowExtensions = true;
}
protected final List<ExtensionFunction> initExtensions(final WebSocketHttpExchange exchange) {
String extHeader = exchange.getResponseHeaders().get(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING) != null ?
exchange.getResponseHeaders().get(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING).get(0) : null;
List<ExtensionFunction> negotiated = new ArrayList<>();
if (extHeader != null) {
List<WebSocketExtension> extensions = WebSocketExtension.parse(extHeader);
if (extensions != null && !extensions.isEmpty()) {
for (WebSocketExtension ext : extensions) {
for (ExtensionHandshake extHandshake : availableExtensions) {
if (extHandshake.getName().equals(ext.getName())) {
negotiated.add(extHandshake.create());
}
}
}
}
}
return negotiated;
}
}