package org.apache.cassandra.service.pager;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.*;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.primitives.Ints;
import org.apache.cassandra.config.CFMetaData;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.LegacyLayout;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.BytesType;
import org.apache.cassandra.db.rows.Cell;
import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.io.util.DataInputBuffer;
import org.apache.cassandra.io.util.DataOutputBuffer;
import org.apache.cassandra.io.util.DataOutputBufferFixed;
import org.apache.cassandra.net.MessagingService;
import org.apache.cassandra.transport.ProtocolException;
import org.apache.cassandra.transport.ProtocolVersion;
import static org.apache.cassandra.db.TypeSizes.sizeof;
import static org.apache.cassandra.db.TypeSizes.sizeofUnsignedVInt;
import static org.apache.cassandra.utils.ByteBufferUtil.*;
import static org.apache.cassandra.utils.vint.VIntCoding.computeUnsignedVIntSize;
import static org.apache.cassandra.utils.vint.VIntCoding.getUnsignedVInt;
@SuppressWarnings("WeakerAccess")
public class PagingState
{
public final ByteBuffer partitionKey;
public final RowMark rowMark;
public final int remaining;
public final int remainingInPartition;
public PagingState(ByteBuffer partitionKey, RowMark rowMark, int remaining, int remainingInPartition)
{
this.partitionKey = partitionKey;
this.rowMark = rowMark;
this.remaining = remaining;
this.remainingInPartition = remainingInPartition;
}
public ByteBuffer serialize(ProtocolVersion protocolVersion)
{
assert rowMark == null || protocolVersion == rowMark.protocolVersion;
try
{
return protocolVersion.isGreaterThan(ProtocolVersion.V3) ? modernSerialize() : legacySerialize(true);
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
public int serializedSize(ProtocolVersion protocolVersion)
{
assert rowMark == null || protocolVersion == rowMark.protocolVersion;
return protocolVersion.isGreaterThan(ProtocolVersion.V3) ? modernSerializedSize() : legacySerializedSize(true);
}
public static PagingState deserialize(ByteBuffer bytes, ProtocolVersion protocolVersion)
{
if (bytes == null)
return null;
try
{
if (protocolVersion.isGreaterThan(ProtocolVersion.V3))
{
if (isModernSerialized(bytes)) return modernDeserialize(bytes, protocolVersion);
if (isLegacySerialized(bytes)) return legacyDeserialize(bytes, ProtocolVersion.V3);
}
if (protocolVersion.isSmallerThan(ProtocolVersion.V4))
{
if (isLegacySerialized(bytes)) return legacyDeserialize(bytes, protocolVersion);
if (isModernSerialized(bytes)) return modernDeserialize(bytes, ProtocolVersion.V4);
}
}
catch (IOException e)
{
throw new ProtocolException("Invalid value for the paging state");
}
throw new ProtocolException("Invalid value for the paging state");
}
@SuppressWarnings({ "resource", "RedundantSuppression" })
private ByteBuffer modernSerialize() throws IOException
{
DataOutputBuffer out = new DataOutputBufferFixed(modernSerializedSize());
writeWithVIntLength(null == partitionKey ? EMPTY_BYTE_BUFFER : partitionKey, out);
writeWithVIntLength(null == rowMark ? EMPTY_BYTE_BUFFER : rowMark.mark, out);
out.writeUnsignedVInt(remaining);
out.writeUnsignedVInt(remainingInPartition);
return out.buffer(false);
}
private static boolean isModernSerialized(ByteBuffer bytes)
{
int index = bytes.position();
int limit = bytes.limit();
long partitionKeyLen = getUnsignedVInt(bytes, index, limit);
if (partitionKeyLen < 0)
return false;
index += computeUnsignedVIntSize(partitionKeyLen) + partitionKeyLen;
if (index >= limit)
return false;
long rowMarkerLen = getUnsignedVInt(bytes, index, limit);
if (rowMarkerLen < 0)
return false;
index += computeUnsignedVIntSize(rowMarkerLen) + rowMarkerLen;
if (index >= limit)
return false;
long remaining = getUnsignedVInt(bytes, index, limit);
if (remaining < 0)
return false;
index += computeUnsignedVIntSize(remaining);
if (index >= limit)
return false;
long remainingInPartition = getUnsignedVInt(bytes, index, limit);
if (remainingInPartition < 0)
return false;
index += computeUnsignedVIntSize(remainingInPartition);
return index == limit;
}
@SuppressWarnings({ "resource", "RedundantSuppression" })
private static PagingState modernDeserialize(ByteBuffer bytes, ProtocolVersion protocolVersion) throws IOException
{
if (protocolVersion.isSmallerThan(ProtocolVersion.V4))
throw new IllegalArgumentException();
DataInputBuffer in = new DataInputBuffer(bytes, false);
ByteBuffer partitionKey = readWithVIntLength(in);
ByteBuffer rawMark = readWithVIntLength(in);
int remaining = Ints.checkedCast(in.readUnsignedVInt());
int remainingInPartition = Ints.checkedCast(in.readUnsignedVInt());
return new PagingState(partitionKey.hasRemaining() ? partitionKey : null,
rawMark.hasRemaining() ? new RowMark(rawMark, protocolVersion) : null,
remaining,
remainingInPartition);
}
private int modernSerializedSize()
{
return serializedSizeWithVIntLength(null == partitionKey ? EMPTY_BYTE_BUFFER : partitionKey)
+ serializedSizeWithVIntLength(null == rowMark ? EMPTY_BYTE_BUFFER : rowMark.mark)
+ sizeofUnsignedVInt(remaining)
+ sizeofUnsignedVInt(remainingInPartition);
}
@VisibleForTesting
@SuppressWarnings({ "resource", "RedundantSuppression" })
ByteBuffer legacySerialize(boolean withRemainingInPartition) throws IOException
{
DataOutputBuffer out = new DataOutputBufferFixed(legacySerializedSize(withRemainingInPartition));
writeWithShortLength(null == partitionKey ? EMPTY_BYTE_BUFFER : partitionKey, out);
writeWithShortLength(null == rowMark ? EMPTY_BYTE_BUFFER : rowMark.mark, out);
out.writeInt(remaining);
if (withRemainingInPartition)
out.writeInt(remainingInPartition);
return out.buffer(false);
}
private static boolean isLegacySerialized(ByteBuffer bytes)
{
int index = bytes.position();
int limit = bytes.limit();
if (limit - index < 2)
return false;
short partitionKeyLen = bytes.getShort(index);
if (partitionKeyLen < 0)
return false;
index += 2 + partitionKeyLen;
if (limit - index < 2)
return false;
short rowMarkerLen = bytes.getShort(index);
if (rowMarkerLen < 0)
return false;
index += 2 + rowMarkerLen;
if (limit - index < 4)
return false;
int remaining = bytes.getInt(index);
if (remaining < 0)
return false;
index += 4;
if (index == limit)
return true;
if (limit - index == 4)
{
int remainingInPartition = bytes.getInt(index);
return remainingInPartition >= 0;
}
return false;
}
@SuppressWarnings({ "resource", "RedundantSuppression" })
private static PagingState legacyDeserialize(ByteBuffer bytes, ProtocolVersion protocolVersion) throws IOException
{
if (protocolVersion.isGreaterThan(ProtocolVersion.V3))
throw new IllegalArgumentException();
DataInputBuffer in = new DataInputBuffer(bytes, false);
ByteBuffer partitionKey = readWithShortLength(in);
ByteBuffer rawMark = readWithShortLength(in);
int remaining = in.readInt();
int remainingInPartition = in.available() > 0 ? in.readInt() : Integer.MAX_VALUE;
return new PagingState(partitionKey.hasRemaining() ? partitionKey : null,
rawMark.hasRemaining() ? new RowMark(rawMark, protocolVersion) : null,
remaining,
remainingInPartition);
}
@VisibleForTesting
int legacySerializedSize(boolean withRemainingInPartition)
{
return serializedSizeWithShortLength(null == partitionKey ? EMPTY_BYTE_BUFFER : partitionKey)
+ serializedSizeWithShortLength(null == rowMark ? EMPTY_BYTE_BUFFER : rowMark.mark)
+ sizeof(remaining)
+ (withRemainingInPartition ? sizeof(remainingInPartition) : 0);
}
@Override
public final int hashCode()
{
return Objects.hash(partitionKey, rowMark, remaining, remainingInPartition);
}
@Override
public final boolean equals(Object o)
{
if(!(o instanceof PagingState))
return false;
PagingState that = (PagingState)o;
return Objects.equals(this.partitionKey, that.partitionKey)
&& Objects.equals(this.rowMark, that.rowMark)
&& this.remaining == that.remaining
&& this.remainingInPartition == that.remainingInPartition;
}
@Override
public String toString()
{
return String.format("PagingState(key=%s, cellname=%s, remaining=%d, remainingInPartition=%d",
partitionKey != null ? bytesToHex(partitionKey) : null,
rowMark,
remaining,
remainingInPartition);
}
public static class RowMark
{
private final ByteBuffer mark;
private final ProtocolVersion protocolVersion;
private RowMark(ByteBuffer mark, ProtocolVersion protocolVersion)
{
this.mark = mark;
this.protocolVersion = protocolVersion;
}
private static List<AbstractType<?>> makeClusteringTypes(CFMetaData metadata)
{
int size = metadata.clusteringColumns().size();
List<AbstractType<?>> l = new ArrayList<>(size);
for (int i = 0; i < size; i++)
l.add(BytesType.instance);
return l;
}
public static RowMark create(CFMetaData metadata, Row row, ProtocolVersion protocolVersion)
{
ByteBuffer mark;
if (protocolVersion.isSmallerOrEqualTo(ProtocolVersion.V3))
{
Iterator<Cell> cells = row.cellsInLegacyOrder(metadata, true).iterator();
if (!cells.hasNext())
{
assert !metadata.isCompactTable();
mark = LegacyLayout.encodeCellName(metadata, row.clustering(), EMPTY_BYTE_BUFFER, null);
}
else
{
Cell cell = cells.next();
mark = LegacyLayout.encodeCellName(metadata, row.clustering(), cell.column().name.bytes, cell.column().isComplex() ? cell.path().get(0) : null);
}
}
else
{
mark = Clustering.serializer.serialize(row.clustering(), MessagingService.VERSION_30, makeClusteringTypes(metadata));
}
return new RowMark(mark, protocolVersion);
}
public Clustering clustering(CFMetaData metadata)
{
if (mark == null)
return null;
return protocolVersion.isSmallerOrEqualTo(ProtocolVersion.V3)
? LegacyLayout.decodeClustering(metadata, mark)
: Clustering.serializer.deserialize(mark, MessagingService.VERSION_30, makeClusteringTypes(metadata));
}
@Override
public final int hashCode()
{
return Objects.hash(mark, protocolVersion);
}
@Override
public final boolean equals(Object o)
{
if(!(o instanceof RowMark))
return false;
RowMark that = (RowMark)o;
return Objects.equals(this.mark, that.mark) && this.protocolVersion == that.protocolVersion;
}
@Override
public String toString()
{
return mark == null ? "null" : bytesToHex(mark);
}
}
}