package jdk.internal.foreign.abi.aarch64;
import jdk.incubator.foreign.Addressable;
import jdk.incubator.foreign.FunctionDescriptor;
import jdk.incubator.foreign.GroupLayout;
import jdk.incubator.foreign.MemoryAddress;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.MemorySegment;
import jdk.internal.foreign.PlatformLayouts;
import jdk.internal.foreign.Utils;
import jdk.internal.foreign.abi.CallingSequenceBuilder;
import jdk.internal.foreign.abi.UpcallHandler;
import jdk.internal.foreign.abi.ABIDescriptor;
import jdk.internal.foreign.abi.Binding;
import jdk.internal.foreign.abi.CallingSequence;
import jdk.internal.foreign.abi.ProgrammableInvoker;
import jdk.internal.foreign.abi.ProgrammableUpcallHandler;
import jdk.internal.foreign.abi.VMStorage;
import jdk.internal.foreign.abi.SharedUtils;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.List;
import java.util.Optional;
import static jdk.internal.foreign.PlatformLayouts.*;
import static jdk.internal.foreign.abi.aarch64.AArch64Architecture.*;
public class CallArranger {
private static final int STACK_SLOT_SIZE = 8;
public static final int MAX_REGISTER_ARGUMENTS = 8;
private static final VMStorage INDIRECT_RESULT = r8;
private static final ABIDescriptor C = AArch64Architecture.abiFor(
new VMStorage[] { r0, r1, r2, r3, r4, r5, r6, r7, INDIRECT_RESULT},
new VMStorage[] { v0, v1, v2, v3, v4, v5, v6, v7 },
new VMStorage[] { r0, r1 },
new VMStorage[] { v0, v1, v2, v3 },
new VMStorage[] { r9, r10, r11, r12, r13, r14, r15 },
new VMStorage[] { v16, v17, v18, v19, v20, v21, v22, v23, v25,
v26, v27, v28, v29, v30, v31 },
16,
0
);
public static class Bindings {
public final CallingSequence callingSequence;
public final boolean isInMemoryReturn;
Bindings(CallingSequence callingSequence, boolean isInMemoryReturn) {
this.callingSequence = callingSequence;
this.isInMemoryReturn = isInMemoryReturn;
}
}
public static Bindings getBindings(MethodType mt, FunctionDescriptor cDesc, boolean forUpcall) {
SharedUtils.checkFunctionTypes(mt, cDesc, AArch64Linker.ADDRESS_SIZE);
CallingSequenceBuilder csb = new CallingSequenceBuilder(forUpcall);
BindingCalculator argCalc = forUpcall ? new BoxBindingCalculator(true) : new UnboxBindingCalculator(true);
BindingCalculator retCalc = forUpcall ? new UnboxBindingCalculator(false) : new BoxBindingCalculator(false);
boolean returnInMemory = isInMemoryReturn(cDesc.returnLayout());
if (returnInMemory) {
csb.addArgumentBindings(MemoryAddress.class, AArch64.C_POINTER,
argCalc.getIndirectBindings());
} else if (cDesc.returnLayout().isPresent()) {
Class<?> carrier = mt.returnType();
MemoryLayout layout = cDesc.returnLayout().get();
csb.setReturnBindings(carrier, layout, retCalc.getBindings(carrier, layout));
}
for (int i = 0; i < mt.parameterCount(); i++) {
Class<?> carrier = mt.parameterType(i);
MemoryLayout layout = cDesc.argumentLayouts().get(i);
csb.addArgumentBindings(carrier, layout, argCalc.getBindings(carrier, layout));
}
csb.setTrivial(SharedUtils.isTrivial(cDesc));
return new Bindings(csb.build(), returnInMemory);
}
public static MethodHandle arrangeDowncall(Addressable addr, MethodType mt, FunctionDescriptor cDesc) {
Bindings bindings = getBindings(mt, cDesc, false);
MethodHandle handle = new ProgrammableInvoker(C, addr, bindings.callingSequence).getBoundMethodHandle();
if (bindings.isInMemoryReturn) {
handle = SharedUtils.adaptDowncallForIMR(handle, cDesc);
}
return handle;
}
public static UpcallHandler arrangeUpcall(MethodHandle target, MethodType mt, FunctionDescriptor cDesc) {
Bindings bindings = getBindings(mt, cDesc, true);
if (bindings.isInMemoryReturn) {
target = SharedUtils.adaptUpcallForIMR(target);
}
return new ProgrammableUpcallHandler(C, target, bindings.callingSequence);
}
private static boolean isInMemoryReturn(Optional<MemoryLayout> returnLayout) {
return returnLayout
.filter(GroupLayout.class::isInstance)
.filter(g -> TypeClass.classifyLayout(g) == TypeClass.STRUCT_REFERENCE)
.isPresent();
}
static class StorageCalculator {
private final boolean forArguments;
private final int[] nRegs = new int[] { 0, 0 };
private long stackOffset = 0;
public StorageCalculator(boolean forArguments) {
this.forArguments = forArguments;
}
VMStorage stackAlloc(long size, long alignment) {
assert forArguments : "no stack returns";
alignment = Math.max(alignment, STACK_SLOT_SIZE);
stackOffset = Utils.alignUp(stackOffset, alignment);
VMStorage storage =
AArch64Architecture.stackStorage((int)(stackOffset / STACK_SLOT_SIZE));
stackOffset += size;
return storage;
}
VMStorage stackAlloc(MemoryLayout layout) {
return stackAlloc(layout.byteSize(), SharedUtils.alignment(layout, true));
}
VMStorage[] regAlloc(int type, int count) {
if (nRegs[type] + count <= MAX_REGISTER_ARGUMENTS) {
VMStorage[] source =
(forArguments ? C.inputStorage : C.outputStorage)[type];
VMStorage[] result = new VMStorage[count];
for (int i = 0; i < count; i++) {
result[i] = source[nRegs[type]++];
}
return result;
} else {
nRegs[type] = MAX_REGISTER_ARGUMENTS;
return null;
}
}
VMStorage[] regAlloc(int type, MemoryLayout layout) {
return regAlloc(type, (int)Utils.alignUp(layout.byteSize(), 8) / 8);
}
VMStorage nextStorage(int type, MemoryLayout layout) {
VMStorage[] storage = regAlloc(type, 1);
if (storage == null) {
return stackAlloc(layout);
}
return storage[0];
}
}
static abstract class BindingCalculator {
protected final StorageCalculator storageCalculator;
protected BindingCalculator(boolean forArguments) {
this.storageCalculator = new StorageCalculator(forArguments);
}
protected void spillStructUnbox(Binding.Builder bindings, MemoryLayout layout) {
long offset = 0;
while (offset < layout.byteSize()) {
long copy = Math.min(layout.byteSize() - offset, STACK_SLOT_SIZE);
VMStorage storage =
storageCalculator.stackAlloc(copy, STACK_SLOT_SIZE);
if (offset + STACK_SLOT_SIZE < layout.byteSize()) {
bindings.dup();
}
Class<?> type = SharedUtils.primitiveCarrierForSize(copy);
bindings.bufferLoad(offset, type)
.vmStore(storage, type);
offset += STACK_SLOT_SIZE;
}
}
protected void spillStructBox(Binding.Builder bindings, MemoryLayout layout) {
long offset = 0;
while (offset < layout.byteSize()) {
long copy = Math.min(layout.byteSize() - offset, STACK_SLOT_SIZE);
VMStorage storage =
storageCalculator.stackAlloc(copy, STACK_SLOT_SIZE);
Class<?> type = SharedUtils.primitiveCarrierForSize(copy);
bindings.dup()
.vmLoad(storage, type)
.bufferStore(offset, type);
offset += STACK_SLOT_SIZE;
}
}
abstract List<Binding> getBindings(Class<?> carrier, MemoryLayout layout);
abstract List<Binding> getIndirectBindings();
}
static class UnboxBindingCalculator extends BindingCalculator {
UnboxBindingCalculator(boolean forArguments) {
super(forArguments);
}
@Override
List<Binding> getIndirectBindings() {
return Binding.builder()
.unboxAddress()
.vmStore(INDIRECT_RESULT, long.class)
.build();
}
@Override
List<Binding> getBindings(Class<?> carrier, MemoryLayout layout) {
TypeClass argumentClass = TypeClass.classifyLayout(layout);
Binding.Builder bindings = Binding.builder();
switch (argumentClass) {
case STRUCT_REGISTER: {
assert carrier == MemorySegment.class;
VMStorage[] regs = storageCalculator.regAlloc(
StorageClasses.INTEGER, layout);
if (regs != null) {
int regIndex = 0;
long offset = 0;
while (offset < layout.byteSize()) {
final long copy = Math.min(layout.byteSize() - offset, 8);
VMStorage storage = regs[regIndex++];
Class<?> type = SharedUtils.primitiveCarrierForSize(copy);
if (offset + copy < layout.byteSize()) {
bindings.dup();
}
bindings.bufferLoad(offset, type)
.vmStore(storage, type);
offset += copy;
}
} else {
spillStructUnbox(bindings, layout);
}
break;
}
case STRUCT_REFERENCE: {
assert carrier == MemorySegment.class;
bindings.copy(layout)
.baseAddress()
.unboxAddress();
VMStorage storage = storageCalculator.nextStorage(
StorageClasses.INTEGER, AArch64.C_POINTER);
bindings.vmStore(storage, long.class);
break;
}
case STRUCT_HFA: {
assert carrier == MemorySegment.class;
GroupLayout group = (GroupLayout)layout;
VMStorage[] regs = storageCalculator.regAlloc(
StorageClasses.VECTOR, group.memberLayouts().size());
if (regs != null) {
long offset = 0;
for (int i = 0; i < group.memberLayouts().size(); i++) {
VMStorage storage = regs[i];
final long size = group.memberLayouts().get(i).byteSize();
Class<?> type = SharedUtils.primitiveCarrierForSize(size);
if (i + 1 < group.memberLayouts().size()) {
bindings.dup();
}
bindings.bufferLoad(offset, type)
.vmStore(storage, type);
offset += size;
}
} else {
spillStructUnbox(bindings, layout);
}
break;
}
case POINTER: {
bindings.unboxAddress();
VMStorage storage =
storageCalculator.nextStorage(StorageClasses.INTEGER, layout);
bindings.vmStore(storage, long.class);
break;
}
case INTEGER: {
VMStorage storage =
storageCalculator.nextStorage(StorageClasses.INTEGER, layout);
bindings.vmStore(storage, carrier);
break;
}
case FLOAT: {
VMStorage storage =
storageCalculator.nextStorage(StorageClasses.VECTOR, layout);
bindings.vmStore(storage, carrier);
break;
}
default:
throw new UnsupportedOperationException("Unhandled class " + argumentClass);
}
return bindings.build();
}
}
static class BoxBindingCalculator extends BindingCalculator{
BoxBindingCalculator(boolean forArguments) {
super(forArguments);
}
@Override
List<Binding> getIndirectBindings() {
return Binding.builder()
.vmLoad(INDIRECT_RESULT, long.class)
.boxAddress()
.build();
}
@Override
List<Binding> getBindings(Class<?> carrier, MemoryLayout layout) {
TypeClass argumentClass = TypeClass.classifyLayout(layout);
Binding.Builder bindings = Binding.builder();
switch (argumentClass) {
case STRUCT_REGISTER: {
assert carrier == MemorySegment.class;
bindings.allocate(layout);
VMStorage[] regs = storageCalculator.regAlloc(
StorageClasses.INTEGER, layout);
if (regs != null) {
int regIndex = 0;
long offset = 0;
while (offset < layout.byteSize()) {
final long copy = Math.min(layout.byteSize() - offset, 8);
VMStorage storage = regs[regIndex++];
bindings.dup();
Class<?> type = SharedUtils.primitiveCarrierForSize(copy);
bindings.vmLoad(storage, type)
.bufferStore(offset, type);
offset += copy;
}
} else {
spillStructBox(bindings, layout);
}
break;
}
case STRUCT_REFERENCE: {
assert carrier == MemorySegment.class;
VMStorage storage = storageCalculator.nextStorage(
StorageClasses.INTEGER, AArch64.C_POINTER);
bindings.vmLoad(storage, long.class)
.boxAddress()
.toSegment(layout);
bindings.copy(layout);
break;
}
case STRUCT_HFA: {
assert carrier == MemorySegment.class;
bindings.allocate(layout);
GroupLayout group = (GroupLayout)layout;
VMStorage[] regs = storageCalculator.regAlloc(
StorageClasses.VECTOR, group.memberLayouts().size());
if (regs != null) {
long offset = 0;
for (int i = 0; i < group.memberLayouts().size(); i++) {
VMStorage storage = regs[i];
final long size = group.memberLayouts().get(i).byteSize();
Class<?> type = SharedUtils.primitiveCarrierForSize(size);
bindings.dup()
.vmLoad(storage, type)
.bufferStore(offset, type);
offset += size;
}
} else {
spillStructBox(bindings, layout);
}
break;
}
case POINTER: {
VMStorage storage =
storageCalculator.nextStorage(StorageClasses.INTEGER, layout);
bindings.vmLoad(storage, long.class)
.boxAddress();
break;
}
case INTEGER: {
VMStorage storage =
storageCalculator.nextStorage(StorageClasses.INTEGER, layout);
bindings.vmLoad(storage, carrier);
break;
}
case FLOAT: {
VMStorage storage =
storageCalculator.nextStorage(StorageClasses.VECTOR, layout);
bindings.vmLoad(storage, carrier);
break;
}
default:
throw new UnsupportedOperationException("Unhandled class " + argumentClass);
}
return bindings.build();
}
}
}