package io.undertow.conduits;
import io.undertow.UndertowMessages;
import org.xnio.Buffers;
import org.xnio.IoUtils;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.AbstractStreamSinkConduit;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayDeque;
import java.util.Deque;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.anyAreSet;
public class AbstractFramedStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> {
private final Deque<Frame> frameQueue = new ArrayDeque<>();
private long queuedData = 0;
private int bufferCount = 0;
private int state;
private static final int FLAG_WRITES_TERMINATED = 1;
private static final int FLAG_DELEGATE_SHUTDOWN = 2;
protected AbstractFramedStreamSinkConduit(StreamSinkConduit next) {
super(next);
}
protected void queueFrame(FrameCallBack callback, ByteBuffer... data) {
queuedData += Buffers.remaining(data);
bufferCount += data.length;
frameQueue.add(new Frame(callback, data, 0, data.length));
}
public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
return src.transferTo(position, count, new ConduitWritableByteChannel(this));
}
public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
}
@Override
public int write(ByteBuffer src) throws IOException {
if (anyAreSet(state, FLAG_WRITES_TERMINATED)) {
throw UndertowMessages.MESSAGES.channelIsClosed();
}
return (int) doWrite(new ByteBuffer[]{src}, 0, 1);
}
@Override
public long write(ByteBuffer[] srcs, int offs, int len) throws IOException {
if (anyAreSet(state, FLAG_WRITES_TERMINATED)) {
throw UndertowMessages.MESSAGES.channelIsClosed();
}
return doWrite(srcs, offs, len);
}
@Override
public int writeFinal(ByteBuffer src) throws IOException {
return Conduits.writeFinalBasic(this, src);
}
@Override
public long writeFinal(ByteBuffer[] srcs, int offs, int len) throws IOException {
return Conduits.writeFinalBasic(this, srcs, offs, len);
}
private long doWrite(ByteBuffer[] additionalData, int offs, int len) throws IOException {
ByteBuffer[] buffers = new ByteBuffer[bufferCount + (additionalData == null ? 0 : len)];
int count = 0;
for (Frame frame : frameQueue) {
for (int i = frame.offs; i < frame.offs + frame.len; ++i) {
buffers[count++] = frame.data[i];
}
}
if (additionalData != null) {
for (int i = offs; i < offs + len; ++i) {
buffers[count++] = additionalData[i];
}
}
try {
long written = next.write(buffers, 0, buffers.length);
if (written > this.queuedData) {
this.queuedData = 0;
} else {
this.queuedData -= written;
}
long toAllocate = written;
Frame frame = frameQueue.peek();
while (frame != null) {
if (frame.remaining > toAllocate) {
frame.remaining -= toAllocate;
return 0;
} else {
frameQueue.poll();
FrameCallBack cb = frame.callback;
if (cb != null) {
cb.done();
}
bufferCount -= frame.len;
toAllocate -= frame.remaining;
}
frame = frameQueue.peek();
}
return toAllocate;
} catch (IOException | RuntimeException | Error e) {
IOException ioe = e instanceof IOException ? (IOException) e : new IOException(e);
try {
for (Frame frame : frameQueue) {
FrameCallBack cb = frame.callback;
if (cb != null) {
cb.failed(ioe);
}
}
frameQueue.clear();
bufferCount = 0;
queuedData = 0;
} finally {
throw e;
}
}
}
protected long queuedDataLength() {
return queuedData;
}
@Override
public void terminateWrites() throws IOException {
if (anyAreSet(state, FLAG_WRITES_TERMINATED)) {
return;
}
queueCloseFrames();
state |= FLAG_WRITES_TERMINATED;
if (queuedData == 0) {
state |= FLAG_DELEGATE_SHUTDOWN;
doTerminateWrites();
finished();
}
}
protected void doTerminateWrites() throws IOException {
next.terminateWrites();
}
protected boolean flushQueuedData() throws IOException {
if (queuedData > 0) {
doWrite(null, 0, 0);
}
if (queuedData > 0) {
return false;
}
if (anyAreSet(state, FLAG_WRITES_TERMINATED) && allAreClear(state, FLAG_DELEGATE_SHUTDOWN)) {
doTerminateWrites();
state |= FLAG_DELEGATE_SHUTDOWN;
finished();
}
return next.flush();
}
@Override
public void truncateWrites() throws IOException {
for (Frame frame : frameQueue) {
FrameCallBack cb = frame.callback;
if (cb != null) {
cb.failed(UndertowMessages.MESSAGES.channelIsClosed());
}
}
}
protected boolean isWritesTerminated() {
return anyAreSet(state, FLAG_WRITES_TERMINATED);
}
protected void queueCloseFrames() {
}
protected void finished() {
}
public interface FrameCallBack {
void done();
void failed(final IOException e);
}
private static class Frame {
final FrameCallBack callback;
final ByteBuffer[] data;
final int offs;
final int len;
long remaining;
private Frame(FrameCallBack callback, ByteBuffer[] data, int offs, int len) {
this.callback = callback;
this.data = data;
this.offs = offs;
this.len = len;
this.remaining = Buffers.remaining(data, offs, len);
}
}
protected static class PooledBufferFrameCallback implements FrameCallBack {
private final PooledByteBuffer buffer;
public PooledBufferFrameCallback(PooledByteBuffer buffer) {
this.buffer = buffer;
}
@Override
public void done() {
buffer.close();
}
@Override
public void failed(IOException e) {
buffer.close();
}
}
protected static class PooledBuffersFrameCallback implements FrameCallBack {
private final PooledByteBuffer[] buffers;
public PooledBuffersFrameCallback(PooledByteBuffer... buffers) {
this.buffers = buffers;
}
@Override
public void done() {
for (PooledByteBuffer buffer : buffers) {
buffer.close();
}
}
@Override
public void failed(IOException e) {
done();
}
}
}