package org.apache.cassandra.transport;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CoderResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.FastThreadLocal;
import org.apache.cassandra.config.Config;
import org.apache.cassandra.db.ConsistencyLevel;
import org.apache.cassandra.db.TypeSizes;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.UUIDGen;
public abstract class CBUtil
{
public static final boolean USE_HEAP_ALLOCATOR = Boolean.getBoolean(Config.PROPERTY_PREFIX + "netty_use_heap_allocator");
public static final ByteBufAllocator allocator = USE_HEAP_ALLOCATOR ? new UnpooledByteBufAllocator(false) : new PooledByteBufAllocator(true);
private final static FastThreadLocal<CharsetDecoder> TL_UTF8_DECODER = new FastThreadLocal<CharsetDecoder>()
{
@Override
protected CharsetDecoder initialValue()
{
return Charset.forName("UTF-8").newDecoder();
}
};
private final static FastThreadLocal<CharBuffer> TL_CHAR_BUFFER = new FastThreadLocal<>();
private CBUtil() {}
private static String decodeString(ByteBuffer src) throws CharacterCodingException
{
CharsetDecoder theDecoder = TL_UTF8_DECODER.get();
theDecoder.reset();
CharBuffer dst = TL_CHAR_BUFFER.get();
int capacity = (int) ((double) src.remaining() * theDecoder.maxCharsPerByte());
if (dst == null)
{
capacity = Math.max(capacity, 4096);
dst = CharBuffer.allocate(capacity);
TL_CHAR_BUFFER.set(dst);
}
else
{
dst.clear();
if (dst.capacity() < capacity)
{
dst = CharBuffer.allocate(capacity);
TL_CHAR_BUFFER.set(dst);
}
}
CoderResult cr = theDecoder.decode(src, dst, true);
if (!cr.isUnderflow())
cr.throwException();
return dst.flip().toString();
}
private static String readString(ByteBuf cb, int length)
{
if (length == 0)
return "";
ByteBuffer buffer = cb.nioBuffer(cb.readerIndex(), length);
try
{
String str = decodeString(buffer);
cb.readerIndex(cb.readerIndex() + length);
return str;
}
catch (IllegalStateException | CharacterCodingException e)
{
throw new ProtocolException("Cannot decode string as UTF8: '" + ByteBufferUtil.bytesToHex(buffer) + "'; " + e);
}
}
public static String readString(ByteBuf cb)
{
try
{
int length = cb.readUnsignedShort();
return readString(cb, length);
}
catch (IndexOutOfBoundsException e)
{
throw new ProtocolException("Not enough bytes to read an UTF8 serialized string preceded by its 2 bytes length");
}
}
public static void writeString(String str, ByteBuf cb)
{
int writerIndex = cb.writerIndex();
cb.writeShort(0);
int lengthBytes = ByteBufUtil.writeUtf8(cb, str);
cb.setShort(writerIndex, lengthBytes);
}
public static int sizeOfString(String str)
{
return 2 + TypeSizes.encodedUTF8Length(str);
}
public static String readLongString(ByteBuf cb)
{
try
{
int length = cb.readInt();
return readString(cb, length);
}
catch (IndexOutOfBoundsException e)
{
throw new ProtocolException("Not enough bytes to read an UTF8 serialized string preceded by its 4 bytes length");
}
}
public static void writeLongString(String str, ByteBuf cb)
{
byte[] bytes = str.getBytes(CharsetUtil.UTF_8);
cb.writeInt(bytes.length);
cb.writeBytes(bytes);
}
public static int sizeOfLongString(String str)
{
return 4 + str.getBytes(CharsetUtil.UTF_8).length;
}
public static byte[] readBytes(ByteBuf cb)
{
try
{
int length = cb.readUnsignedShort();
byte[] bytes = new byte[length];
cb.readBytes(bytes);
return bytes;
}
catch (IndexOutOfBoundsException e)
{
throw new ProtocolException("Not enough bytes to read a byte array preceded by its 2 bytes length");
}
}
public static void writeBytes(byte[] bytes, ByteBuf cb)
{
cb.writeShort(bytes.length);
cb.writeBytes(bytes);
}
public static int sizeOfBytes(byte[] bytes)
{
return 2 + bytes.length;
}
public static Map<String, ByteBuffer> readBytesMap(ByteBuf cb)
{
int length = cb.readUnsignedShort();
Map<String, ByteBuffer> m = new HashMap<>(length);
for (int i = 0; i < length; i++)
{
String k = readString(cb);
ByteBuffer v = readValue(cb);
m.put(k, v);
}
return m;
}
public static void writeBytesMap(Map<String, ByteBuffer> m, ByteBuf cb)
{
cb.writeShort(m.size());
for (Map.Entry<String, ByteBuffer> entry : m.entrySet())
{
writeString(entry.getKey(), cb);
writeValue(entry.getValue(), cb);
}
}
public static int sizeOfBytesMap(Map<String, ByteBuffer> m)
{
int size = 2;
for (Map.Entry<String, ByteBuffer> entry : m.entrySet())
{
size += sizeOfString(entry.getKey());
size += sizeOfValue(entry.getValue());
}
return size;
}
public static ConsistencyLevel readConsistencyLevel(ByteBuf cb)
{
return ConsistencyLevel.fromCode(cb.readUnsignedShort());
}
public static void writeConsistencyLevel(ConsistencyLevel consistency, ByteBuf cb)
{
cb.writeShort(consistency.code);
}
public static int sizeOfConsistencyLevel(ConsistencyLevel consistency)
{
return 2;
}
public static <T extends Enum<T>> T readEnumValue(Class<T> enumType, ByteBuf cb)
{
String value = CBUtil.readString(cb);
try
{
return Enum.valueOf(enumType, value.toUpperCase());
}
catch (IllegalArgumentException e)
{
throw new ProtocolException(String.format("Invalid value '%s' for %s", value, enumType.getSimpleName()));
}
}
public static <T extends Enum<T>> void writeEnumValue(T enumValue, ByteBuf cb)
{
writeString(enumValue.toString(), cb);
}
public static <T extends Enum<T>> int sizeOfEnumValue(T enumValue)
{
return sizeOfString(enumValue.toString());
}
public static UUID readUUID(ByteBuf cb)
{
byte[] bytes = new byte[16];
cb.readBytes(bytes);
return UUIDGen.getUUID(ByteBuffer.wrap(bytes));
}
public static void writeUUID(UUID uuid, ByteBuf cb)
{
cb.writeBytes(UUIDGen.decompose(uuid));
}
public static int sizeOfUUID(UUID uuid)
{
return 16;
}
public static List<String> readStringList(ByteBuf cb)
{
int length = cb.readUnsignedShort();
List<String> l = new ArrayList<String>(length);
for (int i = 0; i < length; i++)
l.add(readString(cb));
return l;
}
public static void writeStringList(List<String> l, ByteBuf cb)
{
cb.writeShort(l.size());
for (String str : l)
writeString(str, cb);
}
public static int sizeOfStringList(List<String> l)
{
int size = 2;
for (String str : l)
size += sizeOfString(str);
return size;
}
public static Map<String, String> readStringMap(ByteBuf cb)
{
int length = cb.readUnsignedShort();
Map<String, String> m = new HashMap<String, String>(length);
for (int i = 0; i < length; i++)
{
String k = readString(cb);
String v = readString(cb);
m.put(k, v);
}
return m;
}
public static void writeStringMap(Map<String, String> m, ByteBuf cb)
{
cb.writeShort(m.size());
for (Map.Entry<String, String> entry : m.entrySet())
{
writeString(entry.getKey(), cb);
writeString(entry.getValue(), cb);
}
}
public static int sizeOfStringMap(Map<String, String> m)
{
int size = 2;
for (Map.Entry<String, String> entry : m.entrySet())
{
size += sizeOfString(entry.getKey());
size += sizeOfString(entry.getValue());
}
return size;
}
public static Map<String, List<String>> readStringToStringListMap(ByteBuf cb)
{
int length = cb.readUnsignedShort();
Map<String, List<String>> m = new HashMap<String, List<String>>(length);
for (int i = 0; i < length; i++)
{
String k = readString(cb).toUpperCase();
List<String> v = readStringList(cb);
m.put(k, v);
}
return m;
}
public static void writeStringToStringListMap(Map<String, List<String>> m, ByteBuf cb)
{
cb.writeShort(m.size());
for (Map.Entry<String, List<String>> entry : m.entrySet())
{
writeString(entry.getKey(), cb);
writeStringList(entry.getValue(), cb);
}
}
public static int sizeOfStringToStringListMap(Map<String, List<String>> m)
{
int size = 2;
for (Map.Entry<String, List<String>> entry : m.entrySet())
{
size += sizeOfString(entry.getKey());
size += sizeOfStringList(entry.getValue());
}
return size;
}
public static ByteBuffer readValue(ByteBuf cb)
{
int length = cb.readInt();
if (length < 0)
return null;
ByteBuf slice = cb.readSlice(length);
return ByteBuffer.wrap(readRawBytes(slice));
}
public static ByteBuffer readBoundValue(ByteBuf cb, ProtocolVersion protocolVersion)
{
int length = cb.readInt();
if (length < 0)
{
if (protocolVersion.isSmallerThan(ProtocolVersion.V4))
return null;
if (length == -1)
return null;
else if (length == -2)
return ByteBufferUtil.UNSET_BYTE_BUFFER;
else
throw new ProtocolException("Invalid ByteBuf length " + length);
}
ByteBuf slice = cb.readSlice(length);
return ByteBuffer.wrap(readRawBytes(slice));
}
public static void writeValue(byte[] bytes, ByteBuf cb)
{
if (bytes == null)
{
cb.writeInt(-1);
return;
}
cb.writeInt(bytes.length);
cb.writeBytes(bytes);
}
public static void writeValue(ByteBuffer bytes, ByteBuf cb)
{
if (bytes == null)
{
cb.writeInt(-1);
return;
}
int remaining = bytes.remaining();
cb.writeInt(remaining);
if (remaining > 0)
cb.writeBytes(bytes.duplicate());
}
public static int sizeOfValue(byte[] bytes)
{
return 4 + (bytes == null ? 0 : bytes.length);
}
public static int sizeOfValue(ByteBuffer bytes)
{
return 4 + (bytes == null ? 0 : bytes.remaining());
}
public static int sizeOfValue(int valueSize)
{
return 4 + (valueSize < 0 ? 0 : valueSize);
}
public static List<ByteBuffer> readValueList(ByteBuf cb, ProtocolVersion protocolVersion)
{
int size = cb.readUnsignedShort();
if (size == 0)
return Collections.<ByteBuffer>emptyList();
List<ByteBuffer> l = new ArrayList<ByteBuffer>(size);
for (int i = 0; i < size; i++)
l.add(readBoundValue(cb, protocolVersion));
return l;
}
public static void writeValueList(List<ByteBuffer> values, ByteBuf cb)
{
cb.writeShort(values.size());
for (ByteBuffer value : values)
CBUtil.writeValue(value, cb);
}
public static int sizeOfValueList(List<ByteBuffer> values)
{
int size = 2;
for (ByteBuffer value : values)
size += CBUtil.sizeOfValue(value);
return size;
}
public static Pair<List<String>, List<ByteBuffer>> readNameAndValueList(ByteBuf cb, ProtocolVersion protocolVersion)
{
int size = cb.readUnsignedShort();
if (size == 0)
return Pair.create(Collections.<String>emptyList(), Collections.<ByteBuffer>emptyList());
List<String> s = new ArrayList<>(size);
List<ByteBuffer> l = new ArrayList<>(size);
for (int i = 0; i < size; i++)
{
s.add(readString(cb));
l.add(readBoundValue(cb, protocolVersion));
}
return Pair.create(s, l);
}
public static InetSocketAddress readInet(ByteBuf cb)
{
int addrSize = cb.readByte() & 0xFF;
byte[] address = new byte[addrSize];
cb.readBytes(address);
int port = cb.readInt();
try
{
return new InetSocketAddress(InetAddress.getByAddress(address), port);
}
catch (UnknownHostException e)
{
throw new ProtocolException(String.format("Invalid IP address (%d.%d.%d.%d) while deserializing inet address", address[0], address[1], address[2], address[3]));
}
}
public static void writeInet(InetSocketAddress inet, ByteBuf cb)
{
byte[] address = inet.getAddress().getAddress();
cb.writeByte(address.length);
cb.writeBytes(address);
cb.writeInt(inet.getPort());
}
public static int sizeOfInet(InetSocketAddress inet)
{
byte[] address = inet.getAddress().getAddress();
return 1 + address.length + 4;
}
public static InetAddress readInetAddr(ByteBuf cb)
{
int addressSize = cb.readByte() & 0xFF;
byte[] address = new byte[addressSize];
cb.readBytes(address);
try
{
return InetAddress.getByAddress(address);
}
catch (UnknownHostException e)
{
throw new ProtocolException("Invalid IP address while deserializing inet address");
}
}
public static void writeInetAddr(InetAddress inetAddr, ByteBuf cb)
{
byte[] address = inetAddr.getAddress();
cb.writeByte(address.length);
cb.writeBytes(address);
}
public static int sizeOfInetAddr(InetAddress inetAddr)
{
return 1 + inetAddr.getAddress().length;
}
public static byte[] readRawBytes(ByteBuf cb)
{
byte[] bytes = new byte[cb.readableBytes()];
cb.readBytes(bytes);
return bytes;
}
}