package io.undertow.server.protocol.http2;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.function.Supplier;
import static java.nio.charset.StandardCharsets.ISO_8859_1;
import javax.net.ssl.SSLSession;
import io.undertow.UndertowLogger;
import io.undertow.UndertowOptions;
import io.undertow.conduits.HeadStreamSinkConduit;
import io.undertow.protocols.http2.AbstractHttp2StreamSourceChannel;
import io.undertow.protocols.http2.Http2Channel;
import io.undertow.protocols.http2.Http2DataStreamSinkChannel;
import io.undertow.protocols.http2.Http2HeadersStreamSinkChannel;
import io.undertow.protocols.http2.Http2StreamSourceChannel;
import io.undertow.server.ConduitWrapper;
import io.undertow.server.ConnectorStatisticsImpl;
import io.undertow.server.Connectors;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.protocol.http.HttpAttachments;
import io.undertow.server.protocol.http.HttpContinue;
import io.undertow.server.protocol.http.HttpRequestParser;
import io.undertow.util.ConduitFactory;
import io.undertow.util.HeaderMap;
import io.undertow.util.HeaderValues;
import io.undertow.util.Headers;
import io.undertow.util.ImmediatePooledByteBuffer;
import io.undertow.util.Methods;
import io.undertow.util.ParameterLimitException;
import io.undertow.util.Protocols;
import io.undertow.util.StatusCodes;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.channels.Channels;
import org.xnio.conduits.StreamSinkConduit;
import static io.undertow.protocols.http2.Http2Channel.AUTHORITY;
import static io.undertow.protocols.http2.Http2Channel.METHOD;
import static io.undertow.protocols.http2.Http2Channel.PATH;
import static io.undertow.protocols.http2.Http2Channel.SCHEME;
public class Http2ReceiveListener implements ChannelListener<Http2Channel> {
private final HttpHandler rootHandler;
private final long maxEntitySize;
private final OptionMap undertowOptions;
private final String encoding;
private final boolean decode;
private final StringBuilder decodeBuffer = new StringBuilder();
private final boolean allowEncodingSlash;
private final int bufferSize;
private final int maxParameters;
private final boolean recordRequestStartTime;
private final ConnectorStatisticsImpl connectorStatistics;
public Http2ReceiveListener(HttpHandler rootHandler, OptionMap undertowOptions, int bufferSize, ConnectorStatisticsImpl connectorStatistics) {
this.rootHandler = rootHandler;
this.undertowOptions = undertowOptions;
this.bufferSize = bufferSize;
this.connectorStatistics = connectorStatistics;
this.maxEntitySize = undertowOptions.get(UndertowOptions.MAX_ENTITY_SIZE, UndertowOptions.DEFAULT_MAX_ENTITY_SIZE);
this.allowEncodingSlash = undertowOptions.get(UndertowOptions.ALLOW_ENCODED_SLASH, false);
this.decode = undertowOptions.get(UndertowOptions.DECODE_URL, true);
this.maxParameters = undertowOptions.get(UndertowOptions.MAX_PARAMETERS, UndertowOptions.DEFAULT_MAX_PARAMETERS);
this.recordRequestStartTime = undertowOptions.get(UndertowOptions.RECORD_REQUEST_START_TIME, false);
if (undertowOptions.get(UndertowOptions.DECODE_URL, true)) {
this.encoding = undertowOptions.get(UndertowOptions.URL_CHARSET, StandardCharsets.UTF_8.name());
} else {
this.encoding = null;
}
}
@Override
public void handleEvent(Http2Channel channel) {
try {
final AbstractHttp2StreamSourceChannel frame = channel.receive();
if (frame == null) {
return;
}
if (frame instanceof Http2StreamSourceChannel) {
handleRequests(channel, (Http2StreamSourceChannel) frame);
}
} catch (IOException e) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(e);
IoUtils.safeClose(channel);
} catch (Throwable t) {
UndertowLogger.REQUEST_IO_LOGGER.handleUnexpectedFailure(t);
IoUtils.safeClose(channel);
}
}
private void handleRequests(Http2Channel channel, Http2StreamSourceChannel frame) {
final Http2StreamSourceChannel dataChannel = frame;
final Http2ServerConnection connection = new Http2ServerConnection(channel, dataChannel, undertowOptions, bufferSize, rootHandler);
if (!checkRequestHeaders(dataChannel.getHeaders())) {
channel.sendRstStream(frame.getStreamId(), Http2Channel.ERROR_PROTOCOL_ERROR);
try {
Channels.drain(frame, Long.MAX_VALUE);
} catch (IOException e) {
}
return;
}
final HttpServerExchange exchange = new HttpServerExchange(connection, dataChannel.getHeaders(), dataChannel.getResponseChannel().getHeaders(), maxEntitySize);
dataChannel.setTrailersHandler(new Http2StreamSourceChannel.TrailersHandler() {
@Override
public void handleTrailers(HeaderMap headerMap) {
exchange.putAttachment(HttpAttachments.REQUEST_TRAILERS, headerMap);
}
});
connection.setExchange(exchange);
dataChannel.setMaxStreamSize(maxEntitySize);
exchange.setRequestScheme(exchange.getRequestHeaders().getFirst(SCHEME));
exchange.setRequestMethod(Methods.fromString(exchange.getRequestHeaders().getFirst(METHOD)));
exchange.getRequestHeaders().put(Headers.HOST, exchange.getRequestHeaders().getFirst(AUTHORITY));
if(!Connectors.areRequestHeadersValid(exchange.getRequestHeaders())) {
UndertowLogger.REQUEST_IO_LOGGER.debugf("Invalid headers in HTTP/2 request, closing connection. Remote peer %s", connection.getPeerAddress());
channel.sendGoAway(Http2Channel.ERROR_PROTOCOL_ERROR);
return;
}
final String path = exchange.getRequestHeaders().getFirst(PATH);
if(path == null || path.isEmpty()) {
UndertowLogger.REQUEST_IO_LOGGER.debugf("No :path header sent in HTTP/2 request, closing connection. Remote peer %s", connection.getPeerAddress());
channel.sendGoAway(Http2Channel.ERROR_PROTOCOL_ERROR);
return;
}
if (recordRequestStartTime) {
Connectors.setRequestStartTime(exchange);
}
handleCommonSetup(dataChannel.getResponseChannel(), exchange, connection);
if(!dataChannel.isOpen()) {
Connectors.terminateRequest(exchange);
} else {
dataChannel.setCompletionListener(new ChannelListener<Http2StreamSourceChannel>() {
@Override
public void handleEvent(Http2StreamSourceChannel channel) {
Connectors.terminateRequest(exchange);
}
});
}
if(connectorStatistics != null) {
connectorStatistics.setup(exchange);
}
try {
Connectors.setExchangeRequestPath(exchange, path, encoding, decode, allowEncodingSlash, decodeBuffer, maxParameters);
} catch (ParameterLimitException e) {
UndertowLogger.REQUEST_IO_LOGGER.debug("Failed to set request path", e);
exchange.setStatusCode(StatusCodes.BAD_REQUEST);
exchange.endExchange();
return;
}
exchange.getRequestHeaders().remove(AUTHORITY);
exchange.getRequestHeaders().remove(PATH);
exchange.getRequestHeaders().remove(SCHEME);
exchange.getRequestHeaders().remove(METHOD);
Connectors.executeRootHandler(rootHandler, exchange);
}
void handleInitialRequest(HttpServerExchange initial, Http2Channel channel, byte[] data) {
Http2HeadersStreamSinkChannel sink = channel.createInitialUpgradeResponseStream();
final Http2ServerConnection connection = new Http2ServerConnection(channel, sink, undertowOptions, bufferSize, rootHandler);
HeaderMap requestHeaders = new HeaderMap();
for(HeaderValues hv : initial.getRequestHeaders()) {
requestHeaders.putAll(hv.getHeaderName(), hv);
}
final HttpServerExchange exchange = new HttpServerExchange(connection, requestHeaders, sink.getHeaders(), maxEntitySize);
if(initial.getRequestHeaders().contains(Headers.EXPECT)) {
HttpContinue.markContinueResponseSent(exchange);
}
if(initial.getAttachment(HttpAttachments.REQUEST_TRAILERS) != null) {
exchange.putAttachment(HttpAttachments.REQUEST_TRAILERS, initial.getAttachment(HttpAttachments.REQUEST_TRAILERS));
}
Connectors.setRequestStartTime(initial, exchange);
connection.setExchange(exchange);
exchange.setRequestScheme(initial.getRequestScheme());
exchange.setRequestMethod(initial.getRequestMethod());
exchange.setQueryString(initial.getQueryString());
if (data != null) {
Connectors.ungetRequestBytes(exchange, new ImmediatePooledByteBuffer(ByteBuffer.wrap(data)));
}
Connectors.terminateRequest(exchange);
String uri = exchange.getQueryString().isEmpty() ? initial.getRequestURI() : initial.getRequestURI() + '?' + exchange.getQueryString();
try {
Connectors.setExchangeRequestPath(exchange, uri, encoding, decode, allowEncodingSlash, decodeBuffer, maxParameters);
} catch (ParameterLimitException e) {
exchange.setStatusCode(StatusCodes.BAD_REQUEST);
exchange.endExchange();
return;
}
handleCommonSetup(sink, exchange, connection);
Connectors.executeRootHandler(rootHandler, exchange);
}
private void handleCommonSetup(Http2HeadersStreamSinkChannel sink, HttpServerExchange exchange, Http2ServerConnection connection) {
Http2Channel channel = sink.getChannel();
SSLSession session = channel.getSslSession();
if(session != null) {
connection.setSslSessionInfo(new Http2SslSessionInfo(channel));
}
sink.setTrailersProducer(new Http2DataStreamSinkChannel.TrailersProducer() {
@Override
public HeaderMap getTrailers() {
Supplier<HeaderMap> supplier = exchange.getAttachment(HttpAttachments.RESPONSE_TRAILER_SUPPLIER);
if(supplier != null) {
return supplier.get();
}
return exchange.getAttachment(HttpAttachments.RESPONSE_TRAILERS);
}
});
sink.setCompletionListener(new ChannelListener<Http2DataStreamSinkChannel>() {
@Override
public void handleEvent(Http2DataStreamSinkChannel channel) {
Connectors.terminateResponse(exchange);
}
});
exchange.setProtocol(Protocols.HTTP_2_0);
if(exchange.getRequestMethod().equals(Methods.HEAD)) {
exchange.addResponseWrapper(new ConduitWrapper<StreamSinkConduit>() {
@Override
public StreamSinkConduit wrap(ConduitFactory<StreamSinkConduit> factory, HttpServerExchange exchange) {
return new HeadStreamSinkConduit(factory.create(), null, true);
}
});
}
}
private boolean (HeaderMap headers) {
if (headers.count(METHOD) != 1 || headers.contains(Headers.CONNECTION)) {
return false;
}
if (headers.get(METHOD).contains(Methods.CONNECT_STRING)) {
if (headers.contains(SCHEME) || headers.contains(PATH) || headers.count(AUTHORITY) != 1) {
return false;
}
} else if (headers.count(SCHEME) != 1 || headers.count(PATH) != 1) {
return false;
}
if (headers.contains(Headers.TE)) {
for (String value : headers.get(Headers.TE)) {
if (!value.equals("trailers")) {
return false;
}
}
}
if (headers.contains(PATH)) {
for (byte b: headers.get(PATH).getFirst().getBytes(ISO_8859_1)) {
if (!HttpRequestParser.isTargetCharacterAllowed((char)b)){
return false;
}
}
}
if (headers.contains(SCHEME)) {
for (byte b: headers.get(SCHEME).getFirst().getBytes(ISO_8859_1)) {
if (!Connectors.isValidSchemeCharacter(b)){
return false;
}
}
}
if (headers.contains(AUTHORITY)) {
for (byte b: headers.get(AUTHORITY).getFirst().getBytes(ISO_8859_1)) {
if (!HttpRequestParser.isTargetCharacterAllowed((char)b)){
return false;
}
}
}
if (headers.contains(METHOD)) {
for (byte b: headers.get(METHOD).getFirst().getBytes(ISO_8859_1)) {
if (!Connectors.isValidTokenCharacter(b)){
return false;
}
}
}
return true;
}
}