package org.apache.cassandra.repair;
import java.net.InetAddress;
import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.concurrent.Stage;
import org.apache.cassandra.concurrent.StageManager;
import org.apache.cassandra.db.ColumnFamilyStore;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.rows.UnfilteredRowIterator;
import org.apache.cassandra.db.rows.UnfilteredRowIterators;
import org.apache.cassandra.dht.Range;
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.net.MessagingService;
import org.apache.cassandra.repair.messages.ValidationComplete;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.MerkleTree;
import org.apache.cassandra.utils.MerkleTree.RowHash;
import org.apache.cassandra.utils.MerkleTrees;
public class Validator implements Runnable
{
private static final Logger logger = LoggerFactory.getLogger(Validator.class);
public final RepairJobDesc desc;
public final InetAddress initiator;
public final int gcBefore;
private final boolean evenTreeDistribution;
private long validated;
private MerkleTrees trees;
private MerkleTree.TreeRange range;
private MerkleTrees.TreeRangeIterator ranges;
private DecoratedKey lastKey;
public Validator(RepairJobDesc desc, InetAddress initiator, int gcBefore)
{
this(desc, initiator, gcBefore, false);
}
public Validator(RepairJobDesc desc, InetAddress initiator, int gcBefore, boolean evenTreeDistribution)
{
this.desc = desc;
this.initiator = initiator;
this.gcBefore = gcBefore;
validated = 0;
range = null;
ranges = null;
this.evenTreeDistribution = evenTreeDistribution;
}
public void prepare(ColumnFamilyStore cfs, MerkleTrees tree)
{
this.trees = tree;
if (!tree.partitioner().preservesOrder() || evenTreeDistribution)
{
tree.init();
}
else
{
List<DecoratedKey> keys = new ArrayList<>();
Random random = new Random();
for (Range<Token> range : tree.ranges())
{
for (DecoratedKey sample : cfs.keySamples(range))
{
assert range.contains(sample.getToken()) : "Token " + sample.getToken() + " is not within range " + desc.ranges;
keys.add(sample);
}
if (keys.isEmpty())
{
tree.init(range);
}
else
{
int numKeys = keys.size();
while (true)
{
DecoratedKey dk = keys.get(random.nextInt(numKeys));
if (!tree.split(dk.getToken()))
break;
}
keys.clear();
}
}
}
logger.debug("Prepared AEService trees of size {} for {}", trees.size(), desc);
ranges = tree.invalids();
}
public void add(UnfilteredRowIterator partition)
{
assert Range.isInRanges(partition.partitionKey().getToken(), desc.ranges) : partition.partitionKey().getToken() + " is not contained in " + desc.ranges;
assert lastKey == null || lastKey.compareTo(partition.partitionKey()) < 0
: "partition " + partition.partitionKey() + " received out of order wrt " + lastKey;
lastKey = partition.partitionKey();
if (range == null)
range = ranges.next();
if (!findCorrectRange(lastKey.getToken()))
{
ranges = trees.invalids();
findCorrectRange(lastKey.getToken());
}
assert range.contains(lastKey.getToken()) : "Token not in MerkleTree: " + lastKey.getToken();
RowHash rowHash = rowHash(partition);
if (rowHash != null)
{
range.addHash(rowHash);
}
}
public boolean findCorrectRange(Token t)
{
while (!range.contains(t) && ranges.hasNext())
{
range = ranges.next();
}
return range.contains(t);
}
static class CountingDigest extends MessageDigest
{
private long count;
private MessageDigest underlying;
public CountingDigest(MessageDigest underlying)
{
super(underlying.getAlgorithm());
this.underlying = underlying;
}
@Override
protected void engineUpdate(byte input)
{
underlying.update(input);
count += 1;
}
@Override
protected void engineUpdate(byte[] input, int offset, int len)
{
underlying.update(input, offset, len);
count += len;
}
@Override
protected byte[] engineDigest()
{
return underlying.digest();
}
@Override
protected void engineReset()
{
underlying.reset();
}
}
private MerkleTree.RowHash rowHash(UnfilteredRowIterator partition)
{
validated++;
CountingDigest digest = new CountingDigest(FBUtilities.newMessageDigest("SHA-256"));
UnfilteredRowIterators.digest(null, partition, digest, MessagingService.current_version);
return digest.count > 0
? new MerkleTree.RowHash(partition.partitionKey().getToken(), digest.digest(), digest.count)
: null;
}
public void complete()
{
completeTree();
StageManager.getStage(Stage.ANTI_ENTROPY).execute(this);
if (logger.isDebugEnabled())
{
logger.debug("Validated {} partitions for {}. Partitions per leaf are:", validated, desc.sessionId);
trees.logRowCountPerLeaf(logger);
logger.debug("Validated {} partitions for {}. Partition sizes are:", validated, desc.sessionId);
trees.logRowSizePerLeaf(logger);
}
}
@VisibleForTesting
public void completeTree()
{
assert ranges != null : "Validator was not prepared()";
ranges = trees.invalids();
while (ranges.hasNext())
{
range = ranges.next();
range.ensureHashInitialised();
}
}
public void fail()
{
logger.error("Failed creating a merkle tree for {}, {} (see log for details)", desc, initiator);
MessagingService.instance().sendOneWay(new ValidationComplete(desc).createMessage(), initiator);
}
public void run()
{
if (!initiator.equals(FBUtilities.getBroadcastAddress()))
{
logger.info("[repair #{}] Sending completed merkle tree to {} for {}.{}", desc.sessionId, initiator, desc.keyspace, desc.columnFamily);
Tracing.traceRepair("Sending completed merkle tree to {} for {}.{}", initiator, desc.keyspace, desc.columnFamily);
}
MessagingService.instance().sendOneWay(new ValidationComplete(desc, trees).createMessage(), initiator);
}
}