package org.apache.cassandra.hadoop.cql3;
import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.Host;
import com.datastax.driver.core.HostDistance;
import com.datastax.driver.core.Statement;
import com.datastax.driver.core.policies.LoadBalancingPolicy;
import com.google.common.base.Function;
import com.google.common.collect.Iterators;
import com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.util.*;
import java.util.concurrent.CopyOnWriteArraySet;
class LimitedLocalNodeFirstLocalBalancingPolicy implements LoadBalancingPolicy
{
private final static Logger logger = LoggerFactory.getLogger(LimitedLocalNodeFirstLocalBalancingPolicy.class);
private final static Set<InetAddress> localAddresses = Collections.unmodifiableSet(getLocalInetAddresses());
private final CopyOnWriteArraySet<Host> liveReplicaHosts = new CopyOnWriteArraySet<>();
private final Set<InetAddress> replicaAddresses = new HashSet<>();
public LimitedLocalNodeFirstLocalBalancingPolicy(String[] replicas)
{
for (String replica : replicas)
{
try
{
InetAddress[] addresses = InetAddress.getAllByName(replica);
Collections.addAll(replicaAddresses, addresses);
}
catch (UnknownHostException e)
{
logger.warn("Invalid replica host name: {}, skipping it", replica);
}
}
logger.trace("Created instance with the following replicas: {}", Arrays.asList(replicas));
}
@Override
public void init(Cluster cluster, Collection<Host> hosts)
{
List<Host> replicaHosts = new ArrayList<>();
for (Host host : hosts)
{
if (replicaAddresses.contains(host.getAddress()))
{
replicaHosts.add(host);
}
}
liveReplicaHosts.addAll(replicaHosts);
logger.trace("Initialized with replica hosts: {}", replicaHosts);
}
@Override
public void close()
{
}
@Override
public HostDistance distance(Host host)
{
if (isLocalHost(host))
{
return HostDistance.LOCAL;
}
else
{
return HostDistance.REMOTE;
}
}
@Override
public Iterator<Host> newQueryPlan(String keyspace, Statement statement)
{
List<Host> local = new ArrayList<>(1);
List<Host> remote = new ArrayList<>(liveReplicaHosts.size());
for (Host liveReplicaHost : liveReplicaHosts)
{
if (isLocalHost(liveReplicaHost))
{
local.add(liveReplicaHost);
}
else
{
remote.add(liveReplicaHost);
}
}
Collections.shuffle(remote);
logger.trace("Using the following hosts order for the new query plan: {} | {}", local, remote);
return Iterators.concat(local.iterator(), remote.iterator());
}
@Override
public void onAdd(Host host)
{
if (replicaAddresses.contains(host.getAddress()))
{
liveReplicaHosts.add(host);
logger.trace("Added a new host {}", host);
}
}
@Override
public void onUp(Host host)
{
if (replicaAddresses.contains(host.getAddress()))
{
liveReplicaHosts.add(host);
logger.trace("The host {} is now up", host);
}
}
@Override
public void onDown(Host host)
{
if (liveReplicaHosts.remove(host))
{
logger.trace("The host {} is now down", host);
}
}
@Override
public void onRemove(Host host)
{
if (liveReplicaHosts.remove(host))
{
logger.trace("Removed the host {}", host);
}
}
public void onSuspected(Host host)
{
}
private static boolean isLocalHost(Host host)
{
InetAddress hostAddress = host.getAddress();
return hostAddress.isLoopbackAddress() || localAddresses.contains(hostAddress);
}
private static Set<InetAddress> getLocalInetAddresses()
{
try
{
return Sets.newHashSet(Iterators.concat(
Iterators.transform(
Iterators.forEnumeration(NetworkInterface.getNetworkInterfaces()),
new Function<NetworkInterface, Iterator<InetAddress>>()
{
@Override
public Iterator<InetAddress> apply(NetworkInterface netIface)
{
return Iterators.forEnumeration(netIface.getInetAddresses());
}
})));
}
catch (SocketException e)
{
logger.warn("Could not retrieve local network interfaces.", e);
return Collections.emptySet();
}
}
}