package io.undertow.servlet.spec;
import io.undertow.servlet.UndertowServletMessages;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.channels.Channels;
import org.xnio.channels.StreamSinkChannel;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.Executor;
import static org.xnio.Bits.anyAreClear;
import static org.xnio.Bits.anyAreSet;
public class UpgradeServletOutputStream extends ServletOutputStream {
private final StreamSinkChannel channel;
private WriteListener listener;
private final Executor ioExecutor;
private static final int FLAG_READY = 1;
private static final int FLAG_CLOSED = 1 << 1;
private static final int FLAG_DELEGATE_SHUTDOWN = 1 << 2;
private int state;
private ByteBuffer buffer;
protected UpgradeServletOutputStream(final StreamSinkChannel channel, Executor ioExecutor) {
this.channel = channel;
this.ioExecutor = ioExecutor;
}
@Override
public void write(final byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public void write(final byte[] b, final int off, final int len) throws IOException {
if (anyAreSet(state, FLAG_CLOSED)) {
throw UndertowServletMessages.MESSAGES.streamIsClosed();
}
if (listener == null) {
Channels.writeBlocking(channel, ByteBuffer.wrap(b, off, len));
} else {
if (anyAreClear(state, FLAG_READY)) {
throw UndertowServletMessages.MESSAGES.streamNotReady();
}
int res;
ByteBuffer buffer = ByteBuffer.wrap(b);
do {
res = channel.write(buffer);
if (res == 0) {
ByteBuffer copy = ByteBuffer.allocate(buffer.remaining());
copy.put(buffer);
copy.flip();
this.buffer = copy;
state = state & ~FLAG_READY;
if (Thread.currentThread() == channel.getIoThread()) {
channel.resumeWrites();
} else {
ioExecutor.execute(new Runnable() {
@Override
public void run() {
channel.resumeWrites();
}
});
}
return;
}
} while (buffer.hasRemaining());
}
}
@Override
public void write(final int b) throws IOException {
write(new byte[]{(byte) b}, 0, 1);
}
@Override
public void flush() throws IOException {
if (anyAreSet(state, FLAG_CLOSED)) {
throw UndertowServletMessages.MESSAGES.streamIsClosed();
}
if (listener == null) {
Channels.flushBlocking(channel);
}
}
@Override
public void close() throws IOException {
state |= FLAG_CLOSED;
state &= ~FLAG_READY;
if (listener == null) {
channel.shutdownWrites();
state |= FLAG_DELEGATE_SHUTDOWN;
Channels.flushBlocking(channel);
} else {
if (buffer == null) {
channel.shutdownWrites();
state |= FLAG_DELEGATE_SHUTDOWN;
if (!channel.flush()) {
if (Thread.currentThread() == channel.getIoThread()) {
channel.resumeWrites();
} else {
ioExecutor.execute(new Runnable() {
@Override
public void run() {
channel.resumeWrites();
}
});
}
}
}
}
}
void closeBlocking() throws IOException {
state |= FLAG_CLOSED;
try {
if (buffer != null) {
Channels.writeBlocking(channel, buffer);
}
channel.shutdownWrites();
Channels.flushBlocking(channel);
} catch (IOException e){
channel.close();
throw e;
}
}
@Override
public boolean isReady() {
if (listener == null) {
throw UndertowServletMessages.MESSAGES.streamNotInAsyncMode();
}
return anyAreSet(state, FLAG_READY);
}
@Override
public void setWriteListener(final WriteListener writeListener) {
if (writeListener == null) {
throw UndertowServletMessages.MESSAGES.paramCannotBeNull("writeListener");
}
if (listener != null) {
throw UndertowServletMessages.MESSAGES.listenerAlreadySet();
}
listener = writeListener;
channel.getWriteSetter().set(new WriteChannelListener());
state |= FLAG_READY;
ioExecutor.execute(new Runnable() {
@Override
public void run() {
channel.resumeWrites();
}
});
}
private class WriteChannelListener implements ChannelListener<StreamSinkChannel> {
@Override
public void handleEvent(final StreamSinkChannel channel) {
if (anyAreSet(state, FLAG_DELEGATE_SHUTDOWN)) {
try {
channel.flush();
return;
} catch (IOException e) {
handleError(channel, e);
}
}
if (buffer != null) {
int res;
do {
try {
res = channel.write(buffer);
if (res == 0) {
return;
}
} catch (IOException e) {
handleError(channel, e);
}
} while (buffer.hasRemaining());
buffer = null;
}
if (anyAreSet(state, FLAG_CLOSED)) {
try {
channel.shutdownWrites();
state |= FLAG_DELEGATE_SHUTDOWN;
channel.flush();
} catch (IOException e) {
handleError(channel, e);
}
} else {
state |= FLAG_READY;
channel.suspendWrites();
try {
listener.onWritePossible();
} catch (IOException e) {
listener.onError(e);
}
}
}
private void handleError(final StreamSinkChannel channel, final IOException e) {
try {
listener.onError(e);
} finally {
IoUtils.safeClose(channel);
}
}
}
}