package org.apache.coyote.http2;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.WebConnection;
import org.apache.coyote.ProtocolException;
import org.apache.tomcat.util.net.SocketEvent;
import org.apache.tomcat.util.net.SocketWrapperBase;
import org.apache.tomcat.util.net.SocketWrapperBase.BlockingMode;
import org.apache.tomcat.util.net.SocketWrapperBase.CompletionCheck;
import org.apache.tomcat.util.net.SocketWrapperBase.CompletionHandlerCall;
import org.apache.tomcat.util.net.SocketWrapperBase.CompletionState;
class Http2AsyncParser extends Http2Parser {
private final SocketWrapperBase<?> socketWrapper;
private final Http2AsyncUpgradeHandler upgradeHandler;
private Throwable error = null;
Http2AsyncParser(String connectionId, Input input, Output output, SocketWrapperBase<?> socketWrapper, Http2AsyncUpgradeHandler upgradeHandler) {
super(connectionId, input, output);
this.socketWrapper = socketWrapper;
socketWrapper.getSocketBufferHandler().expand(input.getMaxFrameSize());
this.upgradeHandler = upgradeHandler;
}
@Override
void readConnectionPreface(WebConnection webConnection, Stream stream) throws Http2Exception {
byte[] prefaceData = new byte[CLIENT_PREFACE_START.length];
ByteBuffer preface = ByteBuffer.wrap(prefaceData);
ByteBuffer header = ByteBuffer.allocate(9);
ByteBuffer framePaylod = ByteBuffer.allocate(input.getMaxFrameSize());
PrefaceCompletionHandler handler = new PrefaceCompletionHandler(webConnection, stream, prefaceData, preface, header, framePaylod);
socketWrapper.read(BlockingMode.NON_BLOCK, socketWrapper.getReadTimeout(), TimeUnit.MILLISECONDS, null,
handler, handler, preface, header, framePaylod);
}
private class PrefaceCompletionHandler extends FrameCompletionHandler {
private boolean prefaceValidated = false;
private final WebConnection webConnection;
private final Stream stream;
private final byte[] prefaceData;
private PrefaceCompletionHandler(WebConnection webConnection, Stream stream, byte[] prefaceData, ByteBuffer... buffers) {
super(FrameType.SETTINGS, buffers);
this.webConnection = webConnection;
this.stream = stream;
this.prefaceData = prefaceData;
}
@Override
public CompletionHandlerCall callHandler(CompletionState state, ByteBuffer[] buffers, int offset, int length) {
if (offset != 0 || length != 3) {
try {
throw new IllegalArgumentException(sm.getString("http2Parser.invalidBuffers"));
} catch (IllegalArgumentException e) {
error = e;
return CompletionHandlerCall.DONE;
}
}
if (!prefaceValidated) {
if (buffers[0].hasRemaining()) {
return CompletionHandlerCall.CONTINUE;
}
for (int i = 0; i < CLIENT_PREFACE_START.length; i++) {
if (CLIENT_PREFACE_START[i] != prefaceData[i]) {
error = new ProtocolException(sm.getString("http2Parser.preface.invalid"));
return CompletionHandlerCall.DONE;
}
}
prefaceValidated = true;
}
return validate(state, buffers[1], buffers[2]);
}
@Override
public void completed(Long result, Void attachment) {
if (streamException || error == null) {
ByteBuffer payload = buffers[2];
payload.flip();
try {
if (streamException) {
swallow(streamId, payloadSize, false, payload);
} else {
readSettingsFrame(flags, payloadSize, payload);
}
} catch (RuntimeException | IOException | Http2Exception e) {
error = e;
}
if (payload.hasRemaining()) {
socketWrapper.unRead(payload);
}
}
upgradeHandler.processConnectionCallback(webConnection, stream);
upgradeHandler.upgradeDispatch(SocketEvent.OPEN_READ);
}
}
@Override
protected boolean readFrame(boolean block, FrameType expected)
throws IOException, Http2Exception {
handleAsyncException();
ByteBuffer header = ByteBuffer.allocate(9);
ByteBuffer framePaylod = ByteBuffer.allocate(input.getMaxFrameSize());
FrameCompletionHandler handler = new FrameCompletionHandler(expected, header, framePaylod);
CompletionState state =
socketWrapper.read(block ? BlockingMode.BLOCK : BlockingMode.NON_BLOCK, socketWrapper.getReadTimeout(), TimeUnit.MILLISECONDS, null, handler, handler, header, framePaylod);
if (state == CompletionState.ERROR || state == CompletionState.INLINE) {
handleAsyncException();
return true;
} else {
return false;
}
}
private void handleAsyncException()
throws IOException, Http2Exception {
if (error != null) {
Throwable error = this.error;
this.error = null;
if (error instanceof Http2Exception) {
throw (Http2Exception) error;
} else if (error instanceof IOException) {
throw (IOException) error;
} else if (error instanceof RuntimeException) {
throw (RuntimeException) error;
} else {
throw new RuntimeException(error);
}
}
}
private class FrameCompletionHandler implements CompletionCheck, CompletionHandler<Long, Void> {
private final FrameType expected;
protected final ByteBuffer[] buffers;
private boolean parsedFrameHeader = false;
private boolean validated = false;
private CompletionState state = null;
protected int payloadSize;
protected FrameType frameType;
protected int flags;
protected int streamId;
protected boolean streamException = false;
private FrameCompletionHandler(FrameType expected, ByteBuffer... buffers) {
this.expected = expected;
this.buffers = buffers;
}
@Override
public CompletionHandlerCall callHandler(CompletionState state,
ByteBuffer[] buffers, int offset, int length) {
if (offset != 0 || length != 2) {
try {
throw new IllegalArgumentException(sm.getString("http2Parser.invalidBuffers"));
} catch (IllegalArgumentException e) {
error = e;
return CompletionHandlerCall.DONE;
}
}
return validate(state, buffers[0], buffers[1]);
}
protected CompletionHandlerCall validate(CompletionState state, ByteBuffer frameHeaderBuffer, ByteBuffer payload) {
if (!parsedFrameHeader) {
if (frameHeaderBuffer.position() < 9) {
return CompletionHandlerCall.CONTINUE;
}
parsedFrameHeader = true;
payloadSize = ByteUtil.getThreeBytes(frameHeaderBuffer, 0);
frameType = FrameType.valueOf(ByteUtil.getOneByte(frameHeaderBuffer, 3));
flags = ByteUtil.getOneByte(frameHeaderBuffer, 4);
streamId = ByteUtil.get31Bits(frameHeaderBuffer, 5);
}
this.state = state;
if (!validated) {
validated = true;
try {
validateFrame(expected, frameType, streamId, flags, payloadSize);
} catch (StreamException e) {
error = e;
streamException = true;
} catch (Http2Exception e) {
error = e;
return CompletionHandlerCall.DONE;
}
}
if (payload.position() < payloadSize) {
return CompletionHandlerCall.CONTINUE;
}
return CompletionHandlerCall.DONE;
}
@Override
public void completed(Long result, Void attachment) {
if (streamException || error == null) {
ByteBuffer payload = buffers[1];
payload.flip();
try {
boolean continueParsing;
do {
continueParsing = false;
if (streamException) {
swallow(streamId, payloadSize, false, payload);
} else {
switch (frameType) {
case DATA:
readDataFrame(streamId, flags, payloadSize, payload);
break;
case HEADERS:
readHeadersFrame(streamId, flags, payloadSize, payload);
break;
case PRIORITY:
readPriorityFrame(streamId, payload);
break;
case RST:
readRstFrame(streamId, payload);
break;
case SETTINGS:
readSettingsFrame(flags, payloadSize, payload);
break;
case PUSH_PROMISE:
readPushPromiseFrame(streamId, payload);
break;
case PING:
readPingFrame(flags, payload);
break;
case GOAWAY:
readGoawayFrame(payloadSize, payload);
break;
case WINDOW_UPDATE:
readWindowUpdateFrame(streamId, payload);
break;
case CONTINUATION:
readContinuationFrame(streamId, flags, payloadSize, payload);
break;
case UNKNOWN:
readUnknownFrame(streamId, frameType, flags, payloadSize, payload);
}
}
if (payload.remaining() >= 9) {
int position = payload.position();
payloadSize = ByteUtil.getThreeBytes(payload, position);
frameType = FrameType.valueOf(ByteUtil.getOneByte(payload, position + 3));
flags = ByteUtil.getOneByte(payload, position + 4);
streamId = ByteUtil.get31Bits(payload, position + 5);
streamException = false;
if (payload.remaining() - 9 >= payloadSize) {
continueParsing = true;
payload.position(payload.position() + 9);
try {
validateFrame(null, frameType, streamId, flags, payloadSize);
} catch (StreamException e) {
error = e;
streamException = true;
} catch (Http2Exception e) {
error = e;
continueParsing = false;
}
}
}
} while (continueParsing);
} catch (RuntimeException | IOException | Http2Exception e) {
error = e;
}
if (payload.hasRemaining()) {
socketWrapper.unRead(payload);
}
}
if (state == CompletionState.DONE) {
upgradeHandler.upgradeDispatch(SocketEvent.OPEN_READ);
}
}
@Override
public void failed(Throwable e, Void attachment) {
error = e;
if (log.isDebugEnabled()) {
log.debug(sm.getString("http2Parser.error", connectionId, Integer.valueOf(streamId), frameType), e);
}
if (state == null || state == CompletionState.DONE) {
upgradeHandler.upgradeDispatch(SocketEvent.ERROR);
}
}
}
}