package com.oracle.svm.hosted.thread;
import static com.oracle.svm.core.util.VMError.shouldNotReachHere;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.graalvm.compiler.core.common.NumUtil;
import org.graalvm.compiler.nodes.PiNode;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderContext;
import org.graalvm.nativeimage.hosted.Feature;
import com.oracle.svm.core.config.ConfigurationValues;
import com.oracle.svm.core.meta.ReadableJavaField;
import com.oracle.svm.core.meta.SharedField;
import com.oracle.svm.core.meta.SubstrateObjectConstant;
import com.oracle.svm.core.threadlocal.FastThreadLocal;
import com.oracle.svm.core.threadlocal.VMThreadLocalInfo;
import com.oracle.svm.hosted.FeatureImpl.CompilationAccessImpl;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaField;
class VMThreadLocalCollector implements Function<Object, Object> {
final Map<FastThreadLocal, VMThreadLocalInfo> threadLocals = new ConcurrentHashMap<>();
private boolean sealed;
@Override
public Object apply(Object source) {
if (source instanceof FastThreadLocal) {
FastThreadLocal threadLocal = (FastThreadLocal) source;
if (sealed) {
assert threadLocals.containsKey(threadLocal) : "VMThreadLocal must have been discovered during static analysis";
} else {
threadLocals.putIfAbsent(threadLocal, new VMThreadLocalInfo(threadLocal));
}
}
return source;
}
public VMThreadLocalInfo getInfo(FastThreadLocal threadLocal) {
VMThreadLocalInfo result = threadLocals.get(threadLocal);
assert result != null;
return result;
}
public VMThreadLocalInfo findInfo(GraphBuilderContext b, ValueNode threadLocalNode) {
if (!threadLocalNode.isConstant()) {
throw shouldNotReachHere("Accessed VMThreadLocal is not a compile time constant: " + b.getMethod().asStackTraceElement(b.bci()) + " - node " + unPi(threadLocalNode));
}
FastThreadLocal threadLocal = (FastThreadLocal) SubstrateObjectConstant.asObject(threadLocalNode.asConstant());
VMThreadLocalInfo result = threadLocals.get(threadLocal);
assert result != null;
return result;
}
public List<VMThreadLocalInfo> sortThreadLocals(Feature.CompilationAccess a) {
CompilationAccessImpl config = (CompilationAccessImpl) a;
sealed = true;
for (ResolvedJavaField f : config.getFields()) {
SharedField field = (SharedField) f;
if (field.isStatic() && field.getStorageKind() == JavaKind.Object) {
Object fieldValue = SubstrateObjectConstant.asObject(((ReadableJavaField) field).readValue(null));
if (fieldValue instanceof FastThreadLocal) {
FastThreadLocal threadLocal = (FastThreadLocal) fieldValue;
VMThreadLocalInfo info = threadLocals.get(threadLocal);
String fieldName = field.format("%H.%n");
if (!field.isFinal()) {
throw shouldNotReachHere("VMThreadLocal referenced from non-final field: " + fieldName);
} else if (info.name != null) {
throw shouldNotReachHere("VMThreadLocal referenced from two static final fields: " + info.name + ", " + fieldName);
}
info.name = fieldName;
}
}
}
for (VMThreadLocalInfo info : threadLocals.values()) {
if (info.name == null) {
shouldNotReachHere("VMThreadLocal found that is not referenced from a static final field");
}
assert info.sizeInBytes == -1;
if (info.sizeSupplier != null) {
int unalignedSize = info.sizeSupplier.getAsInt();
assert unalignedSize > 0;
info.sizeInBytes = NumUtil.roundUp(unalignedSize, 8);
} else {
info.sizeInBytes = ConfigurationValues.getObjectLayout().sizeInBytes(info.storageKind);
}
}
List<VMThreadLocalInfo> sortedThreadLocals = new ArrayList<>(threadLocals.values());
sortedThreadLocals.sort(VMThreadLocalCollector::compareThreadLocal);
return sortedThreadLocals;
}
private static int compareThreadLocal(VMThreadLocalInfo info1, VMThreadLocalInfo info2) {
if (info1 == info2) {
return 0;
}
int result = Integer.compare(info1.maxOffset, info2.maxOffset);
if (result == 0) {
result = -Integer.compare(info1.sizeInBytes, info2.sizeInBytes);
if (result == 0) {
result = -Boolean.compare(info1.isObject, info2.isObject);
if (result == 0) {
result = info1.name.compareTo(info2.name);
}
}
}
assert result != 0 : "not distinguishable: " + info1 + ", " + info2;
return result;
}
private static ValueNode unPi(ValueNode n) {
ValueNode cur = n;
while (cur instanceof PiNode) {
cur = ((PiNode) cur).object();
}
return cur;
}
}