package com.datastax.oss.driver.internal.core.protocol;
import com.datastax.oss.driver.api.core.connection.FrameTooLongException;
import com.datastax.oss.driver.internal.core.util.Loggers;
import com.datastax.oss.protocol.internal.Frame;
import com.datastax.oss.protocol.internal.FrameCodec;
import com.datastax.oss.protocol.internal.ProtocolConstants;
import com.datastax.oss.protocol.internal.response.Error;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.TooLongFrameException;
import java.util.Collections;
import net.jcip.annotations.NotThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@NotThreadSafe
public class FrameDecoder extends LengthFieldBasedFrameDecoder {
private static final Logger LOG = LoggerFactory.getLogger(FrameDecoder.class);
private static final int LENGTH_FIELD_OFFSET = 5;
private static final int LENGTH_FIELD_LENGTH = 4;
private final FrameCodec<ByteBuf> frameCodec;
private boolean isFirstResponse;
public FrameDecoder(FrameCodec<ByteBuf> frameCodec, int maxFrameLengthInBytes) {
super(maxFrameLengthInBytes, LENGTH_FIELD_OFFSET, LENGTH_FIELD_LENGTH, 0, 0, true);
this.frameCodec = frameCodec;
}
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
int startIndex = in.readerIndex();
if (isFirstResponse) {
isFirstResponse = false;
if (in.readableBytes() < 8) {
return null;
}
int protocolVersion = (int) in.getByte(startIndex) & 0b0111_1111;
if (protocolVersion < 3) {
int streamId = in.getByte(startIndex + 2);
int length = in.getInt(startIndex + 4);
if (in.readableBytes() < 8 + length) {
return null;
} else {
in.readerIndex(startIndex + 8 + length);
}
return Frame.forResponse(
protocolVersion,
streamId,
null,
Frame.NO_PAYLOAD,
Collections.emptyList(),
new Error(
ProtocolConstants.ErrorCode.PROTOCOL_ERROR,
"Invalid or unsupported protocol version"));
}
}
try {
ByteBuf buffer = (ByteBuf) super.decode(ctx, in);
return (buffer == null)
? null
: frameCodec.decode(buffer);
} catch (Exception e) {
int streamId;
try {
streamId = in.getShort(startIndex + 2);
} catch (Exception e1) {
Loggers.warnWithException(LOG, "Unexpected error while reading stream id", e1);
streamId = -1;
}
if (e instanceof TooLongFrameException) {
e = new FrameTooLongException(ctx.channel().remoteAddress(), e.getMessage());
}
throw new FrameDecodingException(streamId, e);
}
}
@Override
protected ByteBuf extractFrame(ChannelHandlerContext ctx, ByteBuf buffer, int index, int length) {
return buffer.slice(index, length);
}
}