package jdk.internal.foreign.abi;
import jdk.incubator.foreign.FunctionDescriptor;
import jdk.incubator.foreign.GroupLayout;
import jdk.incubator.foreign.MemoryAccess;
import jdk.incubator.foreign.MemoryAddress;
import jdk.incubator.foreign.MemoryHandles;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.MemorySegment;
import jdk.incubator.foreign.LibraryLookup;
import jdk.incubator.foreign.NativeScope;
import jdk.incubator.foreign.SequenceLayout;
import jdk.incubator.foreign.CLinker;
import jdk.incubator.foreign.ValueLayout;
import jdk.internal.foreign.CABI;
import jdk.internal.foreign.MemoryAddressImpl;
import jdk.internal.foreign.Utils;
import jdk.internal.foreign.abi.aarch64.AArch64Linker;
import jdk.internal.foreign.abi.x64.sysv.SysVx64Linker;
import jdk.internal.foreign.abi.x64.windows.Windowsx64Linker;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.VarHandle;
import java.nio.charset.Charset;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import static java.lang.invoke.MethodHandles.collectArguments;
import static java.lang.invoke.MethodHandles.identity;
import static java.lang.invoke.MethodHandles.insertArguments;
import static java.lang.invoke.MethodHandles.permuteArguments;
import static java.lang.invoke.MethodType.methodType;
import static jdk.incubator.foreign.CLinker.*;
public class SharedUtils {
private static final MethodHandle MH_ALLOC_BUFFER;
private static final MethodHandle MH_BASEADDRESS;
private static final MethodHandle MH_BUFFER_COPY;
static final Allocator DEFAULT_ALLOCATOR = MemorySegment::allocateNative;
static {
try {
var lookup = MethodHandles.lookup();
MH_ALLOC_BUFFER = lookup.findStatic(SharedUtils.class, "allocateNative",
methodType(MemorySegment.class, MemoryLayout.class));
MH_BASEADDRESS = lookup.findVirtual(MemorySegment.class, "address",
methodType(MemoryAddress.class));
MH_BUFFER_COPY = lookup.findStatic(SharedUtils.class, "bufferCopy",
methodType(MemoryAddress.class, MemoryAddress.class, MemorySegment.class));
} catch (ReflectiveOperationException e) {
throw new BootstrapMethodError(e);
}
}
private static MemorySegment allocateNative(MemoryLayout layout) {
return MemorySegment.allocateNative(layout);
}
public static long align(MemoryLayout t, boolean isVar, long addr) {
return alignUp(addr, alignment(t, isVar));
}
public static long alignUp(long addr, long alignment) {
return ((addr - 1) | (alignment - 1)) + 1;
}
public static long alignment(MemoryLayout t, boolean isVar) {
if (t instanceof ValueLayout) {
return alignmentOfScalar((ValueLayout) t);
} else if (t instanceof SequenceLayout) {
return alignmentOfArray((SequenceLayout) t, isVar);
} else if (t instanceof GroupLayout) {
return alignmentOfContainer((GroupLayout) t);
} else if (t.isPadding()) {
return 1;
} else {
throw new IllegalArgumentException("Invalid type: " + t);
}
}
private static long alignmentOfScalar(ValueLayout st) {
return st.byteSize();
}
private static long alignmentOfArray(SequenceLayout ar, boolean isVar) {
if (ar.elementCount().orElseThrow() == 0) {
return 16;
} else if ((ar.byteSize()) >= 16 && isVar) {
return 16;
} else {
MemoryLayout elementType = ar.elementLayout();
return alignment(elementType, false);
}
}
private static long alignmentOfContainer(GroupLayout ct) {
return ct.memberLayouts().stream().mapToLong(t -> alignment(t, false)).max().orElse(1);
}
public static MethodHandle adaptDowncallForIMR(MethodHandle handle, FunctionDescriptor cDesc) {
if (handle.type().returnType() != void.class)
throw new IllegalArgumentException("return expected to be void for in memory returns");
if (handle.type().parameterType(0) != MemoryAddress.class)
throw new IllegalArgumentException("MemoryAddress expected as first param");
if (cDesc.returnLayout().isEmpty())
throw new IllegalArgumentException("Return layout needed: " + cDesc);
MethodHandle ret = identity(MemorySegment.class);
handle = collectArguments(ret, 1, handle);
handle = collectArguments(handle, 1, MH_BASEADDRESS);
MethodType oldType = handle.type();
MethodType newType = oldType.dropParameterTypes(0, 1);
int[] reorder = IntStream.range(-1, newType.parameterCount()).toArray();
reorder[0] = 0;
handle = permuteArguments(handle, newType, reorder);
handle = collectArguments(handle, 0, insertArguments(MH_ALLOC_BUFFER, 0, cDesc.returnLayout().get()));
return handle;
}
public static MethodHandle adaptUpcallForIMR(MethodHandle target) {
if (target.type().returnType() != MemorySegment.class)
throw new IllegalArgumentException("Must return MemorySegment for IMR");
target = collectArguments(MH_BUFFER_COPY, 1, target);
return target;
}
private static MemoryAddress bufferCopy(MemoryAddress dest, MemorySegment buffer) {
MemoryAddressImpl.ofLongUnchecked(dest.toRawLongValue(), buffer.byteSize()).copyFrom(buffer);
return dest;
}
public static void checkCompatibleType(Class<?> carrier, MemoryLayout layout, long addressSize) {
if (carrier.isPrimitive()) {
Utils.checkPrimitiveCarrierCompat(carrier, layout);
} else if (carrier == MemoryAddress.class) {
Utils.checkLayoutType(layout, ValueLayout.class);
if (layout.bitSize() != addressSize)
throw new IllegalArgumentException("Address size mismatch: " + addressSize + " != " + layout.bitSize());
} else if (carrier == MemorySegment.class) {
Utils.checkLayoutType(layout, GroupLayout.class);
} else {
throw new IllegalArgumentException("Unsupported carrier: " + carrier);
}
}
public static void checkFunctionTypes(MethodType mt, FunctionDescriptor cDesc, long addressSize) {
if (mt.returnType() == void.class != cDesc.returnLayout().isEmpty())
throw new IllegalArgumentException("Return type mismatch: " + mt + " != " + cDesc);
List<MemoryLayout> argLayouts = cDesc.argumentLayouts();
if (mt.parameterCount() != argLayouts.size())
throw new IllegalArgumentException("Arity mismatch: " + mt + " != " + cDesc);
int paramCount = mt.parameterCount();
for (int i = 0; i < paramCount; i++) {
checkCompatibleType(mt.parameterType(i), argLayouts.get(i), addressSize);
}
cDesc.returnLayout().ifPresent(rl -> checkCompatibleType(mt.returnType(), rl, addressSize));
}
public static Class<?> primitiveCarrierForSize(long size) {
if (size == 1) {
return byte.class;
} else if(size == 2) {
return short.class;
} else if (size <= 4) {
return int.class;
} else if (size <= 8) {
return long.class;
}
throw new IllegalArgumentException("Size too large: " + size);
}
public static CLinker getSystemLinker() {
return switch (CABI.current()) {
case Win64 -> Windowsx64Linker.getInstance();
case SysV -> SysVx64Linker.getInstance();
case AArch64 -> AArch64Linker.getInstance();
};
}
public static String toJavaStringInternal(MemorySegment segment, long start, Charset charset) {
int len = strlen(segment, start);
byte[] bytes = new byte[len];
MemorySegment.ofArray(bytes)
.copyFrom(segment.asSlice(start, len));
return new String(bytes, charset);
}
private static int strlen(MemorySegment segment, long start) {
for (int offset = 0; offset >= 0; offset++) {
byte curr = MemoryAccess.getByteAtOffset(segment, start + offset);
if (curr == 0) {
return offset;
}
}
throw new IllegalArgumentException("String too large");
}
private static class AllocHolder {
final static LibraryLookup LOOKUP = LibraryLookup.ofDefault();
final static MethodHandle MH_MALLOC = getSystemLinker().downcallHandle(LOOKUP.lookup("malloc").get(),
MethodType.methodType(MemoryAddress.class, long.class),
FunctionDescriptor.of(C_POINTER, C_LONG_LONG));
final static MethodHandle MH_FREE = getSystemLinker().downcallHandle(LOOKUP.lookup("free").get(),
MethodType.methodType(void.class, MemoryAddress.class),
FunctionDescriptor.ofVoid(C_POINTER));
}
public static MemoryAddress allocateMemoryInternal(long size) {
try {
return (MemoryAddress) AllocHolder.MH_MALLOC.invokeExact(size);
} catch (Throwable th) {
throw new RuntimeException(th);
}
}
public static void freeMemoryInternal(MemoryAddress addr) {
try {
AllocHolder.MH_FREE.invokeExact(addr);
} catch (Throwable th) {
throw new RuntimeException(th);
}
}
public static VaList newVaList(Consumer<VaList.Builder> actions, Allocator allocator) {
return switch (CABI.current()) {
case Win64 -> Windowsx64Linker.newVaList(actions, allocator);
case SysV -> SysVx64Linker.newVaList(actions, allocator);
case AArch64 -> AArch64Linker.newVaList(actions, allocator);
};
}
public static VarHandle vhPrimitiveOrAddress(Class<?> carrier, MemoryLayout layout) {
return carrier == MemoryAddress.class
? MemoryHandles.asAddressVarHandle(layout.varHandle(primitiveCarrierForSize(layout.byteSize())))
: layout.varHandle(carrier);
}
public static VaList newVaListOfAddress(MemoryAddress ma) {
return switch (CABI.current()) {
case Win64 -> Windowsx64Linker.newVaListOfAddress(ma);
case SysV -> SysVx64Linker.newVaListOfAddress(ma);
case AArch64 -> AArch64Linker.newVaListOfAddress(ma);
};
}
public static VaList emptyVaList() {
return switch (CABI.current()) {
case Win64 -> Windowsx64Linker.emptyVaList();
case SysV -> SysVx64Linker.emptyVaList();
case AArch64 -> AArch64Linker.emptyVaList();
};
}
public static MethodType convertVaListCarriers(MethodType mt, Class<?> carrier) {
Class<?>[] params = new Class<?>[mt.parameterCount()];
for (int i = 0; i < params.length; i++) {
Class<?> pType = mt.parameterType(i);
params[i] = ((pType == VaList.class) ? carrier : pType);
}
return methodType(mt.returnType(), params);
}
public static MethodHandle unboxVaLists(MethodType type, MethodHandle handle, MethodHandle unboxer) {
for (int i = 0; i < type.parameterCount(); i++) {
if (type.parameterType(i) == VaList.class) {
handle = MethodHandles.filterArguments(handle, i, unboxer);
}
}
return handle;
}
public static MethodHandle boxVaLists(MethodHandle handle, MethodHandle boxer) {
MethodType type = handle.type();
for (int i = 0; i < type.parameterCount(); i++) {
if (type.parameterType(i) == VaList.class) {
handle = MethodHandles.filterArguments(handle, i, boxer);
}
}
return handle;
}
static void checkType(Class<?> actualType, Class<?> expectedType) {
if (expectedType != actualType) {
throw new IllegalArgumentException(
String.format("Invalid operand type: %s. %s expected", actualType, expectedType));
}
}
public static boolean isTrivial(FunctionDescriptor cDesc) {
return cDesc.attribute(FunctionDescriptor.TRIVIAL_ATTRIBUTE_NAME)
.map(Boolean.class::cast)
.orElse(false);
}
public interface Allocator extends AutoCloseable {
Allocator THROWING_ALLOCATOR = (size, align) -> { throw new UnsupportedOperationException("Null allocator"); };
default MemorySegment allocate(MemoryLayout layout) {
return allocate(layout.byteSize(), layout.byteAlignment());
}
default MemorySegment allocate(long size) {
return allocate(size, 1);
}
@Override
default void close() {}
MemorySegment allocate(long size, long align);
static Allocator ofScope(NativeScope scope) {
return new Allocator() {
@Override
public MemorySegment allocate(long size, long align) {
return scope.allocate(size, align);
}
@Override
public void close() {
scope.close();
}
};
}
}
public static class SimpleVaArg {
public final Class<?> carrier;
public final MemoryLayout layout;
public final Object value;
public SimpleVaArg(Class<?> carrier, MemoryLayout layout, Object value) {
this.carrier = carrier;
this.layout = layout;
this.value = value;
}
public VarHandle varHandle() {
return carrier == MemoryAddress.class
? MemoryHandles.asAddressVarHandle(layout.varHandle(primitiveCarrierForSize(layout.byteSize())))
: layout.varHandle(carrier);
}
}
public static class EmptyVaList implements VaList {
private final MemoryAddress address;
public EmptyVaList(MemoryAddress address) {
this.address = address;
}
private static UnsupportedOperationException uoe() {
return new UnsupportedOperationException("Empty VaList");
}
@Override
public int vargAsInt(MemoryLayout layout) {
throw uoe();
}
@Override
public long vargAsLong(MemoryLayout layout) {
throw uoe();
}
@Override
public double vargAsDouble(MemoryLayout layout) {
throw uoe();
}
@Override
public MemoryAddress vargAsAddress(MemoryLayout layout) {
throw uoe();
}
@Override
public MemorySegment vargAsSegment(MemoryLayout layout) {
throw uoe();
}
@Override
public MemorySegment vargAsSegment(MemoryLayout layout, NativeScope scope) {
throw uoe();
}
@Override
public void skip(MemoryLayout... layouts) {
throw uoe();
}
@Override
public boolean isAlive() {
return true;
}
@Override
public void close() {
throw uoe();
}
@Override
public VaList copy() {
return this;
}
@Override
public VaList copy(NativeScope scope) {
throw uoe();
}
@Override
public MemoryAddress address() {
return address;
}
}
static void writeOverSized(MemorySegment ptr, Class<?> type, Object o) {
if (type == long.class) {
MemoryAccess.setLong(ptr, (long) o);
} else if (type == int.class) {
MemoryAccess.setLong(ptr, (int) o);
} else if (type == short.class) {
MemoryAccess.setLong(ptr, (short) o);
} else if (type == char.class) {
MemoryAccess.setLong(ptr, (char) o);
} else if (type == byte.class) {
MemoryAccess.setLong(ptr, (byte) o);
} else if (type == float.class) {
MemoryAccess.setFloat(ptr, (float) o);
} else if (type == double.class) {
MemoryAccess.setDouble(ptr, (double) o);
} else {
throw new IllegalArgumentException("Unsupported carrier: " + type);
}
}
static void write(MemorySegment ptr, Class<?> type, Object o) {
if (type == long.class) {
MemoryAccess.setLong(ptr, (long) o);
} else if (type == int.class) {
MemoryAccess.setInt(ptr, (int) o);
} else if (type == short.class) {
MemoryAccess.setShort(ptr, (short) o);
} else if (type == char.class) {
MemoryAccess.setChar(ptr, (char) o);
} else if (type == byte.class) {
MemoryAccess.setByte(ptr, (byte) o);
} else if (type == float.class) {
MemoryAccess.setFloat(ptr, (float) o);
} else if (type == double.class) {
MemoryAccess.setDouble(ptr, (double) o);
} else {
throw new IllegalArgumentException("Unsupported carrier: " + type);
}
}
static Object read(MemorySegment ptr, Class<?> type) {
if (type == long.class) {
return MemoryAccess.getLong(ptr);
} else if (type == int.class) {
return MemoryAccess.getInt(ptr);
} else if (type == short.class) {
return MemoryAccess.getShort(ptr);
} else if (type == char.class) {
return MemoryAccess.getChar(ptr);
} else if (type == byte.class) {
return MemoryAccess.getByte(ptr);
} else if (type == float.class) {
return MemoryAccess.getFloat(ptr);
} else if (type == double.class) {
return MemoryAccess.getDouble(ptr);
} else {
throw new IllegalArgumentException("Unsupported carrier: " + type);
}
}
}