package io.undertow.conduits;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import io.undertow.UndertowLogger;
import io.undertow.server.protocol.http.HttpAttachments;
import io.undertow.util.Attachable;
import io.undertow.util.AttachmentKey;
import io.undertow.util.HeaderMap;
import io.undertow.util.HeaderValues;
import io.undertow.util.Headers;
import io.undertow.util.ImmediatePooledByteBuffer;
import org.xnio.IoUtils;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.AbstractStreamSinkConduit;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.anyAreSet;
public class ChunkedStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> {
@Deprecated
public static final AttachmentKey<HeaderMap> TRAILERS = HttpAttachments.RESPONSE_TRAILERS;
private final HeaderMap ;
private final ConduitListener<? super ChunkedStreamSinkConduit> finishListener;
private final int config;
private final ByteBufferPool bufferPool;
private static final byte[] LAST_CHUNK = new byte[] {(byte) 48, (byte) 13, (byte) 10};
private static final byte[] CRLF = new byte[] {(byte) 13, (byte) 10};
private final Attachable attachable;
private int state;
private int chunkleft = 0;
private final ByteBuffer chunkingBuffer = ByteBuffer.allocate(12);
private final ByteBuffer chunkingSepBuffer;
private PooledByteBuffer lastChunkBuffer;
private static final int CONF_FLAG_CONFIGURABLE = 1 << 0;
private static final int CONF_FLAG_PASS_CLOSE = 1 << 1;
private static final int FLAG_WRITES_SHUTDOWN = 1;
private static final int FLAG_NEXT_SHUTDOWN = 1 << 2;
private static final int FLAG_WRITTEN_FIRST_CHUNK = 1 << 3;
private static final int FLAG_FIRST_DATA_WRITTEN = 1 << 4;
private static final int FLAG_FINISHED = 1 << 5;
public (final StreamSinkConduit next, final ByteBufferPool bufferPool, final boolean configurable, final boolean passClose, HeaderMap responseHeaders, final ConduitListener<? super ChunkedStreamSinkConduit> finishListener, final Attachable attachable) {
super(next);
this.bufferPool = bufferPool;
this.responseHeaders = responseHeaders;
this.finishListener = finishListener;
this.attachable = attachable;
config = (configurable ? CONF_FLAG_CONFIGURABLE : 0) | (passClose ? CONF_FLAG_PASS_CLOSE : 0);
chunkingSepBuffer = ByteBuffer.allocate(2);
chunkingSepBuffer.flip();
}
@Override
public int write(final ByteBuffer src) throws IOException {
return doWrite(src);
}
int doWrite(final ByteBuffer src) throws IOException {
if (anyAreSet(state, FLAG_WRITES_SHUTDOWN)) {
throw new ClosedChannelException();
}
if(src.remaining() == 0) {
return 0;
}
this.state |= FLAG_FIRST_DATA_WRITTEN;
int oldLimit = src.limit();
boolean dataRemaining = false;
if (chunkleft == 0 && !chunkingSepBuffer.hasRemaining()) {
chunkingBuffer.clear();
putIntAsHexString(chunkingBuffer, src.remaining());
chunkingBuffer.put(CRLF);
chunkingBuffer.flip();
chunkingSepBuffer.clear();
chunkingSepBuffer.put(CRLF);
chunkingSepBuffer.flip();
state |= FLAG_WRITTEN_FIRST_CHUNK;
chunkleft = src.remaining();
} else {
if (src.remaining() > chunkleft) {
dataRemaining = true;
src.limit(chunkleft + src.position());
}
}
try {
int chunkingSize = chunkingBuffer.remaining();
int chunkingSepSize = chunkingSepBuffer.remaining();
if (chunkingSize > 0 || chunkingSepSize > 0 || lastChunkBuffer != null) {
int originalRemaining = src.remaining();
long result;
if (lastChunkBuffer == null || dataRemaining) {
final ByteBuffer[] buf = new ByteBuffer[]{chunkingBuffer, src, chunkingSepBuffer};
result = next.write(buf, 0, buf.length);
} else {
final ByteBuffer[] buf = new ByteBuffer[]{chunkingBuffer, src, lastChunkBuffer.getBuffer()};
if (anyAreSet(state, CONF_FLAG_PASS_CLOSE)) {
result = next.writeFinal(buf, 0, buf.length);
} else {
result = next.write(buf, 0, buf.length);
}
if (!src.hasRemaining()) {
state |= FLAG_WRITES_SHUTDOWN;
}
if (!lastChunkBuffer.getBuffer().hasRemaining()) {
state |= FLAG_NEXT_SHUTDOWN;
lastChunkBuffer.close();
}
}
int srcWritten = originalRemaining - src.remaining();
chunkleft -= srcWritten;
if (result < chunkingSize) {
return 0;
} else {
return srcWritten;
}
} else {
int result = next.write(src);
chunkleft -= result;
return result;
}
} finally {
src.limit(oldLimit);
}
}
@Override
public void truncateWrites() throws IOException {
try {
if (lastChunkBuffer != null) {
lastChunkBuffer.close();
}
if (allAreClear(state, FLAG_FINISHED)) {
invokeFinishListener();
}
} finally {
super.truncateWrites();
}
}
@Override
public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
for (int i = offset; i < length; ++i) {
if (srcs[i].hasRemaining()) {
return write(srcs[i]);
}
}
return 0;
}
@Override
public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
return Conduits.writeFinalBasic(this, srcs, offset, length);
}
@Override
public int writeFinal(ByteBuffer src) throws IOException {
if(!src.hasRemaining()) {
terminateWrites();
return 0;
}
if (lastChunkBuffer == null) {
createLastChunk(true);
}
return doWrite(src);
}
@Override
public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
if (anyAreSet(state, FLAG_WRITES_SHUTDOWN)) {
throw new ClosedChannelException();
}
return src.transferTo(position, count, new ConduitWritableByteChannel(this));
}
@Override
public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
if (anyAreSet(state, FLAG_WRITES_SHUTDOWN)) {
throw new ClosedChannelException();
}
return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
}
@Override
public boolean flush() throws IOException {
this.state |= FLAG_FIRST_DATA_WRITTEN;
if (anyAreSet(state, FLAG_WRITES_SHUTDOWN)) {
if (anyAreSet(state, FLAG_NEXT_SHUTDOWN)) {
boolean val = next.flush();
if (val && allAreClear(state, FLAG_FINISHED)) {
invokeFinishListener();
}
return val;
} else {
next.write(lastChunkBuffer.getBuffer());
if (!lastChunkBuffer.getBuffer().hasRemaining()) {
lastChunkBuffer.close();
if (anyAreSet(config, CONF_FLAG_PASS_CLOSE)) {
next.terminateWrites();
}
state |= FLAG_NEXT_SHUTDOWN;
boolean val = next.flush();
if (val && allAreClear(state, FLAG_FINISHED)) {
invokeFinishListener();
}
return val;
} else {
return false;
}
}
} else {
return next.flush();
}
}
private void invokeFinishListener() {
state |= FLAG_FINISHED;
if (finishListener != null) {
finishListener.handleEvent(this);
}
}
@Override
public void terminateWrites() throws IOException {
if(anyAreSet(state, FLAG_WRITES_SHUTDOWN)) {
return;
}
if (this.chunkleft != 0) {
UndertowLogger.REQUEST_IO_LOGGER.debugf("Channel closed mid-chunk");
next.truncateWrites();
}
if (!anyAreSet(state, FLAG_FIRST_DATA_WRITTEN)) {
responseHeaders.put(Headers.CONTENT_LENGTH, "0");
responseHeaders.remove(Headers.TRANSFER_ENCODING);
state |= FLAG_NEXT_SHUTDOWN | FLAG_WRITES_SHUTDOWN;
if(anyAreSet(state, CONF_FLAG_PASS_CLOSE)) {
next.terminateWrites();
}
} else {
createLastChunk(false);
state |= FLAG_WRITES_SHUTDOWN;
}
}
private void createLastChunk(final boolean writeFinal) throws UnsupportedEncodingException {
PooledByteBuffer lastChunkBufferPooled = bufferPool.allocate();
ByteBuffer lastChunkBuffer = lastChunkBufferPooled.getBuffer();
if (writeFinal) {
lastChunkBuffer.put(CRLF);
} else if(chunkingSepBuffer.hasRemaining()) {
lastChunkBuffer.put(chunkingSepBuffer);
}
lastChunkBuffer.put(LAST_CHUNK);
HeaderMap attachment = attachable.getAttachment(HttpAttachments.RESPONSE_TRAILERS);
final HeaderMap trailers;
Supplier<HeaderMap> supplier = attachable.getAttachment(HttpAttachments.RESPONSE_TRAILER_SUPPLIER);
if(attachment != null && supplier == null) {
trailers = attachment;
} else if(attachment == null && supplier != null) {
trailers = supplier.get();
} else if(attachment != null) {
HeaderMap supplied = supplier.get();
for(HeaderValues k : supplied) {
attachment.putAll(k.getHeaderName(), k);
}
trailers = attachment;
} else {
trailers = null;
}
if (trailers != null && trailers.size() != 0) {
for (HeaderValues trailer : trailers) {
for (String val : trailer) {
trailer.getHeaderName().appendTo(lastChunkBuffer);
lastChunkBuffer.put((byte) ':');
lastChunkBuffer.put((byte) ' ');
lastChunkBuffer.put(val.getBytes(StandardCharsets.US_ASCII));
lastChunkBuffer.put(CRLF);
}
}
lastChunkBuffer.put(CRLF);
} else {
lastChunkBuffer.put(CRLF);
}
lastChunkBuffer.flip();
ByteBuffer data = ByteBuffer.allocate(lastChunkBuffer.remaining());
data.put(lastChunkBuffer);
data.flip();
this.lastChunkBuffer = new ImmediatePooledByteBuffer(data);
lastChunkBufferPooled.close();
}
@Override
public void awaitWritable() throws IOException {
next.awaitWritable();
}
@Override
public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
next.awaitWritable(time, timeUnit);
}
private static void putIntAsHexString(final ByteBuffer buf, final int v) {
byte int3 = (byte) (v >> 24);
byte int2 = (byte) (v >> 16);
byte int1 = (byte) (v >> 8);
byte int0 = (byte) (v );
boolean nonZeroFound = false;
if (int3 != 0) {
buf.put(DIGITS[(0xF0 & int3) >>> 4])
.put(DIGITS[0x0F & int3]);
nonZeroFound = true;
}
if (nonZeroFound || int2 != 0) {
buf.put(DIGITS[(0xF0 & int2) >>> 4])
.put(DIGITS[0x0F & int2]);
nonZeroFound = true;
}
if (nonZeroFound || int1 != 0) {
buf.put(DIGITS[(0xF0 & int1) >>> 4])
.put(DIGITS[0x0F & int1]);
}
buf.put(DIGITS[(0xF0 & int0) >>> 4])
.put(DIGITS[0x0F & int0]);
}
private static final byte[] DIGITS = new byte[] {
(byte) 48, (byte) 49, (byte) 50, (byte) 51, (byte) 52, (byte) 53,
(byte) 54, (byte) 55, (byte) 56, (byte) 57, (byte) 97, (byte) 98,
(byte) 99, (byte) 100, (byte) 101, (byte) 102};
}