package io.undertow.conduits;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.anyAreSet;
import static org.xnio.Bits.longBitMask;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.xnio.conduits.Conduit;
import io.undertow.UndertowMessages;
import io.undertow.util.Attachable;
import io.undertow.util.AttachmentKey;
import io.undertow.util.HeaderMap;
import io.undertow.util.HttpString;
class ChunkReader<T extends Conduit> {
private static final long FLAG_FINISHED = 1L << 62L;
private static final long FLAG_READING_LENGTH = 1L << 61L;
private static final long FLAG_READING_TILL_END_OF_LINE = 1L << 60L;
private static final long FLAG_READING_NEWLINE = 1L << 59L;
private static final long FLAG_READING_AFTER_LAST = 1L << 58L;
private static final long MASK_COUNT = longBitMask(0, 56);
private long state;
private final Attachable attachable;
private final AttachmentKey<HeaderMap> trailerAttachmentKey;
private TrailerParser trailerParser;
private final T conduit;
ChunkReader(final Attachable attachable, final AttachmentKey<HeaderMap> trailerAttachmentKey, T conduit) {
this.attachable = attachable;
this.trailerAttachmentKey = trailerAttachmentKey;
this.conduit = conduit;
this.state = FLAG_READING_LENGTH;
}
public long readChunk(final ByteBuffer buf) throws IOException {
long oldVal = state;
long chunkRemaining = state & MASK_COUNT;
if (chunkRemaining > 0 && !anyAreSet(state, FLAG_READING_AFTER_LAST | FLAG_READING_LENGTH | FLAG_READING_NEWLINE | FLAG_READING_TILL_END_OF_LINE)) {
return chunkRemaining;
}
long newVal = oldVal & ~MASK_COUNT;
try {
if (anyAreSet(oldVal, FLAG_READING_AFTER_LAST)) {
int ret = handleChunkedRequestEnd(buf);
if (ret == -1) {
newVal |= FLAG_FINISHED & ~FLAG_READING_AFTER_LAST;
return -1;
}
return 0;
}
while (anyAreSet(newVal, FLAG_READING_NEWLINE)) {
while (buf.hasRemaining()) {
byte b = buf.get();
if (b == '\n') {
newVal = newVal & ~FLAG_READING_NEWLINE | FLAG_READING_LENGTH;
break;
}
}
if (anyAreSet(newVal, FLAG_READING_NEWLINE)) {
return 0;
}
}
while (anyAreSet(newVal, FLAG_READING_LENGTH)) {
while (buf.hasRemaining()) {
byte b = buf.get();
if ((b >= '0' && b <= '9') || (b >= 'a' && b <= 'f') || (b >= 'A' && b <= 'F')) {
chunkRemaining <<= 4;
chunkRemaining += Character.digit((char) b, 16);
} else {
if (b == '\n') {
newVal = newVal & ~FLAG_READING_LENGTH;
} else {
newVal = newVal & ~FLAG_READING_LENGTH | FLAG_READING_TILL_END_OF_LINE;
}
break;
}
}
if (anyAreSet(newVal, FLAG_READING_LENGTH)) {
return 0;
}
}
while (anyAreSet(newVal, FLAG_READING_TILL_END_OF_LINE)) {
while (buf.hasRemaining()) {
if (buf.get() == '\n') {
newVal = newVal & ~FLAG_READING_TILL_END_OF_LINE;
break;
}
}
if (anyAreSet(newVal, FLAG_READING_TILL_END_OF_LINE)) {
return 0;
}
}
if (allAreClear(newVal, FLAG_READING_NEWLINE | FLAG_READING_LENGTH | FLAG_READING_TILL_END_OF_LINE) && chunkRemaining == 0) {
newVal |= FLAG_READING_AFTER_LAST;
int ret = handleChunkedRequestEnd(buf);
if (ret == -1) {
newVal |= FLAG_FINISHED & ~FLAG_READING_AFTER_LAST;
return -1;
}
return 0;
}
return chunkRemaining;
} finally {
state = newVal | chunkRemaining;
}
}
public long getChunkRemaining() {
if (anyAreSet(state, FLAG_FINISHED)) {
return -1;
}
if (anyAreSet(state, FLAG_READING_LENGTH | FLAG_READING_TILL_END_OF_LINE | FLAG_READING_NEWLINE | FLAG_READING_AFTER_LAST)) {
return 0;
}
return state & MASK_COUNT;
}
public void setChunkRemaining(final long remaining) {
if (remaining < 0 || anyAreSet(state, FLAG_READING_LENGTH | FLAG_READING_TILL_END_OF_LINE | FLAG_READING_NEWLINE | FLAG_READING_AFTER_LAST)) {
return;
}
long old = state;
long oldRemaining = old & MASK_COUNT;
if (remaining == 0 && oldRemaining != 0) {
old |= FLAG_READING_NEWLINE;
}
state = (old & ~MASK_COUNT) | remaining;
}
private int handleChunkedRequestEnd(ByteBuffer buffer) throws IOException {
if (trailerParser != null) {
return trailerParser.handle(buffer);
}
while (buffer.hasRemaining()) {
byte b = buffer.get();
if (b == '\n') {
return -1;
} else if (b != '\r') {
buffer.position(buffer.position() - 1);
trailerParser = new TrailerParser();
return trailerParser.handle(buffer);
}
}
return 0;
}
private final class TrailerParser {
private HeaderMap = new HeaderMap();
private StringBuilder builder = new StringBuilder();
private HttpString httpString;
int state = 0;
private static final int STATE_TRAILER_NAME = 0;
private static final int STATE_TRAILER_VALUE = 1;
private static final int STATE_ENDING = 2;
public int handle(ByteBuffer buf) throws IOException {
while (buf.hasRemaining()) {
final byte b = buf.get();
if (state == STATE_TRAILER_NAME) {
if (b == '\r') {
if (builder.length() == 0) {
state = STATE_ENDING;
} else {
throw UndertowMessages.MESSAGES.couldNotDecodeTrailers();
}
} else if (b == '\n') {
if (builder.length() == 0) {
attachable.putAttachment(trailerAttachmentKey, headerMap);
return -1;
} else {
throw UndertowMessages.MESSAGES.couldNotDecodeTrailers();
}
} else if (b == ':') {
httpString = HttpString.tryFromString(builder.toString().trim());
state = STATE_TRAILER_VALUE;
builder.setLength(0);
} else {
builder.append((char) b);
}
} else if (state == STATE_TRAILER_VALUE) {
if (b == '\n') {
headerMap.put(httpString, builder.toString().trim());
httpString = null;
builder.setLength(0);
state = STATE_TRAILER_NAME;
} else if (b != '\r') {
builder.append((char) b);
}
} else if (state == STATE_ENDING) {
if (b == '\n') {
if (attachable != null) {
attachable.putAttachment(trailerAttachmentKey, headerMap);
}
return -1;
} else {
throw UndertowMessages.MESSAGES.couldNotDecodeTrailers();
}
} else {
throw new IllegalStateException();
}
}
return 0;
}
}
}