package org.graalvm.wasm.memory;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.nodes.Node;
import org.graalvm.wasm.exception.Failure;
import org.graalvm.wasm.exception.WasmException;
import sun.misc.Unsafe;
import java.lang.reflect.Field;
public class UnsafeWasmMemory extends WasmMemory implements AutoCloseable {
private final Unsafe unsafe;
private long startAddress;
private int pageSize;
private final int maxPageSize;
public UnsafeWasmMemory(int initPageSize, int maxPageSize) {
try {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
unsafe = (Unsafe) f.get(null);
} catch (Exception e) {
throw new RuntimeException(e);
}
this.pageSize = initPageSize;
this.maxPageSize = maxPageSize;
long byteSize = byteSize();
this.startAddress = unsafe.allocateMemory(byteSize);
unsafe.setMemory(startAddress, byteSize, (byte) 0);
}
public void validateAddress(Node node, int address, int offset) {
if (address < 0 || address + offset >= this.byteSize()) {
throw trapOutOfBounds(node, address, offset);
}
}
@TruffleBoundary
private WasmException trapOutOfBounds(Node node, int address, int offset) {
throw WasmException.format(Failure.UNSPECIFIED_TRAP, node, "%d-byte memory access at address 0x%016X (%d) is out-of-bounds (memory size %d bytes).",
offset, address, address, byteSize());
}
@Override
public void copy(Node node, int src, int dst, int n) {
validateAddress(node, src, n);
validateAddress(node, dst, n);
unsafe.copyMemory(startAddress + src, startAddress + dst, n);
}
@Override
public void clear() {
unsafe.setMemory(startAddress, byteSize(), (byte) 0);
}
@Override
public int pageSize() {
return pageSize;
}
@Override
public int byteSize() {
return pageSize * PAGE_SIZE;
}
@Override
public int maxPageSize() {
return maxPageSize;
}
@Override
@TruffleBoundary
public boolean grow(int extraPageSize) {
if (extraPageSize < 0) {
throw WasmException.create(Failure.UNSPECIFIED_TRAP, null, "Extra size cannot be negative.");
}
long targetSize = byteSize() + extraPageSize * PAGE_SIZE;
if (maxPageSize >= 0 && targetSize > maxPageSize * PAGE_SIZE) {
return false;
}
if (targetSize * PAGE_SIZE == byteSize()) {
return true;
}
long updatedStartAddress = unsafe.allocateMemory(targetSize);
unsafe.copyMemory(startAddress, updatedStartAddress, byteSize());
unsafe.setMemory(updatedStartAddress + byteSize(), targetSize - byteSize(), (byte) 0);
unsafe.freeMemory(startAddress);
startAddress = updatedStartAddress;
pageSize += extraPageSize;
return true;
}
@Override
public int load_i32(Node node, int address) {
validateAddress(node, address, 4);
int value = unsafe.getInt(startAddress + address);
return value;
}
@Override
public long load_i64(Node node, int address) {
validateAddress(node, address, 8);
long value = unsafe.getLong(startAddress + address);
return value;
}
@Override
public float load_f32(Node node, int address) {
validateAddress(node, address, 4);
float value = unsafe.getFloat(startAddress + address);
return value;
}
@Override
public double load_f64(Node node, int address) {
validateAddress(node, address, 8);
double value = unsafe.getDouble(startAddress + address);
return value;
}
@Override
public int load_i32_8s(Node node, int address) {
validateAddress(node, address, 1);
int value = unsafe.getByte(startAddress + address);
return value;
}
@Override
public int load_i32_8u(Node node, int address) {
validateAddress(node, address, 1);
int value = 0x0000_00ff & unsafe.getByte(startAddress + address);
return value;
}
@Override
public int load_i32_16s(Node node, int address) {
validateAddress(node, address, 2);
int value = unsafe.getShort(startAddress + address);
return value;
}
@Override
public int load_i32_16u(Node node, int address) {
validateAddress(node, address, 2);
int value = 0x0000_ffff & unsafe.getShort(startAddress + address);
return value;
}
@Override
public long load_i64_8s(Node node, int address) {
validateAddress(node, address, 1);
long value = unsafe.getByte(startAddress + address);
return value;
}
@Override
public long load_i64_8u(Node node, int address) {
validateAddress(node, address, 1);
long value = 0x0000_0000_0000_00ffL & unsafe.getByte(startAddress + address);
return value;
}
@Override
public long load_i64_16s(Node node, int address) {
validateAddress(node, address, 2);
long value = unsafe.getShort(startAddress + address);
return value;
}
@Override
public long load_i64_16u(Node node, int address) {
validateAddress(node, address, 2);
long value = 0x0000_0000_0000_ffffL & unsafe.getShort(startAddress + address);
return value;
}
@Override
public long load_i64_32s(Node node, int address) {
validateAddress(node, address, 4);
long value = unsafe.getInt(startAddress + address);
return value;
}
@Override
public long load_i64_32u(Node node, int address) {
validateAddress(node, address, 4);
long value = 0x0000_0000_ffff_ffffL & unsafe.getInt(startAddress + address);
return value;
}
@Override
public void store_i32(Node node, int address, int value) {
validateAddress(node, address, 4);
unsafe.putInt(startAddress + address, value);
}
@Override
public void store_i64(Node node, int address, long value) {
validateAddress(node, address, 8);
unsafe.putLong(startAddress + address, value);
}
@Override
public void store_f32(Node node, int address, float value) {
validateAddress(node, address, 4);
unsafe.putFloat(startAddress + address, value);
}
@Override
public void store_f64(Node node, int address, double value) {
validateAddress(node, address, 8);
unsafe.putDouble(startAddress + address, value);
}
@Override
public void store_i32_8(Node node, int address, byte value) {
validateAddress(node, address, 1);
unsafe.putByte(startAddress + address, value);
}
@Override
public void store_i32_16(Node node, int address, short value) {
validateAddress(node, address, 2);
unsafe.putShort(startAddress + address, value);
}
@Override
public void store_i64_8(Node node, int address, byte value) {
validateAddress(node, address, 1);
unsafe.putByte(startAddress + address, value);
}
@Override
public void store_i64_16(Node node, int address, short value) {
validateAddress(node, address, 2);
unsafe.putShort(startAddress + address, value);
}
@Override
public void store_i64_32(Node node, int address, int value) {
validateAddress(node, address, 4);
unsafe.putInt(startAddress + address, value);
}
@Override
public WasmMemory duplicate() {
final UnsafeWasmMemory other = new UnsafeWasmMemory(pageSize, maxPageSize);
unsafe.copyMemory(this.startAddress, other.startAddress, this.byteSize());
return other;
}
public void free() {
unsafe.freeMemory(this.startAddress);
startAddress = 0;
pageSize = 0;
}
public boolean freed() {
return startAddress == 0;
}
@Override
public void close() {
if (!freed()) {
free();
}
}
}