package io.undertow.websockets.client;
import io.undertow.util.FlexBase64;
import io.undertow.util.Headers;
import io.undertow.websockets.WebSocketExtension;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketMessages;
import io.undertow.websockets.core.WebSocketVersion;
import io.undertow.websockets.core.protocol.version13.WebSocket13Channel;
import io.undertow.websockets.extensions.CompositeExtensionFunction;
import io.undertow.websockets.extensions.ExtensionFunction;
import io.undertow.websockets.extensions.ExtensionHandshake;
import io.undertow.websockets.extensions.NoopExtensionFunction;
import org.xnio.OptionMap;
import io.undertow.connector.ByteBufferPool;
import org.xnio.StreamConnection;
import org.xnio.http.ExtendedHandshakeChecker;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
public class WebSocket13ClientHandshake extends WebSocketClientHandshake {
public static final String MAGIC_NUMBER = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
private final WebSocketClientNegotiation negotiation;
private final Set<ExtensionHandshake> extensions;
public WebSocket13ClientHandshake(final URI url, WebSocketClientNegotiation negotiation, Set<ExtensionHandshake> extensions) {
super(url);
this.negotiation = negotiation;
this.extensions = extensions == null ? Collections.<ExtensionHandshake>emptySet() : extensions;
}
public WebSocket13ClientHandshake(final URI url) {
this(url, null, null);
}
@Override
public WebSocketChannel createChannel(final StreamConnection channel, final String wsUri, final ByteBufferPool bufferPool, OptionMap options) {
if (negotiation != null && negotiation.getSelectedExtensions() != null && !negotiation.getSelectedExtensions().isEmpty()) {
List<WebSocketExtension> selected = negotiation.getSelectedExtensions();
List<ExtensionFunction> negotiated = new ArrayList<>();
if (selected != null && !selected.isEmpty()) {
for (WebSocketExtension ext : selected) {
for (ExtensionHandshake extHandshake : extensions) {
if (ext.getName().equals(extHandshake.getName())) {
negotiated.add(extHandshake.create());
}
}
}
}
return new WebSocket13Channel(channel, bufferPool, wsUri, negotiation.getSelectedSubProtocol(), true, !negotiated.isEmpty(), CompositeExtensionFunction.compose(negotiated), new HashSet<WebSocketChannel>(), options);
} else {
return new WebSocket13Channel(channel, bufferPool, wsUri, negotiation != null ? negotiation.getSelectedSubProtocol() : "", true, false, NoopExtensionFunction.INSTANCE, new HashSet<WebSocketChannel>(), options);
}
}
public Map<String, String> createHeaders() {
Map<String, String> headers = new HashMap<>();
headers.put(Headers.UPGRADE_STRING, "websocket");
headers.put(Headers.CONNECTION_STRING, "upgrade");
String key = createSecKey();
headers.put(Headers.SEC_WEB_SOCKET_KEY_STRING, key);
headers.put(Headers.SEC_WEB_SOCKET_VERSION_STRING, getVersion().toHttpHeaderValue());
if (negotiation != null) {
List<String> subProtocols = negotiation.getSupportedSubProtocols();
if (subProtocols != null && !subProtocols.isEmpty()) {
StringBuilder sb = new StringBuilder();
Iterator<String> it = subProtocols.iterator();
while (it.hasNext()) {
sb.append(it.next());
if (it.hasNext()) {
sb.append(", ");
}
}
headers.put(Headers.SEC_WEB_SOCKET_PROTOCOL_STRING, sb.toString());
}
List<WebSocketExtension> extensions = negotiation.getSupportedExtensions();
if (extensions != null && !extensions.isEmpty()) {
StringBuilder sb = new StringBuilder();
Iterator<WebSocketExtension> it = extensions.iterator();
while (it.hasNext()) {
WebSocketExtension next = it.next();
sb.append(next.getName());
for (WebSocketExtension.Parameter param : next.getParameters()) {
sb.append("; ");
sb.append(param.getName());
if (param.getValue() != null && param.getValue().length() > 0) {
sb.append("=");
sb.append(param.getValue());
}
}
if (it.hasNext()) {
sb.append(", ");
}
}
headers.put(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING, sb.toString());
}
}
return headers;
}
protected String createSecKey() {
SecureRandom random = new SecureRandom();
byte[] data = new byte[16];
for (int i = 0; i < 4; ++i) {
int val = random.nextInt();
data[i * 4] = (byte) val;
data[i * 4 + 1] = (byte) ((val >> 8) & 0xFF);
data[i * 4 + 2] = (byte) ((val >> 16) & 0xFF);
data[i * 4 + 3] = (byte) ((val >> 24) & 0xFF);
}
return FlexBase64.encodeString(data, false);
}
@Override
public ExtendedHandshakeChecker handshakeChecker(final URI uri, final Map<String, List<String>> requestHeaders) {
final String sentKey = requestHeaders.containsKey(Headers.SEC_WEB_SOCKET_KEY_STRING) ? requestHeaders.get(Headers.SEC_WEB_SOCKET_KEY_STRING).get(0) : null;
return new ExtendedHandshakeChecker() {
@Override
public void checkHandshakeExtended(Map<String, List<String>> headers) throws IOException {
try {
if (negotiation != null) {
negotiation.afterRequest(headers);
}
String upgrade = getFirst(Headers.UPGRADE_STRING, headers);
if (upgrade == null || !upgrade.trim().equalsIgnoreCase("websocket")) {
throw WebSocketMessages.MESSAGES.noWebSocketUpgradeHeader();
}
String connHeader = getFirst(Headers.CONNECTION_STRING, headers);
if (connHeader == null || !connHeader.trim().equalsIgnoreCase("upgrade")) {
throw WebSocketMessages.MESSAGES.noWebSocketConnectionHeader();
}
String acceptKey = getFirst(Headers.SEC_WEB_SOCKET_ACCEPT_STRING, headers);
final String dKey = solve(sentKey);
if (!dKey.equals(acceptKey)) {
throw WebSocketMessages.MESSAGES.webSocketAcceptKeyMismatch(dKey, acceptKey);
}
if (negotiation != null) {
String subProto = getFirst(Headers.SEC_WEB_SOCKET_PROTOCOL_STRING, headers);
if (subProto != null && !subProto.isEmpty() && !negotiation.getSupportedSubProtocols().contains(subProto)) {
throw WebSocketMessages.MESSAGES.unsupportedProtocol(subProto, negotiation.getSupportedSubProtocols());
}
List<WebSocketExtension> extensions = Collections.emptyList();
String extHeader = getFirst(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING, headers);
if (extHeader != null) {
extensions = WebSocketExtension.parse(extHeader);
}
negotiation.handshakeComplete(subProto, extensions);
}
} catch (IOException e) {
throw e;
} catch (Exception e) {
throw new IOException(e);
}
}
};
}
private String getFirst(String key, Map<String, List<String>> map) {
List<String> list = map.get(key.toLowerCase(Locale.ENGLISH));
if(list == null || list.isEmpty()) {
return null;
}
return list.get(0);
}
protected final String solve(final String nonceBase64) {
try {
final String concat = nonceBase64 + MAGIC_NUMBER;
final MessageDigest digest = MessageDigest.getInstance("SHA1");
digest.update(concat.getBytes(StandardCharsets.UTF_8));
final byte[] bytes = digest.digest();
return FlexBase64.encodeString(bytes, false);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
public WebSocketVersion getVersion() {
return WebSocketVersion.V13;
}
}