package org.apache.cassandra.utils.concurrent;
import java.lang.ref.PhantomReference;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import org.apache.cassandra.concurrent.InfiniteLoopExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import org.apache.cassandra.concurrent.NamedThreadFactory;
import org.apache.cassandra.db.ColumnFamilyStore;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.lifecycle.View;
import org.apache.cassandra.io.sstable.format.SSTableReader;
import org.apache.cassandra.io.util.Memory;
import org.apache.cassandra.io.util.SafeMemory;
import org.apache.cassandra.utils.ExecutorUtils;
import org.apache.cassandra.utils.NoSpamLogger;
import org.apache.cassandra.utils.Pair;
import org.cliffc.high_scale_lib.NonBlockingHashMap;
import static java.util.Collections.emptyList;
import org.apache.cassandra.concurrent.InfiniteLoopExecutor.InterruptibleRunnable;
import static org.apache.cassandra.utils.ExecutorUtils.awaitTermination;
import static org.apache.cassandra.utils.ExecutorUtils.shutdownNow;
import static org.apache.cassandra.utils.Throwables.maybeFail;
import static org.apache.cassandra.utils.Throwables.merge;
public final class Ref<T> implements RefCounted<T>
{
static final Logger logger = LoggerFactory.getLogger(Ref.class);
public static final boolean DEBUG_ENABLED = System.getProperty("cassandra.debugrefcount", "false").equalsIgnoreCase("true");
final State state;
final T referent;
public Ref(T referent, Tidy tidy)
{
this.state = new State(new GlobalState(tidy), this, referenceQueue);
this.referent = referent;
}
Ref(T referent, GlobalState state)
{
this.state = new State(state, this, referenceQueue);
this.referent = referent;
}
public void release()
{
state.release(false);
}
public Throwable ensureReleased(Throwable accumulate)
{
return state.ensureReleased(accumulate);
}
public void ensureReleased()
{
maybeFail(state.ensureReleased(null));
}
public void close()
{
ensureReleased();
}
public T get()
{
state.assertNotReleased();
return referent;
}
public Ref<T> tryRef()
{
return state.globalState.ref() ? new Ref<>(referent, state.globalState) : null;
}
public Ref<T> ref()
{
Ref<T> ref = tryRef();
if (ref == null)
state.assertNotReleased();
return ref;
}
public String printDebugInfo()
{
if (DEBUG_ENABLED)
{
state.debug.log(state.toString());
return "Memory was freed by " + state.debug.deallocateThread;
}
return "Memory was freed";
}
public int globalCount()
{
return state.globalState.count();
}
static final class State extends PhantomReference<Ref>
{
final Debug debug = DEBUG_ENABLED ? new Debug() : null;
final GlobalState globalState;
private volatile int released;
private static final AtomicIntegerFieldUpdater<State> releasedUpdater = AtomicIntegerFieldUpdater.newUpdater(State.class, "released");
public State(final GlobalState globalState, Ref reference, ReferenceQueue<? super Ref> q)
{
super(reference, q);
this.globalState = globalState;
globalState.register(this);
}
void assertNotReleased()
{
if (DEBUG_ENABLED && released == 1)
debug.log(toString());
assert released == 0;
}
Throwable ensureReleased(Throwable accumulate)
{
if (releasedUpdater.getAndSet(this, 1) == 0)
{
accumulate = globalState.release(this, accumulate);
if (DEBUG_ENABLED)
debug.deallocate();
}
return accumulate;
}
void release(boolean leak)
{
if (!releasedUpdater.compareAndSet(this, 0, 1))
{
if (!leak)
{
String id = this.toString();
logger.error("BAD RELEASE: attempted to release a reference ({}) that has already been released", id);
if (DEBUG_ENABLED)
debug.log(id);
throw new IllegalStateException("Attempted to release a reference that has already been released");
}
return;
}
Throwable fail = globalState.release(this, null);
if (leak)
{
String id = this.toString();
logger.error("LEAK DETECTED: a reference ({}) to {} was not released before the reference was garbage collected", id, globalState);
if (DEBUG_ENABLED)
debug.log(id);
}
else if (DEBUG_ENABLED)
{
debug.deallocate();
}
if (fail != null)
logger.error("Error when closing {}", globalState, fail);
}
}
static final class Debug
{
String allocateThread, deallocateThread;
StackTraceElement[] allocateTrace, deallocateTrace;
Debug()
{
Thread thread = Thread.currentThread();
allocateThread = thread.toString();
allocateTrace = thread.getStackTrace();
}
synchronized void deallocate()
{
Thread thread = Thread.currentThread();
deallocateThread = thread.toString();
deallocateTrace = thread.getStackTrace();
}
synchronized void log(String id)
{
logger.error("Allocate trace {}:\n{}", id, print(allocateThread, allocateTrace));
if (deallocateThread != null)
logger.error("Deallocate trace {}:\n{}", id, print(deallocateThread, deallocateTrace));
}
String print(String thread, StackTraceElement[] trace)
{
StringBuilder sb = new StringBuilder();
sb.append(thread);
sb.append("\n");
for (StackTraceElement element : trace)
{
sb.append("\tat ");
sb.append(element );
sb.append("\n");
}
return sb.toString();
}
}
static final class GlobalState
{
private final Collection<State> locallyExtant = new ConcurrentLinkedDeque<>();
private final AtomicInteger counts = new AtomicInteger();
private final Tidy tidy;
GlobalState(Tidy tidy)
{
this.tidy = tidy;
globallyExtant.add(this);
}
void register(Ref.State ref)
{
locallyExtant.add(ref);
}
boolean ref()
{
while (true)
{
int cur = counts.get();
if (cur < 0)
return false;
if (counts.compareAndSet(cur, cur + 1))
return true;
}
}
Throwable release(Ref.State ref, Throwable accumulate)
{
locallyExtant.remove(ref);
if (-1 == counts.decrementAndGet())
{
globallyExtant.remove(this);
try
{
if (tidy != null)
tidy.tidy();
}
catch (Throwable t)
{
accumulate = merge(accumulate, t);
}
}
return accumulate;
}
int count()
{
return 1 + counts.get();
}
public String toString()
{
if (tidy != null)
return tidy.getClass() + "@" + System.identityHashCode(tidy) + ":" + tidy.name();
return "@" + System.identityHashCode(this);
}
}
private static final Class<?>[] concurrentIterableClasses = new Class<?>[]
{
ConcurrentLinkedQueue.class,
ConcurrentLinkedDeque.class,
ConcurrentSkipListSet.class,
CopyOnWriteArrayList.class,
CopyOnWriteArraySet.class,
DelayQueue.class,
NonBlockingHashMap.class,
};
static final Set<Class<?>> concurrentIterables = Collections.newSetFromMap(new IdentityHashMap<>());
private static final Set<GlobalState> globallyExtant = Collections.newSetFromMap(new ConcurrentHashMap<>());
static final ReferenceQueue<Object> referenceQueue = new ReferenceQueue<>();
private static final InfiniteLoopExecutor EXEC = new InfiniteLoopExecutor("Reference-Reaper", Ref::reapOneReference).start();
static final ScheduledExecutorService STRONG_LEAK_DETECTOR = !DEBUG_ENABLED ? null : Executors.newScheduledThreadPool(1, new NamedThreadFactory("Strong-Reference-Leak-Detector"));
static
{
if (DEBUG_ENABLED)
{
STRONG_LEAK_DETECTOR.scheduleAtFixedRate(new Visitor(), 1, 15, TimeUnit.MINUTES);
STRONG_LEAK_DETECTOR.scheduleAtFixedRate(new StrongLeakDetector(), 2, 15, TimeUnit.MINUTES);
}
concurrentIterables.addAll(Arrays.asList(concurrentIterableClasses));
}
private static void reapOneReference() throws InterruptedException
{
Object obj = referenceQueue.remove(100);
if (obj instanceof Ref.State)
{
((Ref.State) obj).release(true);
}
}
static final Deque<InProgressVisit> inProgressVisitPool = new ArrayDeque<InProgressVisit>();
@SuppressWarnings({ "rawtypes", "unchecked" })
static InProgressVisit newInProgressVisit(Object o, List<Field> fields, Field field, String name)
{
Preconditions.checkNotNull(o);
InProgressVisit ipv = inProgressVisitPool.pollLast();
if (ipv == null)
ipv = new InProgressVisit();
ipv.o = o;
if (o instanceof Object[])
ipv.collectionIterator = Arrays.asList((Object[])o).iterator();
else if (o instanceof ConcurrentMap)
{
ipv.isMapIterator = true;
ipv.collectionIterator = ((Map)o).entrySet().iterator();
}
else if (concurrentIterables.contains(o.getClass()) | o instanceof BlockingQueue)
ipv.collectionIterator = ((Iterable)o).iterator();
ipv.fields = fields;
ipv.field = field;
ipv.name = name;
return ipv;
}
static void returnInProgressVisit(InProgressVisit ipv)
{
if (inProgressVisitPool.size() > 1024)
return;
ipv.name = null;
ipv.fields = null;
ipv.o = null;
ipv.fieldIndex = 0;
ipv.field = null;
ipv.collectionIterator = null;
ipv.mapEntryValue = null;
ipv.isMapIterator = false;
inProgressVisitPool.offer(ipv);
}
@SuppressWarnings({ "rawtypes"})
static class InProgressVisit
{
String name;
List<Field> fields;
Object o;
int fieldIndex = 0;
Field field;
boolean isMapIterator;
Iterator<Object> collectionIterator;
Object mapEntryValue;
private Field nextField()
{
if (fields.isEmpty())
return null;
if (fieldIndex >= fields.size())
return null;
Field retval = fields.get(fieldIndex);
fieldIndex++;
return retval;
}
Pair<Object, Field> nextChild() throws IllegalAccessException
{
if (mapEntryValue != null)
{
Pair<Object, Field> retval = Pair.create(mapEntryValue, field);
mapEntryValue = null;
return retval;
}
if (collectionIterator != null)
{
if (!collectionIterator.hasNext())
return null;
Object nextItem = null;
while (collectionIterator.hasNext() && (nextItem = collectionIterator.next()) == null){}
if (nextItem != null)
{
if (isMapIterator & nextItem instanceof Map.Entry)
{
Map.Entry entry = (Map.Entry)nextItem;
mapEntryValue = entry.getValue();
return Pair.create(entry.getKey(), field);
}
return Pair.create(nextItem, field);
}
else
{
return null;
}
}
while (true)
{
Field nextField = nextField();
if (nextField == null)
return null;
if (o instanceof WeakReference & nextField.getDeclaringClass() == Reference.class)
continue;
Object nextObject = nextField.get(o);
if (nextObject != null)
return Pair.create(nextField.get(o), nextField);
}
}
@Override
public String toString()
{
return field == null ? name : field.toString() + "-" + o.getClass().getName();
}
}
static class Visitor implements Runnable
{
final Deque<InProgressVisit> path = new ArrayDeque<>();
final Set<Object> visited = Collections.newSetFromMap(new IdentityHashMap<>());
@VisibleForTesting
int lastVisitedCount;
@VisibleForTesting
long iterations = 0;
GlobalState visiting;
Set<GlobalState> haveLoops;
public void run()
{
try
{
for (GlobalState globalState : globallyExtant)
{
if (globalState.tidy == null)
continue;
path.clear();
visited.clear();
lastVisitedCount = 0;
iterations = 0;
visited.add(globalState);
visiting = globalState;
traverse(globalState.tidy);
}
}
catch (Throwable t)
{
t.printStackTrace();
}
finally
{
lastVisitedCount = visited.size();
path.clear();
visited.clear();
}
}
void traverse(final RefCounted.Tidy rootObject)
{
path.offer(newInProgressVisit(rootObject, getFields(rootObject.getClass()), null, rootObject.name()));
InProgressVisit inProgress = null;
while (inProgress != null || !path.isEmpty())
{
if (inProgress == null)
inProgress = path.pollLast();
try
{
Pair<Object, Field> p = inProgress.nextChild();
Object child = null;
Field field = null;
if (p != null)
{
iterations++;
child = p.left;
field = p.right;
}
if (child != null && visited.add(child))
{
path.offer(inProgress);
inProgress = newInProgressVisit(child, getFields(child.getClass()), field, null);
continue;
}
else if (visiting == child)
{
if (haveLoops != null)
haveLoops.add(visiting);
NoSpamLogger.log(logger,
NoSpamLogger.Level.ERROR,
rootObject.getClass().getName(),
1,
TimeUnit.SECONDS,
"Strong self-ref loop detected {}",
path);
}
else if (child == null)
{
returnInProgressVisit(inProgress);
inProgress = null;
continue;
}
}
catch (IllegalAccessException e)
{
NoSpamLogger.log(logger, NoSpamLogger.Level.ERROR, 5, TimeUnit.MINUTES, "Could not fully check for self-referential leaks", e);
}
}
}
}
static final Map<Class<?>, List<Field>> fieldMap = new HashMap<>();
static List<Field> getFields(Class<?> clazz)
{
if (clazz == null || clazz == PhantomReference.class || clazz == Class.class || java.lang.reflect.Member.class.isAssignableFrom(clazz))
return emptyList();
List<Field> fields = fieldMap.get(clazz);
if (fields != null)
return fields;
fieldMap.put(clazz, fields = new ArrayList<>());
for (Field field : clazz.getDeclaredFields())
{
if (field.getType().isPrimitive() || Modifier.isStatic(field.getModifiers()))
continue;
field.setAccessible(true);
fields.add(field);
}
fields.addAll(getFields(clazz.getSuperclass()));
return fields;
}
public static class IdentityCollection
{
final Set<Tidy> candidates;
public IdentityCollection(Set<Tidy> candidates)
{
this.candidates = candidates;
}
public void add(Ref<?> ref)
{
candidates.remove(ref.state.globalState.tidy);
}
public void add(SelfRefCounted<?> ref)
{
add(ref.selfRef());
}
public void add(SharedCloseable ref)
{
if (ref instanceof SharedCloseableImpl)
add((SharedCloseableImpl)ref);
}
public void add(SharedCloseableImpl ref)
{
add(ref.ref);
}
public void add(Memory memory)
{
if (memory instanceof SafeMemory)
((SafeMemory) memory).addTo(this);
}
}
private static class StrongLeakDetector implements Runnable
{
Set<Tidy> candidates = new HashSet<>();
public void run()
{
final Set<Tidy> candidates = Collections.newSetFromMap(new IdentityHashMap<>());
for (GlobalState state : globallyExtant)
candidates.add(state.tidy);
removeExpected(candidates);
this.candidates.retainAll(candidates);
if (!this.candidates.isEmpty())
{
List<String> names = new ArrayList<>();
for (Tidy tidy : this.candidates)
names.add(tidy.name());
logger.warn("Strong reference leak candidates detected: {}", names);
}
this.candidates = candidates;
}
private void removeExpected(Set<Tidy> candidates)
{
final Ref.IdentityCollection expected = new Ref.IdentityCollection(candidates);
for (Keyspace ks : Keyspace.all())
{
for (ColumnFamilyStore cfs : ks.getColumnFamilyStores())
{
View view = cfs.getTracker().getView();
for (SSTableReader reader : view.allKnownSSTables())
reader.addTo(expected);
}
}
}
}
@VisibleForTesting
public static void shutdownReferenceReaper(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException
{
ExecutorUtils.shutdownNowAndWait(timeout, unit, EXEC, STRONG_LEAK_DETECTOR);
}
}