package com.datastax.oss.driver.internal.core.util;
import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
import com.datastax.oss.driver.shaded.guava.common.collect.LinkedHashMultimap;
import com.datastax.oss.driver.shaded.guava.common.collect.Lists;
import com.datastax.oss.driver.shaded.guava.common.collect.Maps;
import com.datastax.oss.driver.shaded.guava.common.collect.Multimap;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import net.jcip.annotations.NotThreadSafe;
@NotThreadSafe
public class DirectedGraph<VertexT> {
private final Map<VertexT, Integer> vertices;
private final Multimap<VertexT, VertexT> adjacencyList;
private boolean wasSorted;
public DirectedGraph(Collection<VertexT> vertices) {
this.vertices = Maps.newLinkedHashMapWithExpectedSize(vertices.size());
this.adjacencyList = LinkedHashMultimap.create();
for (VertexT vertex : vertices) {
this.vertices.put(vertex, 0);
}
}
@VisibleForTesting
@SafeVarargs
DirectedGraph(VertexT... vertices) {
this(Arrays.asList(vertices));
}
public void addEdge(VertexT from, VertexT to) {
Preconditions.checkArgument(vertices.containsKey(from) && vertices.containsKey(to));
adjacencyList.put(from, to);
vertices.put(to, vertices.get(to) + 1);
}
public List<VertexT> topologicalSort() {
Preconditions.checkState(!wasSorted);
wasSorted = true;
Queue<VertexT> queue = new ArrayDeque<>();
for (Map.Entry<VertexT, Integer> entry : vertices.entrySet()) {
if (entry.getValue() == 0) {
queue.add(entry.getKey());
}
}
List<VertexT> result = Lists.newArrayList();
while (!queue.isEmpty()) {
VertexT vertex = queue.remove();
result.add(vertex);
for (VertexT successor : adjacencyList.get(vertex)) {
if (decrementAndGetCount(successor) == 0) {
queue.add(successor);
}
}
}
if (result.size() != vertices.size()) {
throw new IllegalArgumentException("failed to perform topological sort, graph has a cycle");
}
return result;
}
private int decrementAndGetCount(VertexT vertex) {
Integer count = vertices.get(vertex);
count = count - 1;
vertices.put(vertex, count);
return count;
}
}