package io.undertow.protocols.http2;
import io.undertow.UndertowMessages;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.server.protocol.framed.SendFrameHeader;
import io.undertow.util.HeaderMap;
import io.undertow.util.ImmediatePooledByteBuffer;
import org.xnio.ChannelListener;
import org.xnio.ChannelListeners;
import java.io.IOException;
import java.nio.ByteBuffer;
public class Http2DataStreamSinkChannel extends Http2StreamSinkChannel implements Http2Stream {
private final HeaderMap ;
private boolean first = true;
private final HpackEncoder encoder;
private ChannelListener<Http2DataStreamSinkChannel> completionListener;
private final int frameType;
private boolean completionListenerReady;
private TrailersProducer trailersProducer;
Http2DataStreamSinkChannel(Http2Channel channel, int streamId, int frameType) {
this(channel, streamId, new HeaderMap(), frameType);
}
(Http2Channel channel, int streamId, HeaderMap headers, int frameType) {
super(channel, streamId);
this.encoder = channel.getEncoder();
this.headers = headers;
this.frameType = frameType;
}
public TrailersProducer getTrailersProducer() {
return trailersProducer;
}
public void setTrailersProducer(TrailersProducer trailersProducer) {
this.trailersProducer = trailersProducer;
}
@Override
protected SendFrameHeader () {
int dataPaddingBytes = getChannel().getPaddingBytes();
int attempted = getBuffer().remaining() + dataPaddingBytes + (dataPaddingBytes > 0 ? 1 : 0);
final int fcWindow = grabFlowControlBytes(attempted);
if (fcWindow == 0 && getBuffer().hasRemaining()) {
return new SendFrameHeader(getBuffer().remaining(), null);
}
if(fcWindow <= dataPaddingBytes + 1) {
if(getBuffer().remaining() >= fcWindow) {
dataPaddingBytes = 0;
} else if (getBuffer().remaining() == dataPaddingBytes ){
dataPaddingBytes = 1;
} else {
dataPaddingBytes = fcWindow - getBuffer().remaining() - 1;
}
}
final boolean finalFrame = isFinalFrameQueued() && fcWindow >= (getBuffer().remaining() + (dataPaddingBytes > 0 ? dataPaddingBytes + 1 : 0));
PooledByteBuffer firstHeaderBuffer = getChannel().getBufferPool().allocate();
PooledByteBuffer[] allHeaderBuffers = null;
ByteBuffer firstBuffer = firstHeaderBuffer.getBuffer();
boolean firstFrame = false;
HeaderMap trailers = null;
if(finalFrame && this.trailersProducer != null) {
trailers = this.trailersProducer.getTrailers();
if(trailers != null && trailers.size() == 0) {
trailers = null;
}
}
if (first) {
firstFrame = true;
first = false;
firstBuffer.put((byte) 0);
firstBuffer.put((byte) 0);
firstBuffer.put((byte) 0);
firstBuffer.put((byte) frameType);
firstBuffer.put((byte) 0);
Http2ProtocolUtils.putInt(firstBuffer, getStreamId());
int paddingBytes = getChannel().getPaddingBytes();
if(paddingBytes > 0) {
firstBuffer.put((byte) (paddingBytes & 0xFF));
}
writeBeforeHeaderBlock(firstBuffer);
HeaderMap headers = this.headers;
HpackEncoder.State result = encoder.encode(headers, firstBuffer);
PooledByteBuffer current = firstHeaderBuffer;
int headerFrameLength = firstBuffer.position() - 9 + paddingBytes;
firstBuffer.put(0, (byte) ((headerFrameLength >> 16) & 0xFF));
firstBuffer.put(1, (byte) ((headerFrameLength >> 8) & 0xFF));
firstBuffer.put(2, (byte) (headerFrameLength & 0xFF));
firstBuffer.put(4, (byte) ((isFinalFrameQueued() && !getBuffer().hasRemaining() && frameType == Http2Channel.FRAME_TYPE_HEADERS && trailers == null ? Http2Channel.HEADERS_FLAG_END_STREAM : 0) | (result == HpackEncoder.State.COMPLETE ? Http2Channel.HEADERS_FLAG_END_HEADERS : 0 ) | (paddingBytes > 0 ? Http2Channel.HEADERS_FLAG_PADDED : 0)));
ByteBuffer currentBuffer = firstBuffer;
if(currentBuffer.remaining() < paddingBytes) {
allHeaderBuffers = allocateAll(allHeaderBuffers, current);
current = allHeaderBuffers[allHeaderBuffers.length - 1];
currentBuffer = current.getBuffer();
}
for(int i = 0; i < paddingBytes; ++ i) {
currentBuffer.put((byte) 0);
}
while (result != HpackEncoder.State.COMPLETE) {
allHeaderBuffers = allocateAll(allHeaderBuffers, current);
current = allHeaderBuffers[allHeaderBuffers.length - 1];
result = encodeContinuationFrame(headers, current);
}
}
PooledByteBuffer currentPooled = allHeaderBuffers == null ? firstHeaderBuffer : allHeaderBuffers[allHeaderBuffers.length - 1];
ByteBuffer currentBuffer = currentPooled.getBuffer();
ByteBuffer trailer = null;
int remainingInBuffer = 0;
boolean requiresTrailers = false;
if (getBuffer().remaining() > 0) {
if (fcWindow > 0) {
if (currentBuffer.remaining() < 10) {
allHeaderBuffers = allocateAll(allHeaderBuffers, currentPooled);
currentPooled = allHeaderBuffers == null ? firstHeaderBuffer : allHeaderBuffers[allHeaderBuffers.length - 1];
currentBuffer = currentPooled.getBuffer();
}
int toSend = fcWindow - dataPaddingBytes - (dataPaddingBytes > 0 ? 1 :0);
remainingInBuffer = getBuffer().remaining() - toSend;
getBuffer().limit(getBuffer().position() + toSend);
currentBuffer.put((byte) ((fcWindow >> 16) & 0xFF));
currentBuffer.put((byte) ((fcWindow >> 8) & 0xFF));
currentBuffer.put((byte) (fcWindow & 0xFF));
currentBuffer.put((byte) Http2Channel.FRAME_TYPE_DATA);
if(trailers == null) {
currentBuffer.put((byte) ((finalFrame ? Http2Channel.DATA_FLAG_END_STREAM : 0) | (dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0)));
} else {
if(finalFrame) {
requiresTrailers = true;
}
currentBuffer.put((byte) (dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0));
}
Http2ProtocolUtils.putInt(currentBuffer, getStreamId());
if(dataPaddingBytes > 0) {
currentBuffer.put((byte) (dataPaddingBytes & 0xFF));
trailer = ByteBuffer.allocate(dataPaddingBytes);
}
} else {
remainingInBuffer = getBuffer().remaining();
}
} else if (finalFrame && !firstFrame) {
currentBuffer.put((byte) ((fcWindow >> 16) & 0xFF));
currentBuffer.put((byte) ((fcWindow >> 8) & 0xFF));
currentBuffer.put((byte) (fcWindow & 0xFF));
currentBuffer.put((byte) Http2Channel.FRAME_TYPE_DATA);
if (trailers == null) {
currentBuffer.put((byte) ((Http2Channel.HEADERS_FLAG_END_STREAM & 0xFF) | (dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0)));
} else {
requiresTrailers = true;
currentBuffer.put((byte) ((dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0)));
}
Http2ProtocolUtils.putInt(currentBuffer, getStreamId());
if (dataPaddingBytes > 0) {
currentBuffer.put((byte) (dataPaddingBytes & 0xFF));
trailer = ByteBuffer.allocate(dataPaddingBytes);
}
} else if(finalFrame && trailers != null) {
requiresTrailers = true;
}
if (requiresTrailers) {
PooledByteBuffer firstTrailerBuffer = getChannel().getBufferPool().allocate();
if (trailer != null) {
firstTrailerBuffer.getBuffer().put(trailer);
}
firstTrailerBuffer.getBuffer().put((byte) 0);
firstTrailerBuffer.getBuffer().put((byte) 0);
firstTrailerBuffer.getBuffer().put((byte) 0);
firstTrailerBuffer.getBuffer().put((byte) Http2Channel.FRAME_TYPE_HEADERS);
firstTrailerBuffer.getBuffer().put((byte) (Http2Channel.HEADERS_FLAG_END_STREAM | Http2Channel.HEADERS_FLAG_END_HEADERS));
Http2ProtocolUtils.putInt(firstTrailerBuffer.getBuffer(), getStreamId());
HpackEncoder.State result = encoder.encode(trailers, firstTrailerBuffer.getBuffer());
if (result != HpackEncoder.State.COMPLETE) {
throw UndertowMessages.MESSAGES.http2TrailerToLargeForSingleBuffer();
}
int headerFrameLength = firstTrailerBuffer.getBuffer().position() - 9;
firstTrailerBuffer.getBuffer().put(0, (byte) ((headerFrameLength >> 16) & 0xFF));
firstTrailerBuffer.getBuffer().put(1, (byte) ((headerFrameLength >> 8) & 0xFF));
firstTrailerBuffer.getBuffer().put(2, (byte) (headerFrameLength & 0xFF));
firstTrailerBuffer.getBuffer().flip();
int size = firstTrailerBuffer.getBuffer().remaining();
trailer = ByteBuffer.allocate(size);
trailer.put(firstTrailerBuffer.getBuffer());
trailer.flip();
firstTrailerBuffer.close();
}
if (allHeaderBuffers == null) {
currentBuffer.flip();
return new SendFrameHeader(remainingInBuffer, currentPooled, false, trailer);
} else {
int length = 0;
for (int i = 0; i < allHeaderBuffers.length; ++i) {
length += allHeaderBuffers[i].getBuffer().position();
allHeaderBuffers[i].getBuffer().flip();
}
try {
ByteBuffer newBuf = ByteBuffer.allocate(length);
for (int i = 0; i < allHeaderBuffers.length; ++i) {
newBuf.put(allHeaderBuffers[i].getBuffer());
}
newBuf.flip();
return new SendFrameHeader(remainingInBuffer, new ImmediatePooledByteBuffer(newBuf), false, trailer);
} finally {
for (int i = 0; i < allHeaderBuffers.length; ++i) {
allHeaderBuffers[i].close();
}
}
}
}
private HpackEncoder.State (HeaderMap headers, PooledByteBuffer current) {
ByteBuffer currentBuffer;
HpackEncoder.State result;
currentBuffer = current.getBuffer();
currentBuffer.put((byte) 0);
currentBuffer.put((byte) 0);
currentBuffer.put((byte) 0);
currentBuffer.put((byte) Http2Channel.FRAME_TYPE_CONTINUATION);
currentBuffer.put((byte) 0);
Http2ProtocolUtils.putInt(currentBuffer, getStreamId());
result = encoder.encode(headers, currentBuffer);
int contFrameLength = currentBuffer.position() - 9;
currentBuffer.put(0, (byte) ((contFrameLength >> 16) & 0xFF));
currentBuffer.put(1, (byte) ((contFrameLength >> 8) & 0xFF));
currentBuffer.put(2, (byte) (contFrameLength & 0xFF));
currentBuffer.put(4, (byte) (result == HpackEncoder.State.COMPLETE ? Http2Channel.HEADERS_FLAG_END_HEADERS : 0 ));
return result;
}
@Override
public boolean flush() throws IOException {
if(completionListenerReady && completionListener != null) {
ChannelListeners.invokeChannelListener(this, completionListener);
completionListener = null;
}
return super.flush();
}
protected void (ByteBuffer buffer) {
}
protected boolean isFlushRequiredOnEmptyBuffer() {
return first;
}
public HeaderMap () {
return headers;
}
@Override
protected void handleFlushComplete(boolean finalFrame) {
super.handleFlushComplete(finalFrame);
if (finalFrame) {
if (completionListener != null) {
completionListenerReady = true;
}
}
}
@Override
protected void channelForciblyClosed() throws IOException {
super.channelForciblyClosed();
if (completionListener != null) {
ChannelListeners.invokeChannelListener(this, completionListener);
completionListener = null;
}
}
public ChannelListener<Http2DataStreamSinkChannel> getCompletionListener() {
return completionListener;
}
public void setCompletionListener(ChannelListener<Http2DataStreamSinkChannel> completionListener) {
this.completionListener = completionListener;
}
public interface TrailersProducer {
HeaderMap getTrailers();
}
}