package io.undertow.websockets.core;
import io.undertow.util.ImmediatePooled;
import org.xnio.ChannelListener;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.Pooled;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class BufferedBinaryMessage {
private final boolean bufferFullMessage;
private List<PooledByteBuffer> data = new ArrayList<>(1);
private PooledByteBuffer current;
private final long maxMessageSize;
private long currentSize;
private boolean complete;
public BufferedBinaryMessage(long maxMessageSize, boolean bufferFullMessage) {
this.bufferFullMessage = bufferFullMessage;
this.maxMessageSize = maxMessageSize;
}
public BufferedBinaryMessage(boolean bufferFullMessage) {
this(-1, bufferFullMessage);
}
public void readBlocking(StreamSourceFrameChannel channel) throws IOException {
if (current == null) {
current = channel.getWebSocketChannel().getBufferPool().allocate();
}
for (; ; ) {
int res = channel.read(current.getBuffer());
if (res == -1) {
complete = true;
return;
} else if (res == 0) {
channel.awaitReadable();
}
checkMaxSize(channel, res);
if (bufferFullMessage) {
dealWithFullBuffer(channel);
} else if (!current.getBuffer().hasRemaining()) {
return;
}
}
}
private void dealWithFullBuffer(StreamSourceFrameChannel channel) {
if (!current.getBuffer().hasRemaining()) {
current.getBuffer().flip();
data.add(current);
current = channel.getWebSocketChannel().getBufferPool().allocate();
}
}
public void read(final StreamSourceFrameChannel channel, final WebSocketCallback<BufferedBinaryMessage> callback) {
try {
for (; ; ) {
if (current == null) {
current = channel.getWebSocketChannel().getBufferPool().allocate();
}
int res = channel.read(current.getBuffer());
if (res == -1) {
this.complete = true;
callback.complete(channel.getWebSocketChannel(), this);
return;
} else if (res == 0) {
if(!bufferFullMessage) {
callback.complete(channel.getWebSocketChannel(), BufferedBinaryMessage.this);
}
channel.getReadSetter().set(new ChannelListener<StreamSourceFrameChannel>() {
@Override
public void handleEvent(StreamSourceFrameChannel channel) {
if(complete ) {
return;
}
try {
for (; ; ) {
if (current == null) {
current = channel.getWebSocketChannel().getBufferPool().allocate();
}
int res = channel.read(current.getBuffer());
if (res == -1) {
complete = true;
channel.suspendReads();
callback.complete(channel.getWebSocketChannel(), BufferedBinaryMessage.this);
return;
} else if (res == 0) {
return;
}
checkMaxSize(channel, res);
if (bufferFullMessage) {
dealWithFullBuffer(channel);
} else if (!current.getBuffer().hasRemaining()) {
callback.complete(channel.getWebSocketChannel(), BufferedBinaryMessage.this);
}
}
} catch (IOException e) {
channel.suspendReads();
callback.onError(channel.getWebSocketChannel(), BufferedBinaryMessage.this, e);
}
}
});
channel.resumeReads();
return;
}
checkMaxSize(channel, res);
if (bufferFullMessage) {
dealWithFullBuffer(channel);
} else if (!current.getBuffer().hasRemaining()) {
callback.complete(channel.getWebSocketChannel(), BufferedBinaryMessage.this);
}
}
} catch (IOException e) {
callback.onError(channel.getWebSocketChannel(), this, e);
}
}
private void checkMaxSize(StreamSourceFrameChannel channel, int res) throws IOException {
currentSize += res;
if (maxMessageSize > 0 && currentSize > maxMessageSize) {
getData().free();
WebSockets.sendClose(new CloseMessage(CloseMessage.MSG_TOO_BIG, WebSocketMessages.MESSAGES.messageToBig(maxMessageSize)), channel.getWebSocketChannel(), null);
throw new IOException(WebSocketMessages.MESSAGES.messageToBig(maxMessageSize));
}
}
public Pooled<ByteBuffer[]> getData() {
if (current == null) {
return new ImmediatePooled<>(new ByteBuffer[0]);
}
if (data.isEmpty()) {
final PooledByteBuffer current = this.current;
current.getBuffer().flip();
this.current = null;
final ByteBuffer[] data = new ByteBuffer[]{current.getBuffer()};
return new PooledByteBufferArray(Collections.singletonList(current), data);
}
current.getBuffer().flip();
data.add(current);
current = null;
ByteBuffer[] ret = new ByteBuffer[data.size()];
for (int i = 0; i < data.size(); ++i) {
ret[i] = data.get(i).getBuffer();
}
List<PooledByteBuffer> data = this.data;
this.data = new ArrayList<>();
return new PooledByteBufferArray(data, ret);
}
public boolean isComplete() {
return complete;
}
private static final class PooledByteBufferArray implements Pooled<ByteBuffer[]> {
private final List<PooledByteBuffer> pooled;
private final ByteBuffer[] data;
private PooledByteBufferArray(List<PooledByteBuffer> pooled, ByteBuffer[] data) {
this.pooled = pooled;
this.data = data;
}
@Override
public void discard() {
for (PooledByteBuffer item : pooled) {
item.close();
}
}
@Override
public void free() {
for (PooledByteBuffer item : pooled) {
item.close();
}
}
@Override
public ByteBuffer[] getResource() throws IllegalStateException {
return data;
}
@Override
public void close() {
free();
}
}
}