package io.undertow.server.protocol.framed;
import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.util.ImmediatePooledByteBuffer;
import org.xnio.Buffers;
import org.xnio.ChannelListener;
import org.xnio.ChannelListeners;
import org.xnio.IoUtils;
import org.xnio.Option;
import org.xnio.XnioExecutor;
import org.xnio.XnioIoThread;
import org.xnio.XnioWorker;
import org.xnio.channels.Channels;
import org.xnio.channels.StreamSinkChannel;
import org.xnio.channels.StreamSourceChannel;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.anyAreSet;
public abstract class AbstractFramedStreamSinkChannel<C extends AbstractFramedChannel<C, R, S>, R extends AbstractFramedStreamSourceChannel<C, R, S>, S extends AbstractFramedStreamSinkChannel<C, R, S>> implements StreamSinkChannel {
private static final PooledByteBuffer EMPTY_BYTE_BUFFER = new ImmediatePooledByteBuffer(ByteBuffer.allocateDirect(0));
private final C channel;
private final ChannelListener.SimpleSetter<S> writeSetter = new ChannelListener.SimpleSetter<>();
private final ChannelListener.SimpleSetter<S> closeSetter = new ChannelListener.SimpleSetter<>();
private final Object lock = new Object();
private volatile int state = 0;
private volatile boolean readyForFlush;
private volatile boolean fullyFlushed;
private volatile boolean finalFrameQueued;
private volatile boolean broken;
private volatile int waiterCount = 0;
private volatile SendFrameHeader ;
private volatile PooledByteBuffer writeBuffer;
private volatile PooledByteBuffer body;
private static final int STATE_CLOSED = 1;
private static final int STATE_WRITES_SHUTDOWN = 1 << 1;
private static final int STATE_FIRST_DATA_WRITTEN = 1 << 2;
private static final int STATE_PRE_WRITE_CALLED = 1 << 3;
private volatile boolean bufferFull;
private volatile boolean writesResumed;
@SuppressWarnings("unused")
private volatile int inListenerLoop;
private static final AtomicIntegerFieldUpdater<AbstractFramedStreamSinkChannel> inListenerLoopUpdater = AtomicIntegerFieldUpdater.newUpdater(AbstractFramedStreamSinkChannel.class, "inListenerLoop");
protected AbstractFramedStreamSinkChannel(C channel) {
this.channel = channel;
}
public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
return src.transferTo(position, count, this);
}
public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
return IoUtils.transfer(source, count, throughBuffer, this);
}
@Override
public void suspendWrites() {
writesResumed = false;
}
final SendFrameHeader () throws IOException {
if (header == null) {
header = createFrameHeader();
if (header == null) {
header = new SendFrameHeader(0, null);
}
}
return header;
}
protected SendFrameHeader () throws IOException{
return null;
}
final void preWrite() {
synchronized (lock) {
if (allAreClear(state, STATE_PRE_WRITE_CALLED)) {
state |= STATE_PRE_WRITE_CALLED;
body = preWriteTransform(body);
}
}
}
protected PooledByteBuffer preWriteTransform(PooledByteBuffer body) {
return body;
}
@Override
public boolean isWriteResumed() {
return writesResumed;
}
@Override
public void wakeupWrites() {
resumeWritesInternal(true);
}
@Override
public void resumeWrites() {
resumeWritesInternal(false);
}
protected void resumeWritesInternal(boolean wakeup) {
boolean alreadyResumed = writesResumed;
if(!wakeup && alreadyResumed) {
return;
}
writesResumed = true;
if(readyForFlush && !wakeup) {
return;
}
if (inListenerLoopUpdater.compareAndSet(this, 0, 1)) {
getChannel().runInIoThread(new Runnable() {
int loopCount = 0;
@Override
public void run() {
try {
ChannelListener<? super S> listener = getWriteListener();
if (listener == null || !isWriteResumed()) {
return;
}
if (loopCount++ == 100) {
UndertowLogger.ROOT_LOGGER.listenerNotProgressing();
IoUtils.safeClose(AbstractFramedStreamSinkChannel.this);
return;
}
ChannelListeners.invokeChannelListener((S) AbstractFramedStreamSinkChannel.this, listener);
} finally {
inListenerLoopUpdater.set(AbstractFramedStreamSinkChannel.this, 0);
}
if (writesResumed && allAreClear(state, STATE_CLOSED) && !broken && !readyForFlush && !fullyFlushed) {
if (inListenerLoopUpdater.compareAndSet(AbstractFramedStreamSinkChannel.this, 0, 1)) {
getIoThread().execute(this);
}
}
}
});
}
}
@Override
public void shutdownWrites() throws IOException {
synchronized (lock) {
if (anyAreSet(state, STATE_WRITES_SHUTDOWN) || broken) {
return;
}
queueFinalFrame();
state |= STATE_WRITES_SHUTDOWN;
}
}
private void queueFinalFrame() throws IOException {
synchronized (lock) {
if (!readyForFlush && !fullyFlushed && allAreClear(state, STATE_CLOSED) && !broken && !finalFrameQueued) {
if (null == body && null != writeBuffer) {
sendWriteBuffer();
} else if (null == body) {
body = EMPTY_BYTE_BUFFER;
}
readyForFlush = true;
state |= STATE_FIRST_DATA_WRITTEN;
state |= STATE_WRITES_SHUTDOWN;
finalFrameQueued = true;
channel.queueFrame((S) this);
}
}
}
protected boolean isFinalFrameQueued() {
return finalFrameQueued;
}
@Override
public void awaitWritable() throws IOException {
if(Thread.currentThread() == getIoThread()) {
throw UndertowMessages.MESSAGES.awaitCalledFromIoThread();
}
synchronized (lock) {
if (anyAreSet(state, STATE_CLOSED) || broken) {
return;
}
if (readyForFlush) {
try {
waiterCount++;
if(readyForFlush && !anyAreSet(state, STATE_CLOSED) && !broken) {
lock.wait();
}
} catch (InterruptedException e) {
throw new InterruptedIOException();
} finally {
waiterCount--;
}
}
}
}
@Override
public void awaitWritable(long l, TimeUnit timeUnit) throws IOException {
if(Thread.currentThread() == getIoThread()) {
throw UndertowMessages.MESSAGES.awaitCalledFromIoThread();
}
synchronized (lock) {
if (anyAreSet(state, STATE_CLOSED) || broken) {
return;
}
if (readyForFlush) {
try {
waiterCount++;
if(readyForFlush && !anyAreSet(state, STATE_CLOSED) && !broken) {
lock.wait(timeUnit.toMillis(l));
}
} catch (InterruptedException e) {
throw new InterruptedIOException();
} finally {
waiterCount--;
}
}
}
}
@Override
public XnioExecutor getWriteThread() {
return channel.getIoThread();
}
@Override
public ChannelListener.Setter<? extends S> getWriteSetter() {
return writeSetter;
}
@Override
public ChannelListener.Setter<? extends S> getCloseSetter() {
return closeSetter;
}
@Override
public XnioWorker getWorker() {
return channel.getWorker();
}
@Override
public XnioIoThread getIoThread() {
return channel.getIoThread();
}
@Override
public boolean flush() throws IOException {
if(anyAreSet(state, STATE_CLOSED)) {
return true;
}
if (broken) {
throw UndertowMessages.MESSAGES.channelIsClosed();
}
if (readyForFlush) {
return false;
}
synchronized (lock) {
if (fullyFlushed) {
state |= STATE_CLOSED;
return true;
}
}
if (anyAreSet(state, STATE_WRITES_SHUTDOWN) && !finalFrameQueued) {
queueFinalFrame();
return false;
}
if(anyAreSet(state, STATE_WRITES_SHUTDOWN)) {
return false;
}
if(isFlushRequiredOnEmptyBuffer() || (writeBuffer != null && writeBuffer.getBuffer().position() > 0)) {
handleBufferFull();
return !readyForFlush;
}
return true;
}
protected boolean isFlushRequiredOnEmptyBuffer() {
return false;
}
@Override
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
if(!safeToSend()) {
return 0;
}
if(writeBuffer == null) {
writeBuffer = getChannel().getBufferPool().allocate();
}
ByteBuffer buffer = writeBuffer.getBuffer();
int copied = Buffers.copy(buffer, srcs, offset, length);
if(!buffer.hasRemaining()) {
handleBufferFull();
}
return copied;
}
@Override
public long write(ByteBuffer[] srcs) throws IOException {
return write(srcs, 0, srcs.length);
}
@Override
public int write(ByteBuffer src) throws IOException {
if(!safeToSend()) {
return 0;
}
if(writeBuffer == null) {
writeBuffer = getChannel().getBufferPool().allocate();
}
ByteBuffer buffer = writeBuffer.getBuffer();
int copied = Buffers.copy(buffer, src);
if(!buffer.hasRemaining()) {
handleBufferFull();
}
return copied;
}
public boolean send(PooledByteBuffer pooled) throws IOException {
if(isWritesShutdown()) {
throw UndertowMessages.MESSAGES.channelIsClosed();
}
boolean result = sendInternal(pooled);
if(result) {
flush();
}
return result;
}
protected boolean sendInternal(PooledByteBuffer pooled) throws IOException {
if (safeToSend()) {
this.body = pooled;
return true;
}
return false;
}
protected boolean safeToSend() throws IOException {
int state = this.state;
if (anyAreSet(state, STATE_CLOSED) || broken) {
throw UndertowMessages.MESSAGES.channelIsClosed();
}
if (readyForFlush) {
return false;
}
if( null != this.body) {
throw UndertowMessages.MESSAGES.bodyIsSetAndNotReadyForFlush();
}
return true;
}
@Override
public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
return Channels.writeFinalBasic(this, srcs, offset, length);
}
@Override
public long writeFinal(ByteBuffer[] srcs) throws IOException {
return writeFinal(srcs, 0, srcs.length);
}
@Override
public int writeFinal(ByteBuffer src) throws IOException {
return Channels.writeFinalBasic(this, src);
}
private void handleBufferFull() throws IOException {
synchronized (lock) {
bufferFull = true;
if (!readyForFlush) {
sendWriteBuffer();
readyForFlush = true;
state |= STATE_FIRST_DATA_WRITTEN;
channel.queueFrame((S) this);
}
}
}
private void sendWriteBuffer() throws IOException {
if(writeBuffer == null) {
writeBuffer = EMPTY_BYTE_BUFFER;
}
writeBuffer.getBuffer().flip();
if(!sendInternal(writeBuffer)) {
throw UndertowMessages.MESSAGES.failedToSendAfterBeingSafe();
}
writeBuffer = null;
}
protected abstract boolean isLastFrame();
public boolean isReadyForFlush() {
return readyForFlush;
}
public boolean isWritesShutdown() {
return anyAreSet(state, STATE_WRITES_SHUTDOWN);
}
@Override
public boolean isOpen() {
return allAreClear(state, STATE_CLOSED);
}
@Override
public void close() throws IOException {
if(fullyFlushed || anyAreSet(state, STATE_CLOSED)) {
return;
}
try {
synchronized (lock) {
state |= STATE_CLOSED;
}
if(writeBuffer != null) {
writeBuffer.close();
writeBuffer = null;
}
if(body != null) {
body.close();
body = null;
}
if (header != null && header.getByteBuffer() != null) {
header.getByteBuffer().close();
header = null;
}
channelForciblyClosed();
if (isWriteResumed()) {
ChannelListeners.invokeChannelListener(getIoThread(), this, (ChannelListener) getWriteListener());
}
wakeupWrites();
} finally {
wakeupWaiters();
}
}
protected void channelForciblyClosed() throws IOException {
if(isFirstDataWritten()) {
getChannel().markWritesBroken(null);
}
wakeupWaiters();
}
@Override
public boolean supportsOption(Option<?> option) {
return false;
}
@Override
public <T> T getOption(Option<T> tOption) throws IOException {
return null;
}
@Override
public <T> T setOption(Option<T> tOption, T t) throws IllegalArgumentException, IOException {
return null;
}
public ByteBuffer getBuffer() {
if(anyAreSet(state, STATE_CLOSED)) {
throw new IllegalStateException();
}
if(body == null) {
body = EMPTY_BYTE_BUFFER;
}
return body.getBuffer();
}
final void flushComplete() throws IOException {
synchronized (lock) {
try {
bufferFull = false;
int remaining = header.getRemainingInBuffer();
boolean finalFrame = finalFrameQueued;
boolean channelClosed = finalFrame && remaining == 0 && !header.isAnotherFrameRequired();
if (remaining > 0) {
body.getBuffer().limit(body.getBuffer().limit() + remaining);
if (finalFrame) {
this.finalFrameQueued = false;
}
} else if (header.isAnotherFrameRequired()) {
this.finalFrameQueued = false;
if (body != null) {
body.close();
body = null;
state &= ~STATE_PRE_WRITE_CALLED;
}
} else if (body != null) {
body.close();
body = null;
state &= ~STATE_PRE_WRITE_CALLED;
}
if (channelClosed) {
fullyFlushed = true;
if (body != null) {
body.close();
body = null;
state &= ~STATE_PRE_WRITE_CALLED;
}
} else if (body != null) {
body.getBuffer().compact();
writeBuffer = body;
body = null;
state &= ~STATE_PRE_WRITE_CALLED;
}
if (header.getByteBuffer() != null) {
header.getByteBuffer().close();
}
header = null;
readyForFlush = false;
if (isWriteResumed() && !channelClosed) {
wakeupWrites();
} else if (isWriteResumed()) {
ChannelListeners.invokeChannelListener(getIoThread(), (S) this, getWriteListener());
}
final ChannelListener<? super S> closeListener = this.closeSetter.get();
if (channelClosed && closeListener != null) {
ChannelListeners.invokeChannelListener(getIoThread(), (S) AbstractFramedStreamSinkChannel.this, closeListener);
}
handleFlushComplete(channelClosed);
} finally {
wakeupWaiters();
}
}
}
protected void handleFlushComplete(boolean finalFrame) {
}
protected boolean isFirstDataWritten() {
return anyAreSet(state, STATE_FIRST_DATA_WRITTEN);
}
public void markBroken() {
this.broken = true;
try {
wakeupWrites();
wakeupWaiters();
if (isWriteResumed()) {
ChannelListener<? super S> writeListener = this.writeSetter.get();
if (writeListener != null) {
ChannelListeners.invokeChannelListener(getIoThread(), (S) this, writeListener);
}
}
ChannelListener<? super S> closeListener = this.closeSetter.get();
if (closeListener != null) {
ChannelListeners.invokeChannelListener(getIoThread(), (S) this, closeListener);
}
} finally {
if(header != null) {
if( header.getByteBuffer() != null) {
header.getByteBuffer().close();
header = null;
}
}
if(body != null) {
body.close();
body = null;
}
if(writeBuffer != null) {
writeBuffer.close();
writeBuffer = null;
}
}
}
ChannelListener<? super S> getWriteListener() {
return writeSetter.get();
}
private void wakeupWaiters() {
if(waiterCount > 0) {
synchronized (lock) {
if (waiterCount > 0) {
lock.notifyAll();
}
}
}
}
public C getChannel() {
return channel;
}
public boolean isBroken() {
return broken;
}
public boolean isBufferFull() {
return bufferFull;
}
}