package io.undertow.server.protocol.ajp;
import static org.xnio.Bits.anyAreSet;
import static org.xnio.Bits.longBitMask;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;
import io.undertow.UndertowMessages;
import io.undertow.conduits.ConduitListener;
import io.undertow.server.Connectors;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.ImmediatePooledByteBuffer;
import org.xnio.IoUtils;
import org.xnio.channels.StreamSinkChannel;
import org.xnio.conduits.AbstractStreamSourceConduit;
import org.xnio.conduits.ConduitReadableByteChannel;
import org.xnio.conduits.StreamSourceConduit;
public class AjpServerRequestConduit extends AbstractStreamSourceConduit<StreamSourceConduit> {
private static final ByteBuffer READ_BODY_CHUNK;
static {
ByteBuffer readBody = ByteBuffer.allocateDirect(7);
readBody.put((byte) 'A');
readBody.put((byte) 'B');
readBody.put((byte) 0);
readBody.put((byte) 3);
readBody.put((byte) 6);
readBody.put((byte) 0x1F);
readBody.put((byte) 0xFA);
readBody.flip();
READ_BODY_CHUNK = readBody;
}
private static final int = 6;
private static final long STATE_READING = 1L << 63L;
private static final long STATE_SEND_REQUIRED = 1L << 62L;
private static final long STATE_FINISHED = 1L << 61L;
private static final long STATE_MASK = longBitMask(0, 60);
private final HttpServerExchange exchange;
private final AjpServerResponseConduit ajpResponseConduit;
private final ByteBuffer = ByteBuffer.allocateDirect(HEADER_LENGTH);
private final ConduitListener<? super AjpServerRequestConduit> finishListener;
private long remaining;
private long state;
private long totalRead;
public AjpServerRequestConduit(final StreamSourceConduit delegate, HttpServerExchange exchange, AjpServerResponseConduit ajpResponseConduit, Long size, ConduitListener<? super AjpServerRequestConduit> finishListener) {
super(delegate);
this.exchange = exchange;
this.ajpResponseConduit = ajpResponseConduit;
this.finishListener = finishListener;
if (size == null) {
state = STATE_SEND_REQUIRED;
remaining = -1;
} else if (size.longValue() == 0L) {
state = STATE_FINISHED;
remaining = 0;
} else {
state = STATE_READING;
remaining = size;
}
}
@Override
public long transferTo(long position, long count, FileChannel target) throws IOException {
try {
return target.transferFrom(new ConduitReadableByteChannel(this), position, count);
} catch (IOException | RuntimeException e) {
IoUtils.safeClose(exchange.getConnection());
throw e;
}
}
@Override
public long transferTo(long count, ByteBuffer throughBuffer, StreamSinkChannel target) throws IOException {
try {
return IoUtils.transfer(new ConduitReadableByteChannel(this), count, throughBuffer, target);
} catch (IOException | RuntimeException e) {
IoUtils.safeClose(exchange.getConnection());
throw e;
}
}
@Override
public void terminateReads() throws IOException {
if(exchange.isPersistent() && anyAreSet(state, STATE_FINISHED)) {
return;
}
super.terminateReads();
}
@Override
public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
try {
long total = 0;
for (int i = offset; i < length; ++i) {
while (dsts[i].hasRemaining()) {
int r = read(dsts[i]);
if (r <= 0 && total > 0) {
return total;
} else if (r <= 0) {
return r;
} else {
total += r;
}
}
}
return total;
} catch (IOException | RuntimeException e) {
IoUtils.safeClose(exchange.getConnection());
throw e;
}
}
@Override
public int read(ByteBuffer dst) throws IOException {
try {
long state = this.state;
if (anyAreSet(state, STATE_FINISHED)) {
return -1;
} else if (anyAreSet(state, STATE_SEND_REQUIRED)) {
state = this.state = (state & STATE_MASK) | STATE_READING;
if (ajpResponseConduit.isWriteShutdown()) {
this.state = STATE_FINISHED;
if (finishListener != null) {
finishListener.handleEvent(this);
}
return -1;
}
if (!ajpResponseConduit.doGetRequestBodyChunk(READ_BODY_CHUNK.duplicate(), this)) {
return 0;
}
}
if (anyAreSet(state, STATE_READING)) {
return doRead(dst, state);
}
assert STATE_FINISHED == state;
return -1;
} catch (IOException | RuntimeException e) {
IoUtils.safeClose(exchange.getConnection());
throw e;
}
}
private int doRead(final ByteBuffer dst, long state) throws IOException {
ByteBuffer headerBuffer = this.headerBuffer;
long headerRead = HEADER_LENGTH - headerBuffer.remaining();
long remaining = this.remaining;
if (remaining == 0) {
this.state = STATE_FINISHED;
if (finishListener != null) {
finishListener.handleEvent(this);
}
return -1;
}
long chunkRemaining;
if (headerRead != HEADER_LENGTH) {
int read = next.read(headerBuffer);
if (read == -1) {
this.state = STATE_FINISHED;
if (finishListener != null) {
finishListener.handleEvent(this);
}
throw new ClosedChannelException();
} else if (headerBuffer.hasRemaining()) {
if(headerBuffer.remaining() <= 2) {
byte b1 = headerBuffer.get(0);
byte b2 = headerBuffer.get(1);
if (b1 != 0x12 || b2 != 0x34) {
throw UndertowMessages.MESSAGES.wrongMagicNumber((b1 & 0xFF) << 8 | (b2 & 0xFF));
}
b1 = headerBuffer.get(2);
b2 = headerBuffer.get(3);
int totalSize = ((b1 & 0xFF) << 8) | (b2 & 0xFF);
if(totalSize == 0) {
if(headerBuffer.remaining() < 2) {
byte[] data = new byte[1];
ByteBuffer bb = ByteBuffer.wrap(data);
bb.put(headerBuffer.get(4));
bb.flip();
Connectors.ungetRequestBytes(exchange, new ImmediatePooledByteBuffer(bb));
}
this.remaining = 0;
this.state = STATE_FINISHED;
if (finishListener != null) {
finishListener.handleEvent(this);
}
return -1;
}
}
return 0;
} else {
headerBuffer.flip();
byte b1 = headerBuffer.get();
byte b2 = headerBuffer.get();
if (b1 != 0x12 || b2 != 0x34) {
throw UndertowMessages.MESSAGES.wrongMagicNumber((b1 & 0xFF) << 8 | (b2 & 0xFF));
}
b1 = headerBuffer.get();
b2 = headerBuffer.get();
int totalSize = ((b1 & 0xFF) << 8) | (b2 & 0xFF);
if(totalSize == 0) {
byte[] data = new byte[2];
ByteBuffer bb = ByteBuffer.wrap(data);
bb.put(headerBuffer);
bb.flip();
Connectors.ungetRequestBytes(exchange, new ImmediatePooledByteBuffer(bb));
this.remaining = 0;
this.state = STATE_FINISHED;
if (finishListener != null) {
finishListener.handleEvent(this);
}
return -1;
}
b1 = headerBuffer.get();
b2 = headerBuffer.get();
chunkRemaining = ((b1 & 0xFF) << 8) | (b2 & 0xFF);
if (chunkRemaining == 0) {
this.remaining = 0;
this.state = STATE_FINISHED;
if (finishListener != null) {
finishListener.handleEvent(this);
}
return -1;
}
}
} else {
chunkRemaining = this.state & STATE_MASK;
}
int limit = dst.limit();
try {
if (dst.remaining() > chunkRemaining) {
dst.limit((int) (dst.position() + chunkRemaining));
}
int read = next.read(dst);
chunkRemaining -= read;
if (remaining != -1) {
remaining -= read;
}
this.totalRead += read;
if (remaining != 0) {
if (chunkRemaining == 0) {
headerBuffer.clear();
this.state = STATE_SEND_REQUIRED;
} else {
this.state = (state & ~STATE_MASK) | chunkRemaining;
}
}
return read;
} finally {
this.remaining = remaining;
dst.limit(limit);
final long maxEntitySize = exchange.getMaxEntitySize();
if (maxEntitySize > 0) {
if (totalRead > maxEntitySize) {
terminateReads();
exchange.setPersistent(false);
throw UndertowMessages.MESSAGES.requestEntityWasTooLarge(maxEntitySize);
}
}
}
}
@Override
public void awaitReadable() throws IOException {
try {
if (anyAreSet(state, STATE_READING)) {
next.awaitReadable();
}
} catch (IOException | RuntimeException e) {
IoUtils.safeClose(exchange.getConnection());
throw e;
}
}
@Override
public void awaitReadable(long time, TimeUnit timeUnit) throws IOException {
try {
if (anyAreSet(state, STATE_READING)) {
next.awaitReadable(time, timeUnit);
}
} catch (IOException | RuntimeException e) {
IoUtils.safeClose(exchange.getConnection());
throw e;
}
}
void setReadBodyChunkError(IOException e) {
IoUtils.safeClose(exchange.getConnection());
if(isReadResumed()) {
wakeupReads();
}
}
}