package org.graalvm.compiler.lir.aarch64;
import static jdk.vm.ci.aarch64.AArch64.zr;
import static jdk.vm.ci.code.ValueUtil.asRegister;
import static org.graalvm.compiler.lir.LIRInstruction.OperandFlag.REG;
import org.graalvm.compiler.asm.Label;
import org.graalvm.compiler.asm.aarch64.AArch64Address;
import org.graalvm.compiler.asm.aarch64.AArch64Assembler;
import org.graalvm.compiler.asm.aarch64.AArch64Assembler.ConditionFlag;
import org.graalvm.compiler.asm.aarch64.AArch64MacroAssembler;
import org.graalvm.compiler.lir.LIRInstructionClass;
import org.graalvm.compiler.lir.Opcode;
import org.graalvm.compiler.lir.asm.CompilationResultBuilder;
import jdk.vm.ci.code.CodeUtil;
import jdk.vm.ci.code.Register;
import jdk.vm.ci.meta.Value;
@Opcode("ZERO_MEMORY")
public final class AArch64ZeroMemoryOp extends AArch64LIRInstruction {
public static final LIRInstructionClass<AArch64ZeroMemoryOp> TYPE = LIRInstructionClass.create(AArch64ZeroMemoryOp.class);
@Use({REG}) protected Value addressValue;
@Use({REG}) protected Value lengthValue;
@Temp({REG}) protected Value addressValueTemp;
@Temp({REG}) protected Value lengthValueTemp;
private final boolean isAligned;
private final boolean useDcZva;
private final int zvaLength;
public AArch64ZeroMemoryOp(Value address, Value length, boolean isAligned, boolean useDcZva, int zvaLength) {
super(TYPE);
this.addressValue = address;
this.lengthValue = length;
this.addressValueTemp = address;
this.lengthValueTemp = length;
this.useDcZva = useDcZva;
this.zvaLength = zvaLength;
this.isAligned = isAligned;
}
@Override
protected void emitCode(CompilationResultBuilder crb, AArch64MacroAssembler masm) {
Register base = asRegister(addressValue);
Register size = asRegister(lengthValue);
try (AArch64MacroAssembler.ScratchRegister scratchRegister = masm.getScratchRegister()) {
Register alignmentBits = scratchRegister.getRegister();
Label tail = new Label();
Label done = new Label();
masm.cbz(64, size, done);
if (!isAligned) {
Label baseAlignedTo2Bytes = new Label();
Label baseAlignedTo4Bytes = new Label();
Label baseAlignedTo8Bytes = new Label();
masm.cmp(64, size, 8);
masm.branchConditionally(ConditionFlag.LT, tail);
masm.neg(64, alignmentBits, base);
masm.and(64, alignmentBits, alignmentBits, 7);
masm.tbz(alignmentBits, 0, baseAlignedTo2Bytes);
masm.sub(64, size, size, 1);
masm.str(8, zr, AArch64Address.createPostIndexedImmediateAddress(base, 1));
masm.bind(baseAlignedTo2Bytes);
masm.tbz(alignmentBits, 1, baseAlignedTo4Bytes);
masm.sub(64, size, size, 2);
masm.str(16, zr, AArch64Address.createPostIndexedImmediateAddress(base, 2));
masm.bind(baseAlignedTo4Bytes);
masm.tbz(alignmentBits, 2, baseAlignedTo8Bytes);
masm.sub(64, size, size, 4);
masm.str(32, zr, AArch64Address.createPostIndexedImmediateAddress(base, 4));
masm.bind(baseAlignedTo8Bytes);
}
if (useDcZva && zvaLength > 0) {
assert (CodeUtil.isPowerOf2(zvaLength) && 4 <= zvaLength && zvaLength <= 2048);
Label preCheck = new Label();
Label preLoop = new Label();
Label mainCheck = new Label();
Label mainLoop = new Label();
Label postCheck = new Label();
Label postLoop = new Label();
masm.neg(64, alignmentBits, base);
masm.and(64, alignmentBits, alignmentBits, zvaLength - 1);
masm.cmp(64, size, alignmentBits);
masm.branchConditionally(AArch64Assembler.ConditionFlag.LE, postCheck);
masm.sub(64, size, size, alignmentBits);
masm.jmp(preCheck);
masm.align(crb.target.wordSize * 2);
masm.bind(preLoop);
masm.str(64, zr, AArch64Address.createPostIndexedImmediateAddress(base, 8));
masm.bind(preCheck);
masm.subs(64, alignmentBits, alignmentBits, 8);
masm.branchConditionally(AArch64Assembler.ConditionFlag.GE, preLoop);
masm.jmp(mainCheck);
masm.align(crb.target.wordSize * 2);
masm.bind(mainLoop);
masm.dc(AArch64Assembler.DataCacheOperationType.ZVA, base);
masm.add(64, base, base, zvaLength);
masm.bind(mainCheck);
masm.subs(64, size, size, zvaLength);
masm.branchConditionally(AArch64Assembler.ConditionFlag.GE, mainLoop);
masm.add(64, size, size, zvaLength);
masm.jmp(postCheck);
masm.align(crb.target.wordSize * 2);
masm.bind(postLoop);
masm.str(64, zr, AArch64Address.createPostIndexedImmediateAddress(base, 8));
masm.bind(postCheck);
masm.subs(64, size, size, 8);
masm.branchConditionally(AArch64Assembler.ConditionFlag.GE, postLoop);
if (!isAligned) {
masm.add(64, size, size, 8);
}
} else {
Label mainCheck = new Label();
Label mainLoop = new Label();
if (!isAligned) {
masm.cmp(64, size, 8);
masm.branchConditionally(ConditionFlag.LT, tail);
}
masm.tbz(base, 3, mainCheck);
masm.sub(64, size, size, 8);
masm.str(64, zr, AArch64Address.createPostIndexedImmediateAddress(base, 8));
masm.jmp(mainCheck);
masm.align(crb.target.wordSize * 2);
masm.bind(mainLoop);
masm.stp(64, zr, zr, AArch64Address.createPostIndexedImmediateAddress(base, 2));
masm.bind(mainCheck);
masm.subs(64, size, size, 16);
masm.branchConditionally(AArch64Assembler.ConditionFlag.GE, mainLoop);
masm.add(64, size, size, 16);
masm.tbz(size, 3, tail);
masm.str(64, zr, AArch64Address.createPostIndexedImmediateAddress(base, 8));
if (!isAligned) {
masm.sub(64, size, size, 8);
}
}
masm.bind(tail);
if (!isAligned) {
Label perByteZeroingLoop = new Label();
masm.cbz(64, size, done);
masm.align(crb.target.wordSize * 2);
masm.bind(perByteZeroingLoop);
masm.str(8, zr, AArch64Address.createPostIndexedImmediateAddress(base, 1));
masm.subs(64, size, size, 1);
masm.branchConditionally(AArch64Assembler.ConditionFlag.NE, perByteZeroingLoop);
}
masm.bind(done);
}
}
}