package com.oracle.truffle.llvm.nativemode.runtime.memory;
import java.lang.reflect.Field;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.function.IntBinaryOperator;
import java.util.function.LongBinaryOperator;
import com.oracle.truffle.llvm.runtime.memory.LLVMHandleMemoryBase;
import org.graalvm.collections.EconomicMap;
import com.oracle.truffle.api.Assumption;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.llvm.runtime.LLVMIVarBit;
import com.oracle.truffle.llvm.runtime.LLVMLanguage;
import com.oracle.truffle.llvm.runtime.except.LLVMMemoryException;
import com.oracle.truffle.llvm.runtime.floating.LLVM80BitFloat;
import com.oracle.truffle.llvm.runtime.pointer.LLVMManagedPointer;
import com.oracle.truffle.llvm.runtime.pointer.LLVMNativePointer;
import sun.misc.Unsafe;
public final class LLVMNativeMemory extends LLVMHandleMemoryBase {
static {
assert (DEREF_HANDLE_SPACE_START & HANDLE_HEADER_MASK) != (DEREF_HANDLE_SPACE_END & HANDLE_HEADER_MASK);
assert (HANDLE_SPACE_START & HANDLE_HEADER_MASK) != (HANDLE_SPACE_END & HANDLE_HEADER_MASK);
assert (DEREF_HANDLE_SPACE_START & HANDLE_HEADER_MASK) != 0;
assert (HANDLE_SPACE_START & HANDLE_HEADER_MASK) != 0;
long tmp = HANDLE_HEADER_MASK;
assert DEREF_HANDLE_SPACE_START == tmp;
tmp = HANDLE_SPACE_START;
assert (DEREF_HANDLE_SPACE_START & tmp) == HANDLE_SPACE_START;
}
private static final Unsafe unsafe = getUnsafe();
private static Unsafe getUnsafe() {
CompilerAsserts.neverPartOfCompilation();
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(null);
} catch (Exception e) {
throw new AssertionError();
}
}
private static final LLVMNativeMemory INSTANCE = new LLVMNativeMemory();
@Deprecated
public static LLVMNativeMemory getInstance() {
return INSTANCE;
}
private LLVMNativeMemory() {
}
private static boolean checkPointer(long ptr) {
assert ptr > 0x100000 : "trying to access invalid address: " + ptr + " 0x" + Long.toHexString(ptr);
return true;
}
@TruffleBoundary
private static void memsetBoundary(long address, long size, byte value) {
unsafe.setMemory(address, size, value);
}
@Override
@Deprecated
@SuppressWarnings("deprecation")
public void memset(Node location, LLVMNativePointer address, long size, byte value) {
assert size == 0 || checkPointer(address.asNative());
try {
memsetBoundary(address.asNative(), size, value);
} catch (Throwable e) {
CompilerDirectives.transferToInterpreter();
throw e;
}
}
@TruffleBoundary
private static void copyMemoryBoundary(long sourceAddress, long targetAddress, long length) {
unsafe.copyMemory(sourceAddress, targetAddress, length);
}
@Override
@Deprecated
@SuppressWarnings("deprecation")
@TruffleBoundary
public void copyMemory(Node location, long sourceAddress, long targetAddress, long length) {
assert length == 0 || checkPointer(sourceAddress) && checkPointer(targetAddress);
copyMemoryBoundary(sourceAddress, targetAddress, length);
}
@TruffleBoundary
private static void freeBoundary(long address) {
unsafe.freeMemory(address);
}
@Override
public void free(Node location, long address) {
try {
freeBoundary(address);
} catch (Throwable e) {
CompilerDirectives.transferToInterpreter();
throw e;
}
}
@TruffleBoundary
private static long allocateMemoryBoundary(long size) {
return unsafe.allocateMemory(size);
}
@Override
public LLVMNativePointer allocateMemory(Node location, long size) {
try {
return LLVMNativePointer.create(allocateMemoryBoundary(size));
} catch (Throwable e) {
CompilerDirectives.transferToInterpreter();
throw e;
}
}
@TruffleBoundary
private static long reallocateMemoryBoundary(long addr, long size) {
return unsafe.reallocateMemory(addr, size);
}
@Override
@Deprecated
@SuppressWarnings("deprecation")
public LLVMNativePointer reallocateMemory(Node location, LLVMNativePointer addr, long size) {
try {
return LLVMNativePointer.create(reallocateMemoryBoundary(addr.asNative(), size));
} catch (Throwable e) {
CompilerDirectives.transferToInterpreter();
throw e;
}
}
@Override
public boolean getI1(Node location, long ptr) {
assert checkPointer(ptr);
return unsafe.getByte(ptr) != 0;
}
@Override
public byte getI8(Node location, long ptr) {
assert checkPointer(ptr);
return unsafe.getByte(ptr);
}
@Override
public short getI16(Node location, long ptr) {
assert checkPointer(ptr);
return unsafe.getShort(ptr);
}
@Override
public int getI32(Node location, long ptr) {
assert checkPointer(ptr);
return unsafe.getInt(ptr);
}
@Override
public LLVMIVarBit getIVarBit(Node location, LLVMNativePointer addr, int bitWidth) {
if (bitWidth % Byte.SIZE != 0) {
CompilerDirectives.transferToInterpreter();
throw new AssertionError();
}
int bytes = bitWidth / Byte.SIZE;
byte[] loadedBytes = new byte[bytes];
long currentAddressPtr = addr.asNative();
for (int i = loadedBytes.length - 1; i >= 0; i--) {
loadedBytes[i] = getI8(location, currentAddressPtr);
currentAddressPtr += Byte.BYTES;
}
return LLVMIVarBit.create(bitWidth, loadedBytes, bitWidth, false);
}
@Override
public long getI64(Node location, long ptr) {
assert checkPointer(ptr);
return unsafe.getLong(ptr);
}
@Override
public float getFloat(Node location, long ptr) {
assert checkPointer(ptr);
return unsafe.getFloat(ptr);
}
@Override
public double getDouble(Node location, long ptr) {
assert checkPointer(ptr);
return unsafe.getDouble(ptr);
}
@Override
public LLVM80BitFloat get80BitFloat(Node location, LLVMNativePointer addr) {
byte[] bytes = new byte[LLVM80BitFloat.BYTE_WIDTH];
long currentPtr = addr.asNative();
for (int i = 0; i < bytes.length; i++) {
bytes[i] = getI8(location, currentPtr);
currentPtr += Byte.BYTES;
}
return LLVM80BitFloat.fromBytes(bytes);
}
@Override
public LLVMNativePointer getPointer(Node location, long ptr) {
assert checkPointer(ptr);
return LLVMNativePointer.create(unsafe.getAddress(ptr));
}
@Override
public void putI1(Node location, long ptr, boolean value) {
assert checkPointer(ptr);
unsafe.putByte(ptr, (byte) (value ? 1 : 0));
}
@Override
public void putI8(Node location, long ptr, byte value) {
assert checkPointer(ptr);
unsafe.putByte(ptr, value);
}
@Override
public void putI16(Node location, long ptr, short value) {
assert checkPointer(ptr);
unsafe.putShort(ptr, value);
}
@Override
public void putI32(Node location, long ptr, int value) {
assert checkPointer(ptr);
unsafe.putInt(ptr, value);
}
@Override
public void putI64(Node location, long ptr, long value) {
assert checkPointer(ptr);
unsafe.putLong(ptr, value);
}
@Override
public void putIVarBit(Node location, long addr, LLVMIVarBit value) {
byte[] bytes = value.getBytes();
long currentptr = addr;
for (int i = bytes.length - 1; i >= 0; i--) {
putI8(location, currentptr, bytes[i]);
currentptr += Byte.BYTES;
}
}
@Override
public void putByteArray(Node location, long ptr, byte[] bytes) {
long currentptr = ptr;
for (int i = 0; i < bytes.length; i++) {
putI8(location, currentptr, bytes[i]);
currentptr += Byte.BYTES;
}
}
@Override
public void putFloat(Node location, long ptr, float value) {
assert checkPointer(ptr);
unsafe.putFloat(ptr, value);
}
@Override
public void putDouble(Node location, long ptr, double value) {
assert checkPointer(ptr);
unsafe.putDouble(ptr, value);
}
@Override
public void put80BitFloat(Node location, long ptr, LLVM80BitFloat value) {
putByteArray(location, ptr, value.getBytes());
}
@Override
public void putPointer(Node location, long ptr, long ptrValue) {
assert ptr != 0;
unsafe.putAddress(ptr, ptrValue);
}
@Override
public CMPXCHGI32 compareAndSwapI32(Node location, LLVMNativePointer p, int comparisonValue, int newValue) {
assert checkPointer(p.asNative());
while (true) {
boolean b = unsafe.compareAndSwapInt(null, p.asNative(), comparisonValue, newValue);
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.LIKELY_PROBABILITY, b)) {
return new CMPXCHGI32(comparisonValue, b);
} else {
int t = unsafe.getIntVolatile(null, p.asNative());
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.UNLIKELY_PROBABILITY, t == comparisonValue)) {
continue;
} else {
return new CMPXCHGI32(t, b);
}
}
}
}
@Override
public CMPXCHGI64 compareAndSwapI64(Node location, LLVMNativePointer p, long comparisonValue, long newValue) {
assert checkPointer(p.asNative());
while (true) {
boolean b = unsafe.compareAndSwapLong(null, p.asNative(), comparisonValue, newValue);
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.LIKELY_PROBABILITY, b)) {
return new CMPXCHGI64(comparisonValue, b);
} else {
long t = unsafe.getLongVolatile(null, p.asNative());
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.UNLIKELY_PROBABILITY, t == comparisonValue)) {
continue;
} else {
return new CMPXCHGI64(t, b);
}
}
}
}
private static long alignToI32(long address) {
long mask = 3;
return (address & ~mask);
}
private static int getI8Index(long address) {
long mask = 3;
return (int) (address & mask);
}
private static byte getI8At(int value, int index) {
return (byte) ((value >> (8 * index)) & 0xff);
}
private static int replaceI8(int index, int value, byte replaceByte) {
return (value & ~(0xFF << (index * 8))) | ((replaceByte & 0xFF) << (index * 8));
}
@Override
public CMPXCHGI8 compareAndSwapI8(Node location, LLVMNativePointer p, byte comparisonValue, byte newValue) {
assert checkPointer(p.asNative());
int byteIndex = getI8Index(p.asNative());
long address = alignToI32(p.asNative());
while (true) {
int t = unsafe.getIntVolatile(null, address);
byte b = getI8At(t, byteIndex);
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.LIKELY_PROBABILITY, b != comparisonValue)) {
return new CMPXCHGI8(b, false);
} else {
int newVal = replaceI8(byteIndex, t, newValue);
boolean c = unsafe.compareAndSwapInt(null, address, t, newVal);
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.LIKELY_PROBABILITY, c)) {
return new CMPXCHGI8(comparisonValue, true);
} else {
continue;
}
}
}
}
private static int getI16Index(long address) {
long mask = 3;
return (int) (address & mask) >> 1;
}
private static short getI16At(int value, int index) {
return (short) ((value >> (16 * index)) & 0xFFFF);
}
private static int replaceI16(int index, int value, short replace) {
return (value & ~(0xFFFF << (index * 16))) | ((replace & 0xFFFF) << (index * 16));
}
@Override
public CMPXCHGI16 compareAndSwapI16(Node location, LLVMNativePointer p, short comparisonValue, short newValue) {
assert checkPointer(p.asNative());
int idx = getI16Index(p.asNative());
long address = alignToI32(p.asNative());
while (true) {
int t = unsafe.getIntVolatile(null, address);
short b = getI16At(t, idx);
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.LIKELY_PROBABILITY, b != comparisonValue)) {
return new CMPXCHGI16(b, false);
} else {
int newVal = replaceI16(idx, t, newValue);
boolean c = unsafe.compareAndSwapInt(null, address, t, newVal);
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.LIKELY_PROBABILITY, c)) {
return new CMPXCHGI16(comparisonValue, true);
} else {
continue;
}
}
}
}
@Override
public long getAndSetI64(Node location, LLVMNativePointer address, long value) {
assert checkPointer(address.asNative());
return unsafe.getAndSetLong(null, address.asNative(), value);
}
@Override
public long getAndAddI64(Node location, LLVMNativePointer address, long value) {
assert checkPointer(address.asNative());
return unsafe.getAndAddLong(null, address.asNative(), value);
}
@Override
public long getAndSubI64(Node location, LLVMNativePointer address, long value) {
assert checkPointer(address.asNative());
return unsafe.getAndAddLong(null, address.asNative(), -value);
}
@Override
public long getAndOpI64(Node location, LLVMNativePointer address, long value, LongBinaryOperator f) {
assert checkPointer(address.asNative());
long addr = address.asNative();
long old;
long nevv;
do {
old = getI64(location, address);
nevv = f.applyAsLong(old, value);
} while (!unsafe.compareAndSwapLong(null, addr, old, nevv));
return old;
}
@Override
public int getAndSetI32(Node location, LLVMNativePointer address, int value) {
assert checkPointer(address.asNative());
return unsafe.getAndSetInt(null, address.asNative(), value);
}
@Override
public int getAndAddI32(Node location, LLVMNativePointer address, int value) {
assert checkPointer(address.asNative());
return unsafe.getAndAddInt(null, address.asNative(), value);
}
@Override
public int getAndSubI32(Node location, LLVMNativePointer address, int value) {
assert checkPointer(address.asNative());
return unsafe.getAndAddInt(null, address.asNative(), -value);
}
@Override
public int getAndOpI32(Node location, LLVMNativePointer address, int value, IntBinaryOperator f) {
assert checkPointer(address.asNative());
long addr = address.asNative();
int old;
int nevv;
do {
old = getI32(location, address);
nevv = f.applyAsInt(old, value);
} while (!unsafe.compareAndSwapInt(null, addr, old, nevv));
return old;
}
@Override
public short getAndOpI16(Node location, LLVMNativePointer address, short value, ShortBinaryOperator f) {
short old;
short nevv;
do {
old = getI16(location, address);
nevv = f.apply(old, value);
} while (!compareAndSwapI16(location, address, old, nevv).isSwap());
return old;
}
@Override
public byte getAndOpI8(Node location, LLVMNativePointer address, byte value, ByteBinaryOperator f) {
byte old;
byte nevv;
do {
old = getI8(location, address);
nevv = f.apply(old, value);
} while (!compareAndSwapI8(location, address, old, nevv).isSwap());
return old;
}
@Override
public boolean getAndOpI1(Node location, LLVMNativePointer address, boolean value, BooleanBinaryOperator f) {
byte old;
boolean nevv;
do {
old = getI8(location, address);
nevv = f.apply(old != 0, value);
} while (!compareAndSwapI8(location, address, old, (byte) (nevv ? 1 : 0)).isSwap());
return old != 0;
}
@Override
public void fullFence() {
unsafe.fullFence();
}
@Override
public HandleContainer createHandleContainer(boolean deref, Assumption noHandleAssumption) {
return deref ? new DerefHandleContainer(noHandleAssumption) : new CommonHandleContainer(noHandleAssumption);
}
private abstract static class AbstractHandleContainer extends HandleContainer {
private final Assumption noHandleAssumption;
private final ArrayDeque<Long> freeList = new ArrayDeque<>();
private final EconomicMap<Object, Handle> handleFromManaged = EconomicMap.create();
private Handle[] handleFromPointer = new Handle[1024];
private long top = getStart();
AbstractHandleContainer(Assumption noHandleAssumption) {
this.noHandleAssumption = noHandleAssumption;
}
protected abstract long getStart();
protected abstract long getEnd();
private int indexFromPointer(long address) {
return (int) (((address - getStart()) >> HANDLE_OBJECT_SIZE_BITS));
}
@Override
@TruffleBoundary
public synchronized LLVMNativePointer allocate(Node location, Object value) {
Handle handle = handleFromManaged.get(value);
if (handle == null) {
Long free = freeList.pollFirst();
long address;
if (free != null) {
address = free;
} else {
noHandleAssumption.invalidate();
if (top >= getEnd()) {
throw new LLVMMemoryException(location, new OutOfMemoryError("handle space exhausted"));
}
address = top;
top += HANDLE_OBJECT_SIZE;
}
handle = new Handle(LLVMNativePointer.create(address), value);
int index = indexFromPointer(address);
if (handleFromPointer.length <= index) {
handleFromPointer = Arrays.copyOf(handleFromPointer, handleFromPointer.length * 2);
}
handleFromPointer[index] = handle;
handleFromManaged.put(value, handle);
}
handle.refcnt++;
return handle.pointer;
}
@Override
@TruffleBoundary
public synchronized void free(Node location, long address) {
if ((address & HANDLE_OFFSET_MASK) != 0) {
throw new LLVMMemoryException(location, new UnsupportedOperationException("Cannot resolve invalid native handle: " + address));
}
if ((address & HANDLE_HEADER_MASK) != getStart()) {
throw new LLVMMemoryException(location, new UnsupportedOperationException("Cannot resolve invalid native handle: " + address));
}
int index = indexFromPointer(address);
if (index < 0 || index >= handleFromPointer.length) {
throw new LLVMMemoryException(location, new UnsupportedOperationException("Cannot resolve native handle: " + address));
}
Handle handle = handleFromPointer[index];
if (handle == null) {
throw new LLVMMemoryException(location, new UnsupportedOperationException("Cannot resolve native handle (double-free?): " + address));
}
if (--handle.refcnt == 0) {
handleFromPointer[index] = null;
handleFromManaged.removeKey(handle.managed);
freeList.addLast(address);
}
}
@Override
public boolean isHandle(long address) {
if ((address & HANDLE_HEADER_MASK) != getStart()) {
return false;
}
int index = indexFromPointer(address);
Handle[] array = handleFromPointer;
return index >= 0 && index < array.length && array[index] != null;
}
@Override
public LLVMManagedPointer getValue(Node location, long address) {
return LLVMManagedPointer.create(handleFromPointer[indexFromPointer(address)].managed, address & HANDLE_OFFSET_MASK);
}
}
private static final class Handle {
private int refcnt;
private final LLVMNativePointer pointer;
private final Object managed;
private Handle(LLVMNativePointer pointer, Object managed) {
this.refcnt = 0;
this.pointer = pointer;
this.managed = managed;
}
}
private static final class CommonHandleContainer extends AbstractHandleContainer {
CommonHandleContainer(Assumption noHandleAssumption) {
super(noHandleAssumption);
}
@Override
protected long getStart() {
return HANDLE_SPACE_START;
}
@Override
protected long getEnd() {
return HANDLE_SPACE_END;
}
}
private static final class DerefHandleContainer extends AbstractHandleContainer {
DerefHandleContainer(Assumption noHandleAssumption) {
super(noHandleAssumption);
}
@Override
protected long getStart() {
return DEREF_HANDLE_SPACE_START;
}
@Override
protected long getEnd() {
return DEREF_HANDLE_SPACE_END;
}
}
}