package com.datastax.oss.protocol.internal.request;
import static com.datastax.oss.protocol.internal.ProtocolConstants.Version.V5;
import com.datastax.oss.protocol.internal.Message;
import com.datastax.oss.protocol.internal.PrimitiveCodec;
import com.datastax.oss.protocol.internal.PrimitiveSizes;
import com.datastax.oss.protocol.internal.ProtocolConstants;
import com.datastax.oss.protocol.internal.ProtocolErrors;
import com.datastax.oss.protocol.internal.request.query.Values;
import com.datastax.oss.protocol.internal.util.Flags;
import com.datastax.oss.protocol.internal.util.collection.NullAllowingImmutableList;
import java.nio.ByteBuffer;
import java.util.List;
public class Batch extends Message {
public final byte type;
public final List<Object> queriesOrIds;
public final List<List<ByteBuffer>> values;
public final int consistency;
public final int serialConsistency;
public final long defaultTimestamp;
public final String keyspace;
public final int flags;
public Batch(
int flags,
byte type,
List<Object> queriesOrIds,
List<List<ByteBuffer>> values,
int consistency,
int serialConsistency,
long defaultTimestamp,
String keyspace) {
super(false, ProtocolConstants.Opcode.BATCH);
this.type = type;
this.queriesOrIds = queriesOrIds;
this.values = values;
this.consistency = consistency;
this.serialConsistency = serialConsistency;
this.defaultTimestamp = defaultTimestamp;
this.keyspace = keyspace;
this.flags = flags;
}
public Batch(
byte type,
List<Object> queriesOrIds,
List<List<ByteBuffer>> values,
int consistency,
int serialConsistency,
long defaultTimestamp,
String keyspace) {
this(
computeFlags(serialConsistency, defaultTimestamp, keyspace),
type,
queriesOrIds,
values,
consistency,
serialConsistency,
defaultTimestamp,
keyspace);
}
@Override
public String toString() {
return "BATCH(" + queriesOrIds.size() + " statements)";
}
protected static int computeFlags(int serialConsistency, long defaultTimestamp, String keyspace) {
int flags = 0;
if (serialConsistency != ProtocolConstants.ConsistencyLevel.SERIAL) {
flags = Flags.add(flags, ProtocolConstants.QueryFlag.SERIAL_CONSISTENCY);
}
if (defaultTimestamp != Long.MIN_VALUE) {
flags = Flags.add(flags, ProtocolConstants.QueryFlag.DEFAULT_TIMESTAMP);
}
if (keyspace != null) {
flags = Flags.add(flags, ProtocolConstants.QueryFlag.WITH_KEYSPACE);
}
return flags;
}
public static class Codec extends Message.Codec {
public Codec(int protocolVersion) {
super(ProtocolConstants.Opcode.BATCH, protocolVersion);
}
@Override
public <B> void encode(B dest, Message message, PrimitiveCodec<B> encoder) {
Batch batch = (Batch) message;
encoder.writeByte(batch.type, dest);
int queryCount = batch.queriesOrIds.size();
encoder.writeUnsignedShort(queryCount, dest);
for (int i = 0; i < queryCount; i++) {
Object q = batch.queriesOrIds.get(i);
if (q instanceof String) {
encoder.writeByte((byte) 0, dest);
encoder.writeLongString((String) q, dest);
} else {
encoder.writeByte((byte) 1, dest);
encoder.writeShortBytes((byte[]) q, dest);
}
Values.writePositionalValues(batch.values.get(i), dest, encoder);
}
encoder.writeUnsignedShort(batch.consistency, dest);
if (protocolVersion >= V5) {
encoder.writeInt(batch.flags, dest);
} else {
encoder.writeByte((byte) batch.flags, dest);
}
if (Flags.contains(batch.flags, ProtocolConstants.QueryFlag.SERIAL_CONSISTENCY)) {
encoder.writeUnsignedShort(batch.serialConsistency, dest);
}
if (Flags.contains(batch.flags, ProtocolConstants.QueryFlag.DEFAULT_TIMESTAMP)) {
encoder.writeLong(batch.defaultTimestamp, dest);
}
if (Flags.contains(batch.flags, ProtocolConstants.QueryFlag.WITH_KEYSPACE)) {
encoder.writeString(batch.keyspace, dest);
}
}
@Override
public int encodedSize(Message message) {
Batch batch = (Batch) message;
int size = PrimitiveSizes.BYTE;
size += PrimitiveSizes.SHORT;
int queryCount = batch.queriesOrIds.size();
ProtocolErrors.check(
queryCount <= 0xFFFF, "Batch messages can contain at most %d queries", 0xFFFF);
ProtocolErrors.check(
batch.values.size() == queryCount,
"Batch contains %d queries but %d value lists",
queryCount,
batch.values.size());
for (int i = 0; i < queryCount; i++) {
Object q = batch.queriesOrIds.get(i);
size +=
PrimitiveSizes.BYTE
+ (q instanceof String
? PrimitiveSizes.sizeOfLongString((String) q)
: PrimitiveSizes.sizeOfShortBytes((byte[]) q));
size += Values.sizeOfPositionalValues(batch.values.get(i));
}
size += PrimitiveSizes.SHORT;
size += (protocolVersion >= V5) ? PrimitiveSizes.INT : PrimitiveSizes.BYTE;
if (Flags.contains(batch.flags, ProtocolConstants.QueryFlag.SERIAL_CONSISTENCY)) {
size += PrimitiveSizes.SHORT;
}
if (Flags.contains(batch.flags, ProtocolConstants.QueryFlag.DEFAULT_TIMESTAMP)) {
size += PrimitiveSizes.LONG;
}
if (Flags.contains(batch.flags, ProtocolConstants.QueryFlag.WITH_KEYSPACE)) {
size += PrimitiveSizes.sizeOfString(batch.keyspace);
}
return size;
}
@Override
public <B> Message decode(B source, PrimitiveCodec<B> decoder) {
byte type = decoder.readByte(source);
int queryCount = decoder.readUnsignedShort(source);
NullAllowingImmutableList.Builder<Object> queriesOrIds =
NullAllowingImmutableList.builder(queryCount);
NullAllowingImmutableList.Builder<List<ByteBuffer>> values =
NullAllowingImmutableList.builder(queryCount);
for (int i = 0; i < queryCount; i++) {
boolean isQueryString = (decoder.readByte(source) == 0);
queriesOrIds.add(
isQueryString ? decoder.readLongString(source) : decoder.readShortBytes(source));
values.add(Values.readPositionalValues(source, decoder));
}
int consistency = decoder.readUnsignedShort(source);
int flags =
(protocolVersion >= ProtocolConstants.Version.V5)
? decoder.readInt(source)
: decoder.readByte(source);
int serialConsistency =
(Flags.contains(flags, ProtocolConstants.QueryFlag.SERIAL_CONSISTENCY))
? decoder.readUnsignedShort(source)
: ProtocolConstants.ConsistencyLevel.SERIAL;
long defaultTimestamp =
(Flags.contains(flags, ProtocolConstants.QueryFlag.DEFAULT_TIMESTAMP))
? decoder.readLong(source)
: Long.MIN_VALUE;
String keyspace =
(Flags.contains(flags, ProtocolConstants.QueryFlag.WITH_KEYSPACE))
? decoder.readString(source)
: null;
return new Batch(
flags,
type,
queriesOrIds.build(),
values.build(),
consistency,
serialConsistency,
defaultTimestamp,
keyspace);
}
}
}