package io.undertow.websockets.core.protocol.version07;
import io.undertow.server.protocol.framed.AbstractFramedStreamSourceChannel;
import io.undertow.websockets.core.StreamSinkFrameChannel;
import io.undertow.websockets.core.StreamSourceFrameChannel;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketException;
import io.undertow.websockets.core.WebSocketFrame;
import io.undertow.websockets.core.WebSocketFrameCorruptedException;
import io.undertow.websockets.core.WebSocketFrameType;
import io.undertow.websockets.core.WebSocketLogger;
import io.undertow.websockets.core.WebSocketMessages;
import io.undertow.websockets.core.WebSocketVersion;
import io.undertow.websockets.extensions.ExtensionFunction;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.StreamConnection;
import java.nio.ByteBuffer;
import java.util.Set;
public class WebSocket07Channel extends WebSocketChannel {
private enum State {
READING_FIRST,
READING_SECOND,
READING_EXTENDED_SIZE1,
READING_EXTENDED_SIZE2,
READING_EXTENDED_SIZE3,
READING_EXTENDED_SIZE4,
READING_EXTENDED_SIZE5,
READING_EXTENDED_SIZE6,
READING_EXTENDED_SIZE7,
READING_EXTENDED_SIZE8,
READING_MASK_1,
READING_MASK_2,
READING_MASK_3,
READING_MASK_4,
DONE,
}
private int fragmentedFramesCount;
private final ByteBuffer lengthBuffer = ByteBuffer.allocate(8);
private UTF8Checker checker;
protected static final byte OPCODE_CONT = 0x0;
protected static final byte OPCODE_TEXT = 0x1;
protected static final byte OPCODE_BINARY = 0x2;
protected static final byte OPCODE_CLOSE = 0x8;
protected static final byte OPCODE_PING = 0x9;
protected static final byte OPCODE_PONG = 0xA;
public WebSocket07Channel(StreamConnection channel, ByteBufferPool bufferPool,
String wsUrl, String subProtocol, final boolean client, boolean allowExtensions, final ExtensionFunction extensionFunction, Set<WebSocketChannel> openConnections, OptionMap options) {
super(channel, bufferPool, WebSocketVersion.V08, wsUrl, subProtocol, client, allowExtensions, extensionFunction, openConnections, options);
}
@Override
protected PartialFrame receiveFrame() {
return new WebSocketFrameHeader();
}
@Override
protected void markReadsBroken(Throwable cause) {
super.markReadsBroken(cause);
}
@Override
protected void closeSubChannels() {
IoUtils.safeClose(fragmentedChannel);
}
@Override
protected StreamSinkFrameChannel createStreamSinkChannel(WebSocketFrameType type) {
switch (type) {
case TEXT:
return new WebSocket07TextFrameSinkChannel(this);
case BINARY:
return new WebSocket07BinaryFrameSinkChannel(this);
case CLOSE:
return new WebSocket07CloseFrameSinkChannel(this);
case PONG:
return new WebSocket07PongFrameSinkChannel(this);
case PING:
return new WebSocket07PingFrameSinkChannel(this);
default:
throw WebSocketMessages.MESSAGES.unsupportedFrameType(type);
}
}
class implements WebSocketFrame {
private boolean ;
private int ;
private int ;
private int ;
private boolean ;
private long ;
private State = State.READING_FIRST;
private int ;
private boolean = false;
@Override
public StreamSourceFrameChannel (PooledByteBuffer pooled) {
StreamSourceFrameChannel channel = createChannel(pooled);
if (frameFinalFlag) {
channel.finalFrame();
} else {
fragmentedChannel = channel;
}
return channel;
}
public StreamSourceFrameChannel (PooledByteBuffer pooled) {
if (frameOpcode == OPCODE_PING) {
if (frameMasked) {
return new WebSocket07PingFrameSourceChannel(WebSocket07Channel.this, frameRsv, new Masker(maskingKey), pooled, framePayloadLength);
} else {
return new WebSocket07PingFrameSourceChannel(WebSocket07Channel.this, frameRsv, pooled, framePayloadLength);
}
}
if (frameOpcode == OPCODE_PONG) {
if (frameMasked) {
return new WebSocket07PongFrameSourceChannel(WebSocket07Channel.this, frameRsv, new Masker(maskingKey), pooled, framePayloadLength);
} else {
return new WebSocket07PongFrameSourceChannel(WebSocket07Channel.this, frameRsv, pooled, framePayloadLength);
}
}
if (frameOpcode == OPCODE_CLOSE) {
if (frameMasked) {
return new WebSocket07CloseFrameSourceChannel(WebSocket07Channel.this, frameRsv, new Masker(maskingKey), pooled, framePayloadLength);
} else {
return new WebSocket07CloseFrameSourceChannel(WebSocket07Channel.this, frameRsv, pooled, framePayloadLength);
}
}
if (frameOpcode == OPCODE_TEXT) {
UTF8Checker checker = WebSocket07Channel.this.checker;
if (checker == null) {
checker = new UTF8Checker();
}
if (!frameFinalFlag) {
WebSocket07Channel.this.checker = checker;
} else {
WebSocket07Channel.this.checker = null;
}
if (frameMasked) {
return new WebSocket07TextFrameSourceChannel(WebSocket07Channel.this, frameRsv, frameFinalFlag, new Masker(maskingKey), checker, pooled, framePayloadLength);
} else {
return new WebSocket07TextFrameSourceChannel(WebSocket07Channel.this, frameRsv, frameFinalFlag, checker, pooled, framePayloadLength);
}
} else if (frameOpcode == OPCODE_BINARY) {
if (frameMasked) {
return new WebSocket07BinaryFrameSourceChannel(WebSocket07Channel.this, frameRsv, frameFinalFlag, new Masker(maskingKey), pooled, framePayloadLength);
} else {
return new WebSocket07BinaryFrameSourceChannel(WebSocket07Channel.this, frameRsv, frameFinalFlag, pooled, framePayloadLength);
}
} else if (frameOpcode == OPCODE_CONT) {
throw new RuntimeException();
} else {
if (hasReservedOpCode) {
if (frameMasked) {
return new WebSocket07BinaryFrameSourceChannel(WebSocket07Channel.this, frameRsv, frameFinalFlag, new Masker(maskingKey), pooled, framePayloadLength);
} else {
return new WebSocket07BinaryFrameSourceChannel(WebSocket07Channel.this, frameRsv, frameFinalFlag, pooled, framePayloadLength);
}
} else {
throw WebSocketMessages.MESSAGES.unsupportedOpCode(frameOpcode);
}
}
}
@Override
public void handle(final ByteBuffer buffer) throws WebSocketException {
if (!buffer.hasRemaining()) {
return;
}
while (state != State.DONE) {
byte b;
switch (state) {
case READING_FIRST:
b = buffer.get();
frameFinalFlag = (b & 0x80) != 0;
frameRsv = (b & 0x70) >> 4;
frameOpcode = b & 0x0F;
if (WebSocketLogger.REQUEST_LOGGER.isDebugEnabled()) {
WebSocketLogger.REQUEST_LOGGER.decodingFrameWithOpCode(frameOpcode);
}
state = State.READING_SECOND;
lengthBuffer.clear();
case READING_SECOND:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
frameMasked = (b & 0x80) != 0;
framePayloadLen1 = b & 0x7F;
if (frameRsv != 0) {
if (!areExtensionsSupported()) {
throw WebSocketMessages.MESSAGES.extensionsNotAllowed(frameRsv);
}
}
if (frameOpcode > 7) {
validateControlFrame();
} else {
validateDataFrame();
}
if (framePayloadLen1 == 126 || framePayloadLen1 == 127) {
state = State.READING_EXTENDED_SIZE1;
} else {
framePayloadLength = framePayloadLen1;
if (frameMasked) {
state = State.READING_MASK_1;
} else {
state = State.DONE;
}
continue;
}
case READING_EXTENDED_SIZE1:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
lengthBuffer.put(b);
state = State.READING_EXTENDED_SIZE2;
case READING_EXTENDED_SIZE2:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
lengthBuffer.put(b);
if (framePayloadLen1 == 126) {
lengthBuffer.flip();
framePayloadLength = lengthBuffer.getShort() & 0xFFFF;
if (frameMasked) {
state = State.READING_MASK_1;
} else {
state = State.DONE;
}
continue;
}
state = State.READING_EXTENDED_SIZE3;
case READING_EXTENDED_SIZE3:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
lengthBuffer.put(b);
state = State.READING_EXTENDED_SIZE4;
case READING_EXTENDED_SIZE4:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
lengthBuffer.put(b);
state = State.READING_EXTENDED_SIZE5;
case READING_EXTENDED_SIZE5:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
lengthBuffer.put(b);
state = State.READING_EXTENDED_SIZE6;
case READING_EXTENDED_SIZE6:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
lengthBuffer.put(b);
state = State.READING_EXTENDED_SIZE7;
case READING_EXTENDED_SIZE7:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
lengthBuffer.put(b);
state = State.READING_EXTENDED_SIZE8;
case READING_EXTENDED_SIZE8:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
lengthBuffer.put(b);
lengthBuffer.flip();
framePayloadLength = lengthBuffer.getLong();
if (frameMasked) {
state = State.READING_MASK_1;
} else {
state = State.DONE;
break;
}
state = State.READING_MASK_1;
case READING_MASK_1:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
maskingKey = b & 0xFF;
state = State.READING_MASK_2;
case READING_MASK_2:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
maskingKey = maskingKey << 8 | b & 0xFF;
state = State.READING_MASK_3;
case READING_MASK_3:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
maskingKey = maskingKey << 8 | b & 0xFF;
state = State.READING_MASK_4;
case READING_MASK_4:
if (!buffer.hasRemaining()) {
return;
}
b = buffer.get();
maskingKey = maskingKey << 8 | b & 0xFF;
state = State.DONE;
break;
default:
throw new IllegalStateException(state.toString());
}
}
if (frameFinalFlag) {
if (frameOpcode != OPCODE_PING && frameOpcode != OPCODE_PONG) {
fragmentedFramesCount = 0;
}
} else {
fragmentedFramesCount++;
}
done = true;
}
private void () throws WebSocketFrameCorruptedException {
if (!isClient() && !frameMasked) {
throw WebSocketMessages.MESSAGES.frameNotMasked();
}
if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT || frameOpcode == OPCODE_BINARY)) {
throw WebSocketMessages.MESSAGES.reservedOpCodeInDataFrame(frameOpcode);
}
if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
throw WebSocketMessages.MESSAGES.continuationFrameOutsideFragmented();
}
if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
throw WebSocketMessages.MESSAGES.nonContinuationFrameInsideFragmented();
}
}
private void () throws WebSocketFrameCorruptedException {
if (!frameFinalFlag) {
throw WebSocketMessages.MESSAGES.fragmentedControlFrame();
}
if (framePayloadLen1 > 125) {
throw WebSocketMessages.MESSAGES.toBigControlFrame();
}
if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING || frameOpcode == OPCODE_PONG)) {
throw WebSocketMessages.MESSAGES.reservedOpCodeInControlFrame(frameOpcode);
}
if (frameOpcode == 8 && framePayloadLen1 == 1) {
throw WebSocketMessages.MESSAGES.controlFrameWithPayloadLen1();
}
}
@Override
public boolean () {
return done;
}
@Override
public long () {
return framePayloadLength;
}
int () {
return maskingKey;
}
@Override
public AbstractFramedStreamSourceChannel<?, ?, ?> () {
if (frameOpcode == OPCODE_CONT) {
StreamSourceFrameChannel ret = fragmentedChannel;
if(frameFinalFlag) {
fragmentedChannel = null;
}
return ret;
}
return null;
}
@Override
public boolean () {
return frameFinalFlag;
}
}
}