package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http2.Http2FrameReader.Configuration;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.UnstableApi;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.PING_FRAME_PAYLOAD_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_INITIAL_WINDOW_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_MAX_FRAME_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.headerListSizeExceeded;
import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid;
import static io.netty.handler.codec.http2.Http2CodecUtil.readUnsignedInt;
import static io.netty.handler.codec.http2.Http2Error.FLOW_CONTROL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.handler.codec.http2.Http2Exception.streamError;
import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION;
import static io.netty.handler.codec.http2.Http2FrameTypes.DATA;
import static io.netty.handler.codec.http2.Http2FrameTypes.GO_AWAY;
import static io.netty.handler.codec.http2.Http2FrameTypes.HEADERS;
import static io.netty.handler.codec.http2.Http2FrameTypes.PING;
import static io.netty.handler.codec.http2.Http2FrameTypes.PRIORITY;
import static io.netty.handler.codec.http2.Http2FrameTypes.PUSH_PROMISE;
import static io.netty.handler.codec.http2.Http2FrameTypes.RST_STREAM;
import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS;
import static io.netty.handler.codec.http2.Http2FrameTypes.WINDOW_UPDATE;
@UnstableApi
public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSizePolicy, Configuration {
private final Http2HeadersDecoder headersDecoder;
private boolean readingHeaders = true;
private boolean readError;
private byte frameType;
private int streamId;
private Http2Flags flags;
private int payloadLength;
private HeadersContinuation headersContinuation;
private int maxFrameSize;
public DefaultHttp2FrameReader() {
this(true);
}
public DefaultHttp2FrameReader(boolean validateHeaders) {
this(new DefaultHttp2HeadersDecoder(validateHeaders));
}
public DefaultHttp2FrameReader(Http2HeadersDecoder headersDecoder) {
this.headersDecoder = headersDecoder;
maxFrameSize = DEFAULT_MAX_FRAME_SIZE;
}
@Override
public Http2HeadersDecoder.Configuration headersConfiguration() {
return headersDecoder.configuration();
}
@Override
public Configuration configuration() {
return this;
}
@Override
public Http2FrameSizePolicy frameSizePolicy() {
return this;
}
@Override
public void maxFrameSize(int max) throws Http2Exception {
if (!isMaxFrameSizeValid(max)) {
throw streamError(streamId, FRAME_SIZE_ERROR,
"Invalid MAX_FRAME_SIZE specified in sent settings: %d", max);
}
maxFrameSize = max;
}
@Override
public int maxFrameSize() {
return maxFrameSize;
}
@Override
public void close() {
closeHeadersContinuation();
}
private void closeHeadersContinuation() {
if (headersContinuation != null) {
headersContinuation.close();
headersContinuation = null;
}
}
@Override
public void readFrame(ChannelHandlerContext ctx, ByteBuf input, Http2FrameListener listener)
throws Http2Exception {
if (readError) {
input.skipBytes(input.readableBytes());
return;
}
try {
do {
if (readingHeaders) {
processHeaderState(input);
if (readingHeaders) {
return;
}
}
processPayloadState(ctx, input, listener);
if (!readingHeaders) {
return;
}
} while (input.isReadable());
} catch (Http2Exception e) {
readError = !Http2Exception.isStreamError(e);
throw e;
} catch (RuntimeException e) {
readError = true;
throw e;
} catch (Throwable cause) {
readError = true;
PlatformDependent.throwException(cause);
}
}
private void processHeaderState(ByteBuf in) throws Http2Exception {
if (in.readableBytes() < FRAME_HEADER_LENGTH) {
return;
}
payloadLength = in.readUnsignedMedium();
if (payloadLength > maxFrameSize) {
throw connectionError(FRAME_SIZE_ERROR, "Frame length: %d exceeds maximum: %d", payloadLength,
maxFrameSize);
}
frameType = in.readByte();
flags = new Http2Flags(in.readUnsignedByte());
streamId = readUnsignedInt(in);
readingHeaders = false;
switch (frameType) {
case DATA:
verifyDataFrame();
break;
case HEADERS:
verifyHeadersFrame();
break;
case PRIORITY:
verifyPriorityFrame();
break;
case RST_STREAM:
verifyRstStreamFrame();
break;
case SETTINGS:
verifySettingsFrame();
break;
case PUSH_PROMISE:
verifyPushPromiseFrame();
break;
case PING:
verifyPingFrame();
break;
case GO_AWAY:
verifyGoAwayFrame();
break;
case WINDOW_UPDATE:
verifyWindowUpdateFrame();
break;
case CONTINUATION:
verifyContinuationFrame();
break;
default:
verifyUnknownFrame();
break;
}
}
private void processPayloadState(ChannelHandlerContext ctx, ByteBuf in, Http2FrameListener listener)
throws Http2Exception {
if (in.readableBytes() < payloadLength) {
return;
}
ByteBuf payload = in.readSlice(payloadLength);
readingHeaders = true;
switch (frameType) {
case DATA:
readDataFrame(ctx, payload, listener);
break;
case HEADERS:
readHeadersFrame(ctx, payload, listener);
break;
case PRIORITY:
readPriorityFrame(ctx, payload, listener);
break;
case RST_STREAM:
readRstStreamFrame(ctx, payload, listener);
break;
case SETTINGS:
readSettingsFrame(ctx, payload, listener);
break;
case PUSH_PROMISE:
readPushPromiseFrame(ctx, payload, listener);
break;
case PING:
readPingFrame(ctx, payload.readLong(), listener);
break;
case GO_AWAY:
readGoAwayFrame(ctx, payload, listener);
break;
case WINDOW_UPDATE:
readWindowUpdateFrame(ctx, payload, listener);
break;
case CONTINUATION:
readContinuationFrame(payload, listener);
break;
default:
readUnknownFrame(ctx, payload, listener);
break;
}
}
private void verifyDataFrame() throws Http2Exception {
verifyAssociatedWithAStream();
verifyNotProcessingHeaders();
verifyPayloadLength(payloadLength);
if (payloadLength < flags.getPaddingPresenceFieldLength()) {
throw streamError(streamId, FRAME_SIZE_ERROR,
"Frame length %d too small.", payloadLength);
}
}
private void verifyHeadersFrame() throws Http2Exception {
verifyAssociatedWithAStream();
verifyNotProcessingHeaders();
verifyPayloadLength(payloadLength);
int requiredLength = flags.getPaddingPresenceFieldLength() + flags.getNumPriorityBytes();
if (payloadLength < requiredLength) {
throw streamError(streamId, FRAME_SIZE_ERROR,
"Frame length too small." + payloadLength);
}
}
private void verifyPriorityFrame() throws Http2Exception {
verifyAssociatedWithAStream();
verifyNotProcessingHeaders();
if (payloadLength != PRIORITY_ENTRY_LENGTH) {
throw streamError(streamId, FRAME_SIZE_ERROR,
"Invalid frame length %d.", payloadLength);
}
}
private void verifyRstStreamFrame() throws Http2Exception {
verifyAssociatedWithAStream();
verifyNotProcessingHeaders();
if (payloadLength != INT_FIELD_LENGTH) {
throw connectionError(FRAME_SIZE_ERROR, "Invalid frame length %d.", payloadLength);
}
}
private void verifySettingsFrame() throws Http2Exception {
verifyNotProcessingHeaders();
verifyPayloadLength(payloadLength);
if (streamId != 0) {
throw connectionError(PROTOCOL_ERROR, "A stream ID must be zero.");
}
if (flags.ack() && payloadLength > 0) {
throw connectionError(FRAME_SIZE_ERROR, "Ack settings frame must have an empty payload.");
}
if (payloadLength % SETTING_ENTRY_LENGTH > 0) {
throw connectionError(FRAME_SIZE_ERROR, "Frame length %d invalid.", payloadLength);
}
}
private void verifyPushPromiseFrame() throws Http2Exception {
verifyNotProcessingHeaders();
verifyPayloadLength(payloadLength);
int minLength = flags.getPaddingPresenceFieldLength() + INT_FIELD_LENGTH;
if (payloadLength < minLength) {
throw streamError(streamId, FRAME_SIZE_ERROR,
"Frame length %d too small.", payloadLength);
}
}
private void verifyPingFrame() throws Http2Exception {
verifyNotProcessingHeaders();
if (streamId != 0) {
throw connectionError(PROTOCOL_ERROR, "A stream ID must be zero.");
}
if (payloadLength != PING_FRAME_PAYLOAD_LENGTH) {
throw connectionError(FRAME_SIZE_ERROR,
"Frame length %d incorrect size for ping.", payloadLength);
}
}
private void verifyGoAwayFrame() throws Http2Exception {
verifyNotProcessingHeaders();
verifyPayloadLength(payloadLength);
if (streamId != 0) {
throw connectionError(PROTOCOL_ERROR, "A stream ID must be zero.");
}
if (payloadLength < 8) {
throw connectionError(FRAME_SIZE_ERROR, "Frame length %d too small.", payloadLength);
}
}
private void verifyWindowUpdateFrame() throws Http2Exception {
verifyNotProcessingHeaders();
verifyStreamOrConnectionId(streamId, "Stream ID");
if (payloadLength != INT_FIELD_LENGTH) {
throw connectionError(FRAME_SIZE_ERROR, "Invalid frame length %d.", payloadLength);
}
}
private void verifyContinuationFrame() throws Http2Exception {
verifyAssociatedWithAStream();
verifyPayloadLength(payloadLength);
if (headersContinuation == null) {
throw connectionError(PROTOCOL_ERROR, "Received %s frame but not currently processing headers.",
frameType);
}
if (streamId != headersContinuation.getStreamId()) {
throw connectionError(PROTOCOL_ERROR, "Continuation stream ID does not match pending headers. "
+ "Expected %d, but received %d.", headersContinuation.getStreamId(), streamId);
}
if (payloadLength < flags.getPaddingPresenceFieldLength()) {
throw streamError(streamId, FRAME_SIZE_ERROR,
"Frame length %d too small for padding.", payloadLength);
}
}
private void verifyUnknownFrame() throws Http2Exception {
verifyNotProcessingHeaders();
}
private void readDataFrame(ChannelHandlerContext ctx, ByteBuf payload,
Http2FrameListener listener) throws Http2Exception {
int padding = readPadding(payload);
verifyPadding(padding);
int dataLength = lengthWithoutTrailingPadding(payload.readableBytes(), padding);
ByteBuf data = payload.readSlice(dataLength);
listener.onDataRead(ctx, streamId, data, padding, flags.endOfStream());
payload.skipBytes(payload.readableBytes());
}
private void readHeadersFrame(final ChannelHandlerContext ctx, ByteBuf payload,
Http2FrameListener listener) throws Http2Exception {
final int headersStreamId = streamId;
final Http2Flags headersFlags = flags;
final int padding = readPadding(payload);
verifyPadding(padding);
if (flags.priorityPresent()) {
long word1 = payload.readUnsignedInt();
final boolean exclusive = (word1 & 0x80000000L) != 0;
final int streamDependency = (int) (word1 & 0x7FFFFFFFL);
if (streamDependency == streamId) {
throw streamError(streamId, PROTOCOL_ERROR, "A stream cannot depend on itself.");
}
final short weight = (short) (payload.readUnsignedByte() + 1);
final ByteBuf fragment = payload.readSlice(lengthWithoutTrailingPadding(payload.readableBytes(), padding));
headersContinuation = new HeadersContinuation() {
@Override
public int getStreamId() {
return headersStreamId;
}
@Override
public void processFragment(boolean endOfHeaders, ByteBuf fragment,
Http2FrameListener listener) throws Http2Exception {
final HeadersBlockBuilder hdrBlockBuilder = headersBlockBuilder();
hdrBlockBuilder.addFragment(fragment, ctx.alloc(), endOfHeaders);
if (endOfHeaders) {
listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), streamDependency,
weight, exclusive, padding, headersFlags.endOfStream());
}
}
};
headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener);
resetHeadersContinuationIfEnd(flags.endOfHeaders());
return;
}
headersContinuation = new HeadersContinuation() {
@Override
public int getStreamId() {
return headersStreamId;
}
@Override
public void processFragment(boolean endOfHeaders, ByteBuf fragment,
Http2FrameListener listener) throws Http2Exception {
final HeadersBlockBuilder hdrBlockBuilder = headersBlockBuilder();
hdrBlockBuilder.addFragment(fragment, ctx.alloc(), endOfHeaders);
if (endOfHeaders) {
listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), padding,
headersFlags.endOfStream());
}
}
};
final ByteBuf fragment = payload.readSlice(lengthWithoutTrailingPadding(payload.readableBytes(), padding));
headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener);
resetHeadersContinuationIfEnd(flags.endOfHeaders());
}
private void resetHeadersContinuationIfEnd(boolean endOfHeaders) {
if (endOfHeaders) {
closeHeadersContinuation();
}
}
private void readPriorityFrame(ChannelHandlerContext ctx, ByteBuf payload,
Http2FrameListener listener) throws Http2Exception {
long word1 = payload.readUnsignedInt();
boolean exclusive = (word1 & 0x80000000L) != 0;
int streamDependency = (int) (word1 & 0x7FFFFFFFL);
if (streamDependency == streamId) {
throw streamError(streamId, PROTOCOL_ERROR, "A stream cannot depend on itself.");
}
short weight = (short) (payload.readUnsignedByte() + 1);
listener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive);
}
private void readRstStreamFrame(ChannelHandlerContext ctx, ByteBuf payload,
Http2FrameListener listener) throws Http2Exception {
long errorCode = payload.readUnsignedInt();
listener.onRstStreamRead(ctx, streamId, errorCode);
}
private void readSettingsFrame(ChannelHandlerContext ctx, ByteBuf payload,
Http2FrameListener listener) throws Http2Exception {
if (flags.ack()) {
listener.onSettingsAckRead(ctx);
} else {
int numSettings = payloadLength / SETTING_ENTRY_LENGTH;
Http2Settings settings = new Http2Settings();
for (int index = 0; index < numSettings; ++index) {
char id = (char) payload.readUnsignedShort();
long value = payload.readUnsignedInt();
try {
settings.put(id, Long.valueOf(value));
} catch (IllegalArgumentException e) {
switch(id) {
case SETTINGS_MAX_FRAME_SIZE:
throw connectionError(PROTOCOL_ERROR, e, e.getMessage());
case SETTINGS_INITIAL_WINDOW_SIZE:
throw connectionError(FLOW_CONTROL_ERROR, e, e.getMessage());
default:
throw connectionError(PROTOCOL_ERROR, e, e.getMessage());
}
}
}
listener.onSettingsRead(ctx, settings);
}
}
private void readPushPromiseFrame(final ChannelHandlerContext ctx, ByteBuf payload,
Http2FrameListener listener) throws Http2Exception {
final int pushPromiseStreamId = streamId;
final int padding = readPadding(payload);
verifyPadding(padding);
final int promisedStreamId = readUnsignedInt(payload);
headersContinuation = new HeadersContinuation() {
@Override
public int getStreamId() {
return pushPromiseStreamId;
}
@Override
public void processFragment(boolean endOfHeaders, ByteBuf fragment,
Http2FrameListener listener) throws Http2Exception {
headersBlockBuilder().addFragment(fragment, ctx.alloc(), endOfHeaders);
if (endOfHeaders) {
listener.onPushPromiseRead(ctx, pushPromiseStreamId, promisedStreamId,
headersBlockBuilder().headers(), padding);
}
}
};
final ByteBuf fragment = payload.readSlice(lengthWithoutTrailingPadding(payload.readableBytes(), padding));
headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener);
resetHeadersContinuationIfEnd(flags.endOfHeaders());
}
private void readPingFrame(ChannelHandlerContext ctx, long data,
Http2FrameListener listener) throws Http2Exception {
if (flags.ack()) {
listener.onPingAckRead(ctx, data);
} else {
listener.onPingRead(ctx, data);
}
}
private static void readGoAwayFrame(ChannelHandlerContext ctx, ByteBuf payload,
Http2FrameListener listener) throws Http2Exception {
int lastStreamId = readUnsignedInt(payload);
long errorCode = payload.readUnsignedInt();
ByteBuf debugData = payload.readSlice(payload.readableBytes());
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
}
private void readWindowUpdateFrame(ChannelHandlerContext ctx, ByteBuf payload,
Http2FrameListener listener) throws Http2Exception {
int windowSizeIncrement = readUnsignedInt(payload);
if (windowSizeIncrement == 0) {
throw streamError(streamId, PROTOCOL_ERROR,
"Received WINDOW_UPDATE with delta 0 for stream: %d", streamId);
}
listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement);
}
private void readContinuationFrame(ByteBuf payload, Http2FrameListener listener)
throws Http2Exception {
final ByteBuf continuationFragment = payload.readSlice(payload.readableBytes());
headersContinuation.processFragment(flags.endOfHeaders(), continuationFragment,
listener);
resetHeadersContinuationIfEnd(flags.endOfHeaders());
}
private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener)
throws Http2Exception {
payload = payload.readSlice(payload.readableBytes());
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
}
private int readPadding(ByteBuf payload) {
if (!flags.paddingPresent()) {
return 0;
}
return payload.readUnsignedByte() + 1;
}
private void verifyPadding(int padding) throws Http2Exception {
int len = lengthWithoutTrailingPadding(payloadLength, padding);
if (len < 0) {
throw connectionError(PROTOCOL_ERROR, "Frame payload too small for padding.");
}
}
private static int lengthWithoutTrailingPadding(int readableBytes, int padding) {
return padding == 0
? readableBytes
: readableBytes - (padding - 1);
}
private abstract class HeadersContinuation {
private final HeadersBlockBuilder builder = new HeadersBlockBuilder();
abstract int getStreamId();
abstract void processFragment(boolean endOfHeaders, ByteBuf fragment,
Http2FrameListener listener) throws Http2Exception;
final HeadersBlockBuilder headersBlockBuilder() {
return builder;
}
final void close() {
builder.close();
}
}
protected class HeadersBlockBuilder {
private ByteBuf headerBlock;
private void headerSizeExceeded() throws Http2Exception {
close();
headerListSizeExceeded(headersDecoder.configuration().maxHeaderListSizeGoAway());
}
final void addFragment(ByteBuf fragment, ByteBufAllocator alloc, boolean endOfHeaders) throws Http2Exception {
if (headerBlock == null) {
if (fragment.readableBytes() > headersDecoder.configuration().maxHeaderListSizeGoAway()) {
headerSizeExceeded();
}
if (endOfHeaders) {
headerBlock = fragment.retain();
} else {
headerBlock = alloc.buffer(fragment.readableBytes());
headerBlock.writeBytes(fragment);
}
return;
}
if (headersDecoder.configuration().maxHeaderListSizeGoAway() - fragment.readableBytes() <
headerBlock.readableBytes()) {
headerSizeExceeded();
}
if (headerBlock.isWritable(fragment.readableBytes())) {
headerBlock.writeBytes(fragment);
} else {
ByteBuf buf = alloc.buffer(headerBlock.readableBytes() + fragment.readableBytes());
buf.writeBytes(headerBlock);
buf.writeBytes(fragment);
headerBlock.release();
headerBlock = buf;
}
}
Http2Headers headers() throws Http2Exception {
try {
return headersDecoder.decodeHeaders(streamId, headerBlock);
} finally {
close();
}
}
void close() {
if (headerBlock != null) {
headerBlock.release();
headerBlock = null;
}
headersContinuation = null;
}
}
private void verifyNotProcessingHeaders() throws Http2Exception {
if (headersContinuation != null) {
throw connectionError(PROTOCOL_ERROR, "Received frame of type %s while processing headers on stream %d.",
frameType, headersContinuation.getStreamId());
}
}
private void verifyPayloadLength(int payloadLength) throws Http2Exception {
if (payloadLength > maxFrameSize) {
throw connectionError(PROTOCOL_ERROR, "Total payload length %d exceeds max frame length.", payloadLength);
}
}
private void verifyAssociatedWithAStream() throws Http2Exception {
if (streamId == 0) {
throw connectionError(PROTOCOL_ERROR, "Frame of type %s must be associated with a stream.", frameType);
}
}
private static void verifyStreamOrConnectionId(int streamId, String argumentName)
throws Http2Exception {
if (streamId < 0) {
throw connectionError(PROTOCOL_ERROR, "%s must be >= 0", argumentName);
}
}
}