package io.undertow.servlet.spec;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.anyAreClear;
import static org.xnio.Bits.anyAreSet;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import org.xnio.Buffers;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.channels.Channels;
import org.xnio.channels.EmptyStreamSourceChannel;
import org.xnio.channels.StreamSourceChannel;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.servlet.UndertowServletMessages;
public class ServletInputStreamImpl extends ServletInputStream {
private final HttpServletRequestImpl request;
private final StreamSourceChannel channel;
private final ByteBufferPool bufferPool;
private volatile ReadListener listener;
private volatile ServletInputStreamChannelListener internalListener;
private static final int FLAG_READY = 1;
private static final int FLAG_CLOSED = 1 << 1;
private static final int FLAG_FINISHED = 1 << 2;
private static final int FLAG_ON_DATA_READ_CALLED = 1 << 3;
private static final int FLAG_CALL_ON_ALL_DATA_READ = 1 << 4;
private static final int FLAG_BEING_INVOKED_IN_IO_THREAD = 1 << 5;
private static final int FLAG_IS_READY_CALLED = 1 << 6;
private volatile int state;
private volatile AsyncContextImpl asyncContext;
private volatile PooledByteBuffer pooled;
private static final AtomicIntegerFieldUpdater<ServletInputStreamImpl> stateUpdater = AtomicIntegerFieldUpdater.newUpdater(ServletInputStreamImpl.class, "state");
public ServletInputStreamImpl(final HttpServletRequestImpl request) {
this.request = request;
if (request.getExchange().isRequestChannelAvailable()) {
this.channel = request.getExchange().getRequestChannel();
} else {
this.channel = new EmptyStreamSourceChannel(request.getExchange().getIoThread());
}
this.bufferPool = request.getExchange().getConnection().getByteBufferPool();
}
@Override
public boolean isFinished() {
return anyAreSet(state, FLAG_FINISHED);
}
@Override
public boolean isReady() {
boolean finished = anyAreSet(state, FLAG_FINISHED);
if(finished) {
if (anyAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
if(allAreClear(state, FLAG_BEING_INVOKED_IN_IO_THREAD)) {
setFlags(FLAG_ON_DATA_READ_CALLED);
request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
} else {
setFlags(FLAG_CALL_ON_ALL_DATA_READ);
}
}
}
boolean ready = anyAreSet(state, FLAG_READY) && !finished;
if(!ready && listener != null && !finished) {
channel.resumeReads();
}
if(ready) {
setFlags(FLAG_IS_READY_CALLED);
}
return ready;
}
@Override
public void setReadListener(final ReadListener readListener) {
if (readListener == null) {
throw UndertowServletMessages.MESSAGES.listenerCannotBeNull();
}
if (listener != null) {
throw UndertowServletMessages.MESSAGES.listenerAlreadySet();
}
if (!request.isAsyncStarted()) {
throw UndertowServletMessages.MESSAGES.asyncNotStarted();
}
asyncContext = request.getAsyncContext();
listener = readListener;
channel.getReadSetter().set(internalListener = new ServletInputStreamChannelListener());
asyncContext.addAsyncTask(new Runnable() {
@Override
public void run() {
channel.getIoThread().execute(new Runnable() {
@Override
public void run() {
internalListener.handleEvent(channel);
}
});
}
});
}
@Override
public int read() throws IOException {
byte[] b = new byte[1];
int read = read(b);
if (read == -1) {
return -1;
}
return b[0] & 0xff;
}
@Override
public int read(final byte[] b) throws IOException {
return read(b, 0, b.length);
}
@Override
public int read(final byte[] b, final int off, final int len) throws IOException {
if (anyAreSet(state, FLAG_CLOSED)) {
throw UndertowServletMessages.MESSAGES.streamIsClosed();
}
if (listener != null) {
if (anyAreClear(state, FLAG_READY | FLAG_IS_READY_CALLED) ) {
throw UndertowServletMessages.MESSAGES.streamNotReady();
}
clearFlags(FLAG_IS_READY_CALLED);
} else {
readIntoBuffer();
}
if (anyAreSet(state, FLAG_FINISHED)) {
return -1;
}
if (len == 0) {
return 0;
}
ByteBuffer buffer = pooled.getBuffer();
int copied = Buffers.copy(ByteBuffer.wrap(b, off, len), buffer);
if (!buffer.hasRemaining()) {
pooled.close();
pooled = null;
if (listener != null) {
readIntoBufferNonBlocking();
}
}
return copied;
}
private void readIntoBuffer() throws IOException {
if (pooled == null && !anyAreSet(state, FLAG_FINISHED)) {
pooled = bufferPool.allocate();
int res = Channels.readBlocking(channel, pooled.getBuffer());
pooled.getBuffer().flip();
if (res == -1) {
setFlags(FLAG_FINISHED);
pooled.close();
pooled = null;
}
}
}
private void readIntoBufferNonBlocking() throws IOException {
if (pooled == null && !anyAreSet(state, FLAG_FINISHED)) {
pooled = bufferPool.allocate();
if (listener == null) {
int res = channel.read(pooled.getBuffer());
if (res == 0) {
pooled.close();
pooled = null;
return;
}
pooled.getBuffer().flip();
if (res == -1) {
setFlags(FLAG_FINISHED);
pooled.close();
pooled = null;
}
} else {
int res = channel.read(pooled.getBuffer());
pooled.getBuffer().flip();
if (res == -1) {
setFlags(FLAG_FINISHED);
pooled.close();
pooled = null;
} else if (res == 0) {
clearFlags(FLAG_READY);
pooled.close();
pooled = null;
}
}
}
}
@Override
public int available() throws IOException {
if (anyAreSet(state, FLAG_CLOSED)) {
throw UndertowServletMessages.MESSAGES.streamIsClosed();
}
readIntoBufferNonBlocking();
if (anyAreSet(state, FLAG_FINISHED)) {
return 0;
}
if (pooled == null) {
return 0;
}
return pooled.getBuffer().remaining();
}
@Override
public void close() throws IOException {
if (anyAreSet(state, FLAG_CLOSED)) {
return;
}
setFlags(FLAG_CLOSED);
try {
while (allAreClear(state, FLAG_FINISHED)) {
readIntoBuffer();
if (pooled != null) {
pooled.close();
pooled = null;
}
}
} finally {
setFlags(FLAG_FINISHED);
if (pooled != null) {
pooled.close();
pooled = null;
}
channel.shutdownReads();
}
}
private class ServletInputStreamChannelListener implements ChannelListener<StreamSourceChannel> {
@Override
public void handleEvent(final StreamSourceChannel channel) {
try {
if (asyncContext.isDispatched()) {
channel.suspendReads();
return;
}
if (anyAreSet(state, FLAG_FINISHED)) {
channel.suspendReads();
return;
}
readIntoBufferNonBlocking();
if (pooled != null) {
channel.suspendReads();
setFlags(FLAG_READY);
if (!anyAreSet(state, FLAG_FINISHED)) {
setFlags(FLAG_BEING_INVOKED_IN_IO_THREAD);
try {
request.getServletContext().invokeOnDataAvailable(request.getExchange(), listener);
} finally {
clearFlags(FLAG_BEING_INVOKED_IN_IO_THREAD);
}
if(anyAreSet(state, FLAG_CALL_ON_ALL_DATA_READ) && allAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
setFlags(FLAG_ON_DATA_READ_CALLED);
request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
}
}
} else if(anyAreSet(state, FLAG_FINISHED)) {
if (allAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
setFlags(FLAG_ON_DATA_READ_CALLED);
request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
}
} else {
channel.resumeReads();
}
} catch (final Throwable e) {
try {
request.getServletContext().invokeRunnable(request.getExchange(), new Runnable() {
@Override
public void run() {
listener.onError(e);
}
});
} finally {
if (pooled != null) {
pooled.close();
pooled = null;
}
IoUtils.safeClose(channel);
}
}
}
}
private void setFlags(int flags) {
int old;
do {
old = state;
} while (!stateUpdater.compareAndSet(this, old, old | flags));
}
private void clearFlags(int flags) {
int old;
do {
old = state;
} while (!stateUpdater.compareAndSet(this, old, old & ~flags));
}
}