package io.undertow.websockets.client;
import io.undertow.UndertowMessages;
import io.undertow.client.ClientCallback;
import io.undertow.client.ClientConnection;
import io.undertow.client.ClientExchange;
import io.undertow.client.ClientRequest;
import io.undertow.client.UndertowClient;
import io.undertow.protocols.ssl.UndertowXnioSsl;
import io.undertow.util.Headers;
import io.undertow.util.Methods;
import io.undertow.util.Protocols;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketLogger;
import io.undertow.websockets.core.WebSocketVersion;
import io.undertow.websockets.extensions.ExtensionHandshake;
import org.xnio.Cancellable;
import org.xnio.ChannelListener;
import org.xnio.FutureResult;
import org.xnio.IoFuture;
import org.xnio.OptionMap;
import io.undertow.connector.ByteBufferPool;
import org.xnio.StreamConnection;
import org.xnio.XnioWorker;
import org.xnio.http.HttpUpgrade;
import org.xnio.http.RedirectException;
import org.xnio.ssl.XnioSsl;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class WebSocketClient {
public static final String BIND_PROPERTY = "io.undertow.websockets.BIND_ADDRESS";
private static final int MAX_REDIRECTS = Integer.getInteger("io.undertow.websockets.max-redirects", 5);
@Deprecated
public static IoFuture<WebSocketChannel> connect(XnioWorker worker, final ByteBufferPool bufferPool, final OptionMap optionMap, final URI uri, WebSocketVersion version) {
return connect(worker, bufferPool, optionMap, uri, version, null);
}
@Deprecated
public static IoFuture<WebSocketChannel> connect(XnioWorker worker, XnioSsl ssl, final ByteBufferPool bufferPool, final OptionMap optionMap, final URI uri, WebSocketVersion version) {
return connect(worker, ssl, bufferPool, optionMap, uri, version, null);
}
@Deprecated
public static IoFuture<WebSocketChannel> connect(XnioWorker worker, final ByteBufferPool bufferPool, final OptionMap optionMap, final URI uri, WebSocketVersion version, WebSocketClientNegotiation clientNegotiation) {
return connect(worker, null, bufferPool, optionMap, uri, version, clientNegotiation);
}
@Deprecated
public static IoFuture<WebSocketChannel> connect(XnioWorker worker, XnioSsl ssl, final ByteBufferPool bufferPool, final OptionMap optionMap, final URI uri, WebSocketVersion version, WebSocketClientNegotiation clientNegotiation) {
return connect(worker, ssl, bufferPool, optionMap, uri, version, clientNegotiation, null);
}
@Deprecated
public static IoFuture<WebSocketChannel> connect(XnioWorker worker, XnioSsl ssl, final ByteBufferPool bufferPool, final OptionMap optionMap, final URI uri, WebSocketVersion version, WebSocketClientNegotiation clientNegotiation, Set<ExtensionHandshake> clientExtensions) {
return connect(worker, ssl, bufferPool, optionMap, null, uri, version, clientNegotiation, clientExtensions);
}
@Deprecated
public static IoFuture<WebSocketChannel> connect(XnioWorker worker, XnioSsl ssl, final ByteBufferPool bufferPool, final OptionMap optionMap, InetSocketAddress bindAddress, final URI uri, WebSocketVersion version, WebSocketClientNegotiation clientNegotiation, Set<ExtensionHandshake> clientExtensions) {
return connectionBuilder(worker, bufferPool, uri)
.setSsl(ssl)
.setOptionMap(optionMap)
.setBindAddress(bindAddress)
.setVersion(version)
.setClientNegotiation(clientNegotiation)
.setClientExtensions(clientExtensions)
.connect();
}
public static class ConnectionBuilder {
private final XnioWorker worker;
private final ByteBufferPool bufferPool;
private final URI uri;
private XnioSsl ssl;
private OptionMap optionMap = OptionMap.EMPTY;
private InetSocketAddress bindAddress;
private WebSocketVersion version = WebSocketVersion.V13;
private WebSocketClientNegotiation clientNegotiation;
private Set<ExtensionHandshake> clientExtensions;
private URI proxyUri;
private XnioSsl proxySsl;
public ConnectionBuilder(XnioWorker worker, ByteBufferPool bufferPool, URI uri) {
this.worker = worker;
this.bufferPool = bufferPool;
this.uri = uri;
}
public XnioWorker getWorker() {
return worker;
}
public URI getUri() {
return uri;
}
public XnioSsl getSsl() {
return ssl;
}
public ConnectionBuilder setSsl(XnioSsl ssl) {
this.ssl = ssl;
return this;
}
public ByteBufferPool getBufferPool() {
return bufferPool;
}
public OptionMap getOptionMap() {
return optionMap;
}
public ConnectionBuilder setOptionMap(OptionMap optionMap) {
this.optionMap = optionMap;
return this;
}
public InetSocketAddress getBindAddress() {
return bindAddress;
}
public ConnectionBuilder setBindAddress(InetSocketAddress bindAddress) {
this.bindAddress = bindAddress;
return this;
}
public WebSocketVersion getVersion() {
return version;
}
public ConnectionBuilder setVersion(WebSocketVersion version) {
this.version = version;
return this;
}
public WebSocketClientNegotiation getClientNegotiation() {
return clientNegotiation;
}
public ConnectionBuilder setClientNegotiation(WebSocketClientNegotiation clientNegotiation) {
this.clientNegotiation = clientNegotiation;
return this;
}
public Set<ExtensionHandshake> getClientExtensions() {
return clientExtensions;
}
public ConnectionBuilder setClientExtensions(Set<ExtensionHandshake> clientExtensions) {
this.clientExtensions = clientExtensions;
return this;
}
public URI getProxyUri() {
return proxyUri;
}
public ConnectionBuilder setProxyUri(URI proxyUri) {
this.proxyUri = proxyUri;
return this;
}
public XnioSsl getProxySsl() {
return proxySsl;
}
public ConnectionBuilder setProxySsl(XnioSsl proxySsl) {
this.proxySsl = proxySsl;
return this;
}
public IoFuture<WebSocketChannel> connect() {
return connectImpl(uri, new FutureResult<WebSocketChannel>(), 0);
}
private IoFuture<WebSocketChannel> connectImpl(final URI uri, final FutureResult<WebSocketChannel> ioFuture, final int redirectCount) {
WebSocketLogger.REQUEST_LOGGER.debugf("Opening websocket connection to %s", uri);
final String scheme = uri.getScheme().equals("wss") ? "https" : "http";
final URI newUri;
try {
newUri = new URI(scheme, uri.getUserInfo(), uri.getHost(), uri.getPort() == -1 ? (uri.getScheme().equals("wss") ? 443 : 80) : uri.getPort(), uri.getPath().isEmpty() ? "/" : uri.getPath(), uri.getQuery(), uri.getFragment());
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
final WebSocketClientHandshake handshake = WebSocketClientHandshake.create(version, newUri, clientNegotiation, clientExtensions);
final Map<String, String> originalHeaders = handshake.createHeaders();
originalHeaders.put(Headers.HOST_STRING, uri.getHost() + ":" + newUri.getPort());
final Map<String, List<String>> headers = new HashMap<>();
for(Map.Entry<String, String> entry : originalHeaders.entrySet()) {
List<String> list = new ArrayList<>();
list.add(entry.getValue());
headers.put(entry.getKey(), list);
}
if (clientNegotiation != null) {
clientNegotiation.beforeRequest(headers);
}
InetSocketAddress toBind = bindAddress;
String sysBind = System.getProperty(BIND_PROPERTY);
if(toBind == null && sysBind != null) {
toBind = new InetSocketAddress(sysBind, 0);
}
if(proxyUri != null) {
UndertowClient.getInstance().connect(new ClientCallback<ClientConnection>() {
@Override
public void completed(final ClientConnection connection) {
int port = uri.getPort() > 0 ? uri.getPort() : uri.getScheme().equals("https") || uri.getScheme().equals("wss") ? 443 : 80;
ClientRequest cr = new ClientRequest()
.setMethod(Methods.CONNECT)
.setPath(uri.getHost() + ":" + port)
.setProtocol(Protocols.HTTP_1_1);
cr.getRequestHeaders().put(Headers.HOST, proxyUri.getHost() + ":" + (proxyUri.getPort() > 0 ? proxyUri.getPort() : 80));
connection.sendRequest(cr, new ClientCallback<ClientExchange>() {
@Override
public void completed(ClientExchange result) {
result.setResponseListener(new ClientCallback<ClientExchange>() {
@Override
public void completed(ClientExchange response) {
try {
if (response.getResponse().getResponseCode() == 200) {
try {
StreamConnection targetConnection = connection.performUpgrade();
WebSocketLogger.REQUEST_LOGGER.debugf("Established websocket connection to %s", uri);
if (uri.getScheme().equals("wss") || uri.getScheme().equals("https")) {
handleConnectionWithExistingConnection(((UndertowXnioSsl) ssl).wrapExistingConnection(targetConnection, optionMap, uri));
} else {
handleConnectionWithExistingConnection(targetConnection);
}
} catch (IOException e) {
ioFuture.setException(e);
} catch (Exception e) {
ioFuture.setException(new IOException(e));
}
} else {
ioFuture.setException(UndertowMessages.MESSAGES.proxyConnectionFailed(response.getResponse().getResponseCode()));
}
} catch (Exception e) {
ioFuture.setException(new IOException(e));
}
}
private void handleConnectionWithExistingConnection(StreamConnection targetConnection) {
final IoFuture<?> result;
result = HttpUpgrade.performUpgrade(targetConnection, newUri, headers, new WebsocketConnectionListener(optionMap, handshake, newUri, ioFuture), handshake.handshakeChecker(newUri, headers));
result.addNotifier(new IoFuture.Notifier<Object, Object>() {
@Override
public void notify(IoFuture<?> res, Object attachment) {
if (res.getStatus() == IoFuture.Status.FAILED) {
ioFuture.setException(res.getException());
}
}
}, null);
ioFuture.addCancelHandler(new Cancellable() {
@Override
public Cancellable cancel() {
result.cancel();
return null;
}
});
}
@Override
public void failed(IOException e) {
ioFuture.setException(e);
}
});
}
@Override
public void failed(IOException e) {
ioFuture.setException(e);
}
});
}
@Override
public void failed(IOException e) {
ioFuture.setException(e);
}
}, bindAddress, proxyUri, worker, proxySsl, bufferPool, optionMap);
} else {
final IoFuture<?> result;
if (ssl != null) {
result = HttpUpgrade.performUpgrade(worker, ssl, toBind, newUri, headers, new WebsocketConnectionListener(optionMap, handshake, newUri, ioFuture), null, optionMap, handshake.handshakeChecker(newUri, headers));
} else {
result = HttpUpgrade.performUpgrade(worker, toBind, newUri, headers, new WebsocketConnectionListener(optionMap, handshake, newUri, ioFuture), null, optionMap, handshake.handshakeChecker(newUri, headers));
}
result.addNotifier(new IoFuture.Notifier<Object, Object>() {
@Override
public void notify(IoFuture<?> res, Object attachment) {
if (res.getStatus() == IoFuture.Status.FAILED) {
IOException exception = res.getException();
if(exception instanceof RedirectException) {
if(redirectCount == MAX_REDIRECTS) {
ioFuture.setException(UndertowMessages.MESSAGES.tooManyRedirects(exception));
} else {
String path = ((RedirectException) exception).getLocation();
try {
connectImpl(new URI(path), ioFuture, redirectCount + 1);
} catch (URISyntaxException e) {
ioFuture.setException(new IOException(e));
}
}
} else {
ioFuture.setException(exception);
}
}
}
}, null);
ioFuture.addCancelHandler(new Cancellable() {
@Override
public Cancellable cancel() {
result.cancel();
return null;
}
});
}
return ioFuture.getIoFuture();
}
private class WebsocketConnectionListener implements ChannelListener<StreamConnection> {
private final OptionMap options;
private final WebSocketClientHandshake handshake;
private final URI newUri;
private final FutureResult<WebSocketChannel> ioFuture;
WebsocketConnectionListener(OptionMap options, WebSocketClientHandshake handshake, URI newUri, FutureResult<WebSocketChannel> ioFuture) {
this.options = options;
this.handshake = handshake;
this.newUri = newUri;
this.ioFuture = ioFuture;
}
@Override
public void handleEvent(StreamConnection channel) {
WebSocketChannel result = handshake.createChannel(channel, newUri.toString(), bufferPool, options);
ioFuture.setResult(result);
}
}
}
public static ConnectionBuilder connectionBuilder(XnioWorker worker, ByteBufferPool bufferPool, URI uri) {
return new ConnectionBuilder(worker, bufferPool, uri);
}
private WebSocketClient() {
}
}