package io.vertx.core.http.impl;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.websocketx.WebSocket00FrameDecoder;
import io.netty.handler.codec.http.websocketx.WebSocket07FrameDecoder;
import io.netty.handler.codec.http.websocketx.WebSocket08FrameDecoder;
import io.netty.handler.codec.http.websocketx.WebSocket13FrameDecoder;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker00;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker07;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker08;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker13;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig;
import io.netty.handler.codec.http.websocketx.WebSocketFrameDecoder;
import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.vertx.core.AsyncResult;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.http.UpgradeRejectedException;
import io.vertx.core.http.impl.headers.HeadersAdaptor;
import java.net.URI;
import static io.netty.handler.codec.http.websocketx.WebSocketVersion.V00;
import static io.netty.handler.codec.http.websocketx.WebSocketVersion.V07;
import static io.netty.handler.codec.http.websocketx.WebSocketVersion.V08;
import static io.netty.handler.codec.http.websocketx.WebSocketVersion.V13;
class WebSocketHandshakeInboundHandler extends ChannelInboundHandlerAdapter {
private final Handler<AsyncResult<HeadersAdaptor>> wsHandler;
private final WebSocketClientHandshaker handshaker;
private ChannelHandlerContext chctx;
private FullHttpResponse response;
WebSocketHandshakeInboundHandler(WebSocketClientHandshaker handshaker, Handler<AsyncResult<HeadersAdaptor>> wsHandler) {
this.handshaker = handshaker;
this.wsHandler = wsHandler;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
chctx = ctx;
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
wsHandler.handle(Future.failedFuture(new WebSocketHandshakeException("Connection closed while handshake in process")));
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof HttpResponse) {
HttpResponse resp = (HttpResponse) msg;
response = new DefaultFullHttpResponse(resp.protocolVersion(), resp.status());
response.headers().add(resp.headers());
}
if (msg instanceof HttpContent) {
if (response != null) {
response.content().writeBytes(((HttpContent) msg).content());
if (msg instanceof LastHttpContent) {
response.trailingHeaders().add(((LastHttpContent) msg).trailingHeaders());
ChannelPipeline pipeline = chctx.pipeline();
pipeline.remove(WebSocketHandshakeInboundHandler.this);
ChannelHandler handler = pipeline.get(HttpContentDecompressor.class);
if (handler != null) {
ctx.pipeline().remove(handler);
}
Future<HeadersAdaptor> fut = handshakeComplete(response);
wsHandler.handle(fut);
}
}
}
}
private Future<HeadersAdaptor> handshakeComplete(FullHttpResponse response) {
int sc = response.status().code();
if (sc != 101) {
UpgradeRejectedException failure = new UpgradeRejectedException("WebSocket connection attempt returned HTTP status code " + sc, sc);
return Future.failedFuture(failure);
} else {
try {
handshaker.finishHandshake(chctx.channel(), response);
return Future.succeededFuture(new HeadersAdaptor(response.headers()));
} catch (WebSocketHandshakeException e) {
return Future.failedFuture(e);
}
}
}
static WebSocketClientHandshaker newHandshaker(
URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking) {
WebSocketDecoderConfig config = WebSocketDecoderConfig.newBuilder()
.expectMaskedFrames(false)
.allowExtensions(allowExtensions)
.maxFramePayloadLength(maxFramePayloadLength)
.allowMaskMismatch(false)
.closeOnProtocolViolation(false)
.build();
if (version == V13) {
return new WebSocketClientHandshaker13(
webSocketURL, V13, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, false, -1) {
@Override
protected WebSocketFrameDecoder newWebsocketDecoder() {
return new WebSocket13FrameDecoder(config);
}
};
}
if (version == V08) {
return new WebSocketClientHandshaker08(
webSocketURL, V08, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, false, -1) {
@Override
protected WebSocketFrameDecoder newWebsocketDecoder() {
return new WebSocket08FrameDecoder(config);
}
};
}
if (version == V07) {
return new WebSocketClientHandshaker07(
webSocketURL, V07, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, false, -1) {
@Override
protected WebSocketFrameDecoder newWebsocketDecoder() {
return new WebSocket07FrameDecoder(config);
}
};
}
if (version == V00) {
return new WebSocketClientHandshaker00(
webSocketURL, V00, subprotocol, customHeaders, maxFramePayloadLength, -1);
}
throw new WebSocketHandshakeException("Protocol version " + version + " not supported.");
}}