package jdk.internal.foreign.abi.aarch64;
import jdk.incubator.foreign.*;
import jdk.internal.foreign.NativeMemorySegmentImpl;
import jdk.internal.foreign.Utils;
import jdk.internal.foreign.abi.SharedUtils;
import jdk.internal.misc.Unsafe;
import java.lang.invoke.VarHandle;
import java.lang.ref.Cleaner;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static jdk.internal.foreign.PlatformLayouts.AArch64;
import static jdk.incubator.foreign.CLinker.VaList;
import static jdk.incubator.foreign.MemoryLayout.PathElement.groupElement;
import static jdk.internal.foreign.abi.SharedUtils.SimpleVaArg;
import static jdk.internal.foreign.abi.SharedUtils.checkCompatibleType;
import static jdk.internal.foreign.abi.SharedUtils.vhPrimitiveOrAddress;
import static jdk.internal.foreign.abi.aarch64.CallArranger.MAX_REGISTER_ARGUMENTS;
public class AArch64VaList implements VaList {
private static final Unsafe U = Unsafe.getUnsafe();
static final Class<?> CARRIER = MemoryAddress.class;
static final GroupLayout LAYOUT = MemoryLayout.ofStruct(
AArch64.C_POINTER.withName("__stack"),
AArch64.C_POINTER.withName("__gr_top"),
AArch64.C_POINTER.withName("__vr_top"),
AArch64.C_INT.withName("__gr_offs"),
AArch64.C_INT.withName("__vr_offs")
).withName("__va_list");
private static final MemoryLayout GP_REG
= MemoryLayout.ofValueBits(64, ByteOrder.nativeOrder());
private static final MemoryLayout FP_REG
= MemoryLayout.ofValueBits(128, ByteOrder.nativeOrder());
private static final MemoryLayout LAYOUT_GP_REGS
= MemoryLayout.ofSequence(MAX_REGISTER_ARGUMENTS, GP_REG);
private static final MemoryLayout LAYOUT_FP_REGS
= MemoryLayout.ofSequence(MAX_REGISTER_ARGUMENTS, FP_REG);
private static final int GP_SLOT_SIZE = (int) GP_REG.byteSize();
private static final int FP_SLOT_SIZE = (int) FP_REG.byteSize();
private static final int MAX_GP_OFFSET = (int) LAYOUT_GP_REGS.byteSize();
private static final int MAX_FP_OFFSET = (int) LAYOUT_FP_REGS.byteSize();
private static final VarHandle VH_stack
= MemoryHandles.asAddressVarHandle(LAYOUT.varHandle(long.class, groupElement("__stack")));
private static final VarHandle VH_gr_top
= MemoryHandles.asAddressVarHandle(LAYOUT.varHandle(long.class, groupElement("__gr_top")));
private static final VarHandle VH_vr_top
= MemoryHandles.asAddressVarHandle(LAYOUT.varHandle(long.class, groupElement("__vr_top")));
private static final VarHandle VH_gr_offs
= LAYOUT.varHandle(int.class, groupElement("__gr_offs"));
private static final VarHandle VH_vr_offs
= LAYOUT.varHandle(int.class, groupElement("__vr_offs"));
private static final Cleaner cleaner = Cleaner.create();
private static final VaList EMPTY
= new SharedUtils.EmptyVaList(emptyListAddress());
private final MemorySegment segment;
private final MemorySegment gpRegsArea;
private final MemorySegment fpRegsArea;
private final List<MemorySegment> attachedSegments;
private AArch64VaList(MemorySegment segment, MemorySegment gpRegsArea, MemorySegment fpRegsArea,
List<MemorySegment> attachedSegments) {
this.segment = segment;
this.gpRegsArea = gpRegsArea;
this.fpRegsArea = fpRegsArea;
this.attachedSegments = attachedSegments;
}
private static AArch64VaList readFromSegment(MemorySegment segment) {
MemorySegment gpRegsArea = handoffIfNeeded(grTop(segment).addOffset(-MAX_GP_OFFSET)
.asSegmentRestricted(MAX_GP_OFFSET), segment.ownerThread());
MemorySegment fpRegsArea = handoffIfNeeded(vrTop(segment).addOffset(-MAX_FP_OFFSET)
.asSegmentRestricted(MAX_FP_OFFSET), segment.ownerThread());
return new AArch64VaList(segment, gpRegsArea, fpRegsArea, List.of(gpRegsArea, fpRegsArea));
}
private static MemoryAddress emptyListAddress() {
long ptr = U.allocateMemory(LAYOUT.byteSize());
MemorySegment ms = MemoryAddress.ofLong(ptr)
.asSegmentRestricted(LAYOUT.byteSize(), () -> U.freeMemory(ptr), null)
.share();
cleaner.register(AArch64VaList.class, ms::close);
VH_stack.set(ms, MemoryAddress.NULL);
VH_gr_top.set(ms, MemoryAddress.NULL);
VH_vr_top.set(ms, MemoryAddress.NULL);
VH_gr_offs.set(ms, 0);
VH_vr_offs.set(ms, 0);
return ms.address();
}
public static VaList empty() {
return EMPTY;
}
private MemoryAddress grTop() {
return grTop(segment);
}
private static MemoryAddress grTop(MemorySegment segment) {
return (MemoryAddress) VH_gr_top.get(segment);
}
private MemoryAddress vrTop() {
return vrTop(segment);
}
private static MemoryAddress vrTop(MemorySegment segment) {
return (MemoryAddress) VH_vr_top.get(segment);
}
private int grOffs() {
final int offs = (int) VH_gr_offs.get(segment);
assert offs <= 0;
return offs;
}
private int vrOffs() {
final int offs = (int) VH_vr_offs.get(segment);
assert offs <= 0;
return offs;
}
private MemoryAddress stackPtr() {
return (MemoryAddress) VH_stack.get(segment);
}
private void stackPtr(MemoryAddress ptr) {
VH_stack.set(segment, ptr);
}
private void consumeGPSlots(int num) {
final int old = (int) VH_gr_offs.get(segment);
VH_gr_offs.set(segment, old + num * GP_SLOT_SIZE);
}
private void consumeFPSlots(int num) {
final int old = (int) VH_vr_offs.get(segment);
VH_vr_offs.set(segment, old + num * FP_SLOT_SIZE);
}
private long currentGPOffset() {
return gpRegsArea.byteSize() + grOffs();
}
private long currentFPOffset() {
return fpRegsArea.byteSize() + vrOffs();
}
private void preAlignStack(MemoryLayout layout) {
if (layout.byteAlignment() > 8) {
stackPtr(Utils.alignUp(stackPtr(), 16));
}
}
private void postAlignStack(MemoryLayout layout) {
stackPtr(Utils.alignUp(stackPtr().addOffset(layout.byteSize()), 8));
}
@Override
public int vargAsInt(MemoryLayout layout) {
return (int) read(int.class, layout);
}
@Override
public long vargAsLong(MemoryLayout layout) {
return (long) read(long.class, layout);
}
@Override
public double vargAsDouble(MemoryLayout layout) {
return (double) read(double.class, layout);
}
@Override
public MemoryAddress vargAsAddress(MemoryLayout layout) {
return (MemoryAddress) read(MemoryAddress.class, layout);
}
@Override
public MemorySegment vargAsSegment(MemoryLayout layout) {
return (MemorySegment) read(MemorySegment.class, layout);
}
@Override
public MemorySegment vargAsSegment(MemoryLayout layout, NativeScope scope) {
Objects.requireNonNull(scope);
return (MemorySegment) read(MemorySegment.class, layout, SharedUtils.Allocator.ofScope(scope));
}
private Object read(Class<?> carrier, MemoryLayout layout) {
return read(carrier, layout, MemorySegment::allocateNative);
}
private Object read(Class<?> carrier, MemoryLayout layout, SharedUtils.Allocator allocator) {
Objects.requireNonNull(layout);
checkCompatibleType(carrier, layout, AArch64Linker.ADDRESS_SIZE);
TypeClass typeClass = TypeClass.classifyLayout(layout);
if (isRegOverflow(currentGPOffset(), currentFPOffset(), typeClass, layout)) {
preAlignStack(layout);
return switch (typeClass) {
case STRUCT_REGISTER, STRUCT_HFA, STRUCT_REFERENCE -> {
try (MemorySegment slice = handoffIfNeeded(stackPtr()
.asSegmentRestricted(layout.byteSize()), segment.ownerThread())) {
MemorySegment seg = allocator.allocate(layout);
seg.copyFrom(slice);
postAlignStack(layout);
yield seg;
}
}
case POINTER, INTEGER, FLOAT -> {
VarHandle reader = vhPrimitiveOrAddress(carrier, layout);
try (MemorySegment slice = handoffIfNeeded(stackPtr()
.asSegmentRestricted(layout.byteSize()), segment.ownerThread())) {
Object res = reader.get(slice);
postAlignStack(layout);
yield res;
}
}
};
} else {
return switch (typeClass) {
case STRUCT_REGISTER -> {
MemorySegment value = allocator.allocate(layout);
long offset = 0;
while (offset < layout.byteSize()) {
final long copy = Math.min(layout.byteSize() - offset, 8);
MemorySegment slice = value.asSlice(offset, copy);
slice.copyFrom(gpRegsArea.asSlice(currentGPOffset(), copy));
consumeGPSlots(1);
offset += copy;
}
yield value;
}
case STRUCT_HFA -> {
MemorySegment value = allocator.allocate(layout);
GroupLayout group = (GroupLayout)layout;
long offset = 0;
for (MemoryLayout elem : group.memberLayouts()) {
assert elem.byteSize() <= 8;
final long copy = elem.byteSize();
MemorySegment slice = value.asSlice(offset, copy);
slice.copyFrom(fpRegsArea.asSlice(currentFPOffset(), copy));
consumeFPSlots(1);
offset += copy;
}
yield value;
}
case STRUCT_REFERENCE -> {
VarHandle ptrReader
= SharedUtils.vhPrimitiveOrAddress(MemoryAddress.class, AArch64.C_POINTER);
MemoryAddress ptr = (MemoryAddress) ptrReader.get(
gpRegsArea.asSlice(currentGPOffset()));
consumeGPSlots(1);
try (MemorySegment slice = handoffIfNeeded(ptr
.asSegmentRestricted(layout.byteSize()), segment.ownerThread())) {
MemorySegment seg = allocator.allocate(layout);
seg.copyFrom(slice);
yield seg;
}
}
case POINTER, INTEGER -> {
VarHandle reader = SharedUtils.vhPrimitiveOrAddress(carrier, layout);
Object res = reader.get(gpRegsArea.asSlice(currentGPOffset()));
consumeGPSlots(1);
yield res;
}
case FLOAT -> {
VarHandle reader = layout.varHandle(carrier);
Object res = reader.get(fpRegsArea.asSlice(currentFPOffset()));
consumeFPSlots(1);
yield res;
}
};
}
}
@Override
public void skip(MemoryLayout... layouts) {
Objects.requireNonNull(layouts);
for (MemoryLayout layout : layouts) {
Objects.requireNonNull(layout);
TypeClass typeClass = TypeClass.classifyLayout(layout);
if (isRegOverflow(currentGPOffset(), currentFPOffset(), typeClass, layout)) {
preAlignStack(layout);
postAlignStack(layout);
} else if (typeClass == TypeClass.FLOAT || typeClass == TypeClass.STRUCT_HFA) {
consumeFPSlots(numSlots(layout));
} else if (typeClass == TypeClass.STRUCT_REFERENCE) {
consumeGPSlots(1);
} else {
consumeGPSlots(numSlots(layout));
}
}
}
static AArch64VaList.Builder builder(SharedUtils.Allocator allocator) {
return new AArch64VaList.Builder(allocator);
}
public static VaList ofAddress(MemoryAddress ma) {
return readFromSegment(ma.asSegmentRestricted(LAYOUT.byteSize()));
}
@Override
public boolean isAlive() {
return segment.isAlive();
}
@Override
public void close() {
segment.close();
attachedSegments.forEach(MemorySegment::close);
}
@Override
public VaList copy() {
return copy(MemorySegment::allocateNative);
}
@Override
public VaList copy(NativeScope scope) {
Objects.requireNonNull(scope);
return copy(SharedUtils.Allocator.ofScope(scope));
}
private VaList copy(SharedUtils.Allocator allocator) {
MemorySegment copy = allocator.allocate(LAYOUT);
copy.copyFrom(segment);
return new AArch64VaList(copy, gpRegsArea, fpRegsArea, List.of());
}
@Override
public MemoryAddress address() {
return segment.address();
}
private static int numSlots(MemoryLayout layout) {
return (int) Utils.alignUp(layout.byteSize(), 8) / 8;
}
private static boolean isRegOverflow(long currentGPOffset, long currentFPOffset,
TypeClass typeClass, MemoryLayout layout) {
if (typeClass == TypeClass.FLOAT || typeClass == TypeClass.STRUCT_HFA) {
return currentFPOffset > MAX_FP_OFFSET - numSlots(layout) * FP_SLOT_SIZE;
} else if (typeClass == TypeClass.STRUCT_REFERENCE) {
return currentGPOffset > MAX_GP_OFFSET - GP_SLOT_SIZE;
} else {
return currentGPOffset > MAX_GP_OFFSET - numSlots(layout) * GP_SLOT_SIZE;
}
}
@Override
public String toString() {
return "AArch64VaList{"
+ "__stack=" + stackPtr()
+ ", __gr_top=" + grTop()
+ ", __vr_top=" + vrTop()
+ ", __gr_offs=" + grOffs()
+ ", __vr_offs=" + vrOffs()
+ '}';
}
static class Builder implements VaList.Builder {
private final SharedUtils.Allocator allocator;
private final MemorySegment gpRegs;
private final MemorySegment fpRegs;
private long currentGPOffset = 0;
private long currentFPOffset = 0;
private final List<SimpleVaArg> stackArgs = new ArrayList<>();
Builder(SharedUtils.Allocator allocator) {
this.allocator = allocator;
this.gpRegs = allocator.allocate(LAYOUT_GP_REGS);
this.fpRegs = allocator.allocate(LAYOUT_FP_REGS);
}
@Override
public Builder vargFromInt(ValueLayout layout, int value) {
return arg(int.class, layout, value);
}
@Override
public Builder vargFromLong(ValueLayout layout, long value) {
return arg(long.class, layout, value);
}
@Override
public Builder vargFromDouble(ValueLayout layout, double value) {
return arg(double.class, layout, value);
}
@Override
public Builder vargFromAddress(ValueLayout layout, Addressable value) {
return arg(MemoryAddress.class, layout, value.address());
}
@Override
public Builder vargFromSegment(GroupLayout layout, MemorySegment value) {
return arg(MemorySegment.class, layout, value);
}
private Builder arg(Class<?> carrier, MemoryLayout layout, Object value) {
Objects.requireNonNull(layout);
Objects.requireNonNull(value);
checkCompatibleType(carrier, layout, AArch64Linker.ADDRESS_SIZE);
TypeClass typeClass = TypeClass.classifyLayout(layout);
if (isRegOverflow(currentGPOffset, currentFPOffset, typeClass, layout)) {
stackArgs.add(new SimpleVaArg(carrier, layout, value));
} else {
switch (typeClass) {
case STRUCT_REGISTER -> {
MemorySegment valueSegment = (MemorySegment) value;
long offset = 0;
while (offset < layout.byteSize()) {
final long copy = Math.min(layout.byteSize() - offset, 8);
MemorySegment slice = valueSegment.asSlice(offset, copy);
gpRegs.asSlice(currentGPOffset, copy).copyFrom(slice);
currentGPOffset += GP_SLOT_SIZE;
offset += copy;
}
}
case STRUCT_HFA -> {
MemorySegment valueSegment = (MemorySegment) value;
GroupLayout group = (GroupLayout)layout;
long offset = 0;
for (MemoryLayout elem : group.memberLayouts()) {
assert elem.byteSize() <= 8;
final long copy = elem.byteSize();
MemorySegment slice = valueSegment.asSlice(offset, copy);
fpRegs.asSlice(currentFPOffset, copy).copyFrom(slice);
currentFPOffset += FP_SLOT_SIZE;
offset += copy;
}
}
case STRUCT_REFERENCE -> {
MemorySegment valueSegment = (MemorySegment) value;
VarHandle writer
= SharedUtils.vhPrimitiveOrAddress(MemoryAddress.class,
AArch64.C_POINTER);
writer.set(gpRegs.asSlice(currentGPOffset),
valueSegment.address());
currentGPOffset += GP_SLOT_SIZE;
}
case POINTER, INTEGER -> {
VarHandle writer = SharedUtils.vhPrimitiveOrAddress(carrier, layout);
writer.set(gpRegs.asSlice(currentGPOffset), value);
currentGPOffset += GP_SLOT_SIZE;
}
case FLOAT -> {
VarHandle writer = layout.varHandle(carrier);
writer.set(fpRegs.asSlice(currentFPOffset), value);
currentFPOffset += FP_SLOT_SIZE;
}
}
}
return this;
}
private boolean isEmpty() {
return currentGPOffset == 0 && currentFPOffset == 0 && stackArgs.isEmpty();
}
public VaList build() {
if (isEmpty()) {
return EMPTY;
}
MemorySegment vaListSegment = allocator.allocate(LAYOUT);
List<MemorySegment> attachedSegments = new ArrayList<>();
MemoryAddress stackArgsPtr = MemoryAddress.NULL;
if (!stackArgs.isEmpty()) {
long stackArgsSize = stackArgs.stream()
.reduce(0L, (acc, e) -> acc + Utils.alignUp(e.layout.byteSize(), 8), Long::sum);
MemorySegment stackArgsSegment = allocator.allocate(stackArgsSize, 16);
stackArgsPtr = stackArgsSegment.address();
for (SimpleVaArg arg : stackArgs) {
final long alignedSize = Utils.alignUp(arg.layout.byteSize(), 8);
stackArgsSegment = Utils.alignUp(stackArgsSegment, alignedSize);
VarHandle writer = arg.varHandle();
writer.set(stackArgsSegment, arg.value);
stackArgsSegment = stackArgsSegment.asSlice(alignedSize);
}
attachedSegments.add(stackArgsSegment);
}
VH_gr_top.set(vaListSegment, gpRegs.asSlice(gpRegs.byteSize()).address());
VH_vr_top.set(vaListSegment, fpRegs.asSlice(fpRegs.byteSize()).address());
VH_stack.set(vaListSegment, stackArgsPtr);
VH_gr_offs.set(vaListSegment, -MAX_GP_OFFSET);
VH_vr_offs.set(vaListSegment, -MAX_FP_OFFSET);
attachedSegments.add(gpRegs);
attachedSegments.add(fpRegs);
assert gpRegs.ownerThread() == vaListSegment.ownerThread();
assert fpRegs.ownerThread() == vaListSegment.ownerThread();
return new AArch64VaList(vaListSegment, gpRegs, fpRegs, attachedSegments);
}
}
private static MemorySegment handoffIfNeeded(MemorySegment segment, Thread thread) {
return segment.ownerThread() == thread ?
segment : segment.handoff(thread);
}
}