package jdk.internal.net.http.websocket;
import jdk.internal.net.http.common.Logger;
import jdk.internal.net.http.common.Utils;
import jdk.internal.net.http.websocket.Frame.Opcode;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
import static jdk.internal.net.http.common.Utils.dump;
import static jdk.internal.net.http.websocket.StatusCodes.NO_STATUS_CODE;
import static jdk.internal.net.http.websocket.StatusCodes.isLegalToReceiveFromServer;
class MessageDecoder implements Frame.Consumer {
private static final Logger debug =
Utils.getWebSocketLogger("[Input]"::toString, Utils.DEBUG_WS);
private final MessageStreamConsumer output;
private final UTF8AccumulatingDecoder decoder = new UTF8AccumulatingDecoder();
private boolean fin;
private Opcode opcode, originatingOpcode;
private long payloadLen;
private long unconsumedPayloadLen;
private ByteBuffer binaryData;
MessageDecoder(MessageStreamConsumer output) {
this.output = requireNonNull(output);
}
MessageStreamConsumer getOutput() {
return output;
}
@Override
public void fin(boolean value) {
if (debug.on()) {
debug.log("fin %s", value);
}
fin = value;
}
@Override
public void rsv1(boolean value) {
if (debug.on()) {
debug.log("rsv1 %s", value);
}
if (value) {
throw new FailWebSocketException("Unexpected rsv1 bit");
}
}
@Override
public void rsv2(boolean value) {
if (debug.on()) {
debug.log("rsv2 %s", value);
}
if (value) {
throw new FailWebSocketException("Unexpected rsv2 bit");
}
}
@Override
public void rsv3(boolean value) {
if (debug.on()) {
debug.log("rsv3 %s", value);
}
if (value) {
throw new FailWebSocketException("Unexpected rsv3 bit");
}
}
@Override
public void opcode(Opcode v) {
if (debug.on()) {
debug.log("opcode %s", v);
}
if (v == Opcode.PING || v == Opcode.PONG || v == Opcode.CLOSE) {
if (!fin) {
throw new FailWebSocketException("Fragmented control frame " + v);
}
opcode = v;
} else if (v == Opcode.TEXT || v == Opcode.BINARY) {
if (originatingOpcode != null) {
throw new FailWebSocketException(
format("Unexpected frame %s (fin=%s)", v, fin));
}
opcode = v;
if (!fin) {
originatingOpcode = v;
}
} else if (v == Opcode.CONTINUATION) {
if (originatingOpcode == null) {
throw new FailWebSocketException(
format("Unexpected frame %s (fin=%s)", v, fin));
}
opcode = v;
} else {
throw new FailWebSocketException("Unexpected opcode " + v);
}
}
@Override
public void mask(boolean value) {
if (debug.on()) {
debug.log("mask %s", value);
}
if (value) {
throw new FailWebSocketException("Masked frame received");
}
}
@Override
public void payloadLen(long value) {
if (debug.on()) {
debug.log("payloadLen %s", value);
}
if (opcode.isControl()) {
if (value > Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH) {
throw new FailWebSocketException(
format("%s's payload length %s", opcode, value));
}
assert Opcode.CLOSE.isControl();
if (opcode == Opcode.CLOSE && value == 1) {
throw new FailWebSocketException("Incomplete status code");
}
}
payloadLen = value;
unconsumedPayloadLen = value;
}
@Override
public void maskingKey(int value) {
throw new InternalError();
}
@Override
public void payloadData(ByteBuffer data) {
if (debug.on()) {
debug.log("payload %s", data);
}
unconsumedPayloadLen -= data.remaining();
boolean lastPayloadChunk = unconsumedPayloadLen == 0;
if (opcode.isControl()) {
if (binaryData != null) {
binaryData.put(data);
} else if (!lastPayloadChunk) {
int remaining = data.remaining();
assert remaining < Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH
: dump(remaining);
binaryData = ByteBuffer.allocate(
Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH).put(data);
} else {
binaryData = ByteBuffer.allocate(data.remaining()).put(data);
}
} else {
boolean last = fin && lastPayloadChunk;
boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT;
if (!text) {
output.onBinary(data.slice(), last);
data.position(data.limit());
} else {
boolean binaryNonEmpty = data.hasRemaining();
CharBuffer textData;
try {
textData = decoder.decode(data, last);
} catch (CharacterCodingException e) {
throw new FailWebSocketException(
"Invalid UTF-8 in frame " + opcode,
StatusCodes.NOT_CONSISTENT).initCause(e);
}
if (!(binaryNonEmpty && !textData.hasRemaining())) {
output.onText(textData, last);
}
}
}
}
@Override
public void endFrame() {
if (debug.on()) {
debug.log("end frame");
}
if (opcode.isControl()) {
binaryData.flip();
}
switch (opcode) {
case CLOSE:
char statusCode = NO_STATUS_CODE;
String reason = "";
if (payloadLen != 0) {
int len = binaryData.remaining();
assert 2 <= len
&& len <= Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH
: dump(len, payloadLen);
statusCode = binaryData.getChar();
if (!isLegalToReceiveFromServer(statusCode)) {
throw new FailWebSocketException(
"Illegal status code: " + statusCode);
}
try {
reason = UTF_8.newDecoder().decode(binaryData).toString();
} catch (CharacterCodingException e) {
throw new FailWebSocketException("Illegal close reason")
.initCause(e);
}
}
output.onClose(statusCode, reason);
break;
case PING:
output.onPing(binaryData);
binaryData = null;
break;
case PONG:
output.onPong(binaryData);
binaryData = null;
break;
default:
assert opcode == Opcode.TEXT || opcode == Opcode.BINARY
|| opcode == Opcode.CONTINUATION : dump(opcode);
if (fin) {
originatingOpcode = null;
}
break;
}
payloadLen = 0;
opcode = null;
}
}