package org.graalvm.compiler.lir.amd64;
import static jdk.vm.ci.amd64.AMD64.k7;
import static jdk.vm.ci.amd64.AMD64.rax;
import static jdk.vm.ci.amd64.AMD64.rcx;
import static jdk.vm.ci.amd64.AMD64.rdx;
import static jdk.vm.ci.code.ValueUtil.asRegister;
import static org.graalvm.compiler.lir.LIRInstruction.OperandFlag.ILLEGAL;
import static org.graalvm.compiler.lir.LIRInstruction.OperandFlag.REG;
import java.util.EnumSet;
import org.graalvm.compiler.asm.Label;
import org.graalvm.compiler.asm.amd64.AMD64Address;
import org.graalvm.compiler.asm.amd64.AMD64Address.Scale;
import org.graalvm.compiler.asm.amd64.AMD64Assembler.ConditionFlag;
import org.graalvm.compiler.asm.amd64.AMD64MacroAssembler;
import org.graalvm.compiler.core.common.LIRKind;
import org.graalvm.compiler.lir.LIRInstructionClass;
import org.graalvm.compiler.lir.Opcode;
import org.graalvm.compiler.lir.asm.CompilationResultBuilder;
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
import jdk.vm.ci.amd64.AMD64;
import jdk.vm.ci.amd64.AMD64.CPUFeature;
import jdk.vm.ci.amd64.AMD64Kind;
import jdk.vm.ci.code.CodeUtil;
import jdk.vm.ci.code.Register;
import jdk.vm.ci.code.TargetDescription;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.Value;
@Opcode("ARRAY_COMPARE_TO")
public final class AMD64ArrayCompareToOp extends AMD64LIRInstruction {
public static final LIRInstructionClass<AMD64ArrayCompareToOp> TYPE = LIRInstructionClass.create(AMD64ArrayCompareToOp.class);
private final JavaKind kind1;
private final JavaKind kind2;
private final int array1BaseOffset;
private final int array2BaseOffset;
private final int useAVX3Threshold;
@Def({REG}) protected Value resultValue;
@Alive({REG}) protected Value array1Value;
@Alive({REG}) protected Value array2Value;
@Use({REG}) protected Value length1Value;
@Use({REG}) protected Value length2Value;
@Temp({REG}) protected Value length1ValueTemp;
@Temp({REG}) protected Value length2ValueTemp;
@Temp({REG}) protected Value temp1;
@Temp({REG}) protected Value temp2;
@Temp({REG, ILLEGAL}) protected Value vectorTemp1;
public AMD64ArrayCompareToOp(LIRGeneratorTool tool, int useAVX3Threshold, JavaKind kind1, JavaKind kind2, Value result, Value array1, Value array2, Value length1, Value length2) {
super(TYPE);
assert CodeUtil.isPowerOf2(useAVX3Threshold) : "AVX3Threshold must be power of 2";
this.useAVX3Threshold = useAVX3Threshold;
this.kind1 = kind1;
this.kind2 = kind2;
this.array1BaseOffset = tool.getProviders().getMetaAccess().getArrayBaseOffset(kind1);
this.array2BaseOffset = tool.getProviders().getMetaAccess().getArrayBaseOffset(kind2);
this.resultValue = result;
this.array1Value = array1;
this.array2Value = array2;
this.length1Value = length1;
this.length2Value = length2;
this.length1ValueTemp = length1;
this.length2ValueTemp = length2;
this.temp1 = tool.newVariable(LIRKind.unknownReference(tool.target().arch.getWordKind()));
this.temp2 = tool.newVariable(LIRKind.unknownReference(tool.target().arch.getWordKind()));
if (supportsSSE42(tool.target())) {
this.vectorTemp1 = tool.newVariable(LIRKind.value(AMD64Kind.DOUBLE));
} else {
this.vectorTemp1 = Value.ILLEGAL;
}
}
private static boolean supportsSSE42(TargetDescription target) {
AMD64 arch = (AMD64) target.arch;
return arch.getFeatures().contains(CPUFeature.SSE4_2);
}
private static boolean supportsAVX2(TargetDescription target) {
AMD64 arch = (AMD64) target.arch;
return arch.getFeatures().contains(CPUFeature.AVX2);
}
private static boolean supportsAVX512VLBW(TargetDescription target) {
EnumSet<CPUFeature> features = ((AMD64) target.arch).getFeatures();
return features.contains(CPUFeature.AVX512BW) && features.contains(CPUFeature.AVX512VL);
}
@Override
public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
Register result = asRegister(resultValue);
Register str1 = asRegister(temp1);
Register str2 = asRegister(temp2);
masm.leaq(str1, new AMD64Address(asRegister(array1Value), array1BaseOffset));
masm.leaq(str2, new AMD64Address(asRegister(array2Value), array2BaseOffset));
Register cnt1 = asRegister(length1Value);
Register cnt2 = asRegister(length2Value);
Label LENGTH_DIFF_LABEL = new Label();
Label POP_LABEL = new Label();
Label DONE_LABEL = new Label();
Label WHILE_HEAD_LABEL = new Label();
Label COMPARE_WIDE_VECTORS_LOOP_FAILED = new Label();
int stride, stride2;
int adr_stride = -1;
int adr_stride1 = -1;
int adr_stride2 = -1;
int stride2x2 = 0x40;
AMD64Address.Scale scale = null;
AMD64Address.Scale scale1 = null;
AMD64Address.Scale scale2 = null;
if (!(kind1 == JavaKind.Byte && kind2 == JavaKind.Byte)) {
stride2x2 = 0x20;
}
if (kind1 != kind2) {
masm.shrl(cnt2, 1);
}
masm.movl(result, cnt1);
masm.subl(cnt1, cnt2);
masm.push(cnt1);
masm.cmovl(ConditionFlag.LessEqual, cnt2, result);
masm.testlAndJcc(cnt2, cnt2, ConditionFlag.Zero, LENGTH_DIFF_LABEL, false);
if (kind1 == JavaKind.Byte && kind2 == JavaKind.Byte) {
masm.movzbl(result, new AMD64Address(str1, 0));
masm.movzbl(cnt1, new AMD64Address(str2, 0));
} else if (kind1 == JavaKind.Char && kind2 == JavaKind.Char) {
masm.movzwl(result, new AMD64Address(str1, 0));
masm.movzwl(cnt1, new AMD64Address(str2, 0));
} else {
masm.movzbl(result, new AMD64Address(str1, 0));
masm.movzwl(cnt1, new AMD64Address(str2, 0));
}
masm.sublAndJcc(result, cnt1, ConditionFlag.NotZero, POP_LABEL, false);
if (kind1 == JavaKind.Char && kind2 == JavaKind.Char) {
masm.shrl(cnt2, 1);
}
masm.cmplAndJcc(cnt2, 1, ConditionFlag.Equal, LENGTH_DIFF_LABEL, false);
if (kind1 == kind2) {
masm.cmpqAndJcc(str1, str2, ConditionFlag.Equal, LENGTH_DIFF_LABEL, false);
if (kind1 == JavaKind.Byte && kind2 == JavaKind.Byte) {
scale = AMD64Address.Scale.Times1;
stride = 16;
} else {
scale = AMD64Address.Scale.Times2;
stride = 8;
}
} else {
scale1 = AMD64Address.Scale.Times1;
scale2 = AMD64Address.Scale.Times2;
stride = 8;
}
if (supportsAVX2(crb.target) && supportsSSE42(crb.target)) {
Register vec1 = asRegister(vectorTemp1, AMD64Kind.DOUBLE);
Label COMPARE_WIDE_VECTORS = new Label();
Label VECTOR_NOT_EQUAL = new Label();
Label COMPARE_WIDE_TAIL = new Label();
Label COMPARE_SMALL_STR = new Label();
Label COMPARE_WIDE_VECTORS_LOOP = new Label();
Label COMPARE_16_CHARS = new Label();
Label COMPARE_INDEX_CHAR = new Label();
Label COMPARE_WIDE_VECTORS_LOOP_AVX2 = new Label();
Label COMPARE_TAIL_LONG = new Label();
Label COMPARE_WIDE_VECTORS_LOOP_AVX3 = new Label();
int pcmpmask = 0x19;
if (kind1 == JavaKind.Byte && kind2 == JavaKind.Byte) {
pcmpmask &= ~0x01;
}
if (kind1 == JavaKind.Byte && kind2 == JavaKind.Byte) {
stride2 = 32;
} else {
stride2 = 16;
}
if (kind1 == kind2) {
adr_stride = stride << scale.log2;
} else {
adr_stride1 = 8;
adr_stride2 = 16;
}
assert result.equals(rax) && cnt2.equals(rdx) && cnt1.equals(rcx) : "pcmpestri";
masm.movl(result, cnt2);
masm.andlAndJcc(cnt2, ~(stride2 - 1), ConditionFlag.Zero, COMPARE_TAIL_LONG, false);
masm.bind(COMPARE_16_CHARS);
if (kind1 == kind2) {
masm.movdqu(vec1, new AMD64Address(str1, 0));
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, 0));
}
masm.pcmpestri(vec1, new AMD64Address(str2, 0), pcmpmask);
masm.jccb(ConditionFlag.Below, COMPARE_INDEX_CHAR);
if (kind1 == kind2) {
masm.movdqu(vec1, new AMD64Address(str1, adr_stride));
masm.pcmpestri(vec1, new AMD64Address(str2, adr_stride), pcmpmask);
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, adr_stride1));
masm.pcmpestri(vec1, new AMD64Address(str2, adr_stride2), pcmpmask);
}
masm.jccb(ConditionFlag.AboveEqual, COMPARE_WIDE_VECTORS);
masm.addl(cnt1, stride);
masm.bind(COMPARE_INDEX_CHAR);
loadNextElements(masm, result, cnt2, str1, str2, scale, scale1, scale2, cnt1);
masm.subl(result, cnt2);
masm.jmp(POP_LABEL);
masm.bind(COMPARE_WIDE_VECTORS);
if (kind1 == kind2) {
masm.leaq(str1, new AMD64Address(str1, result, scale));
masm.leaq(str2, new AMD64Address(str2, result, scale));
} else {
masm.leaq(str1, new AMD64Address(str1, result, scale1));
masm.leaq(str2, new AMD64Address(str2, result, scale2));
}
masm.subl(result, stride2);
masm.sublAndJcc(cnt2, stride2, ConditionFlag.Zero, COMPARE_WIDE_TAIL, false);
masm.negq(result);
masm.bind(COMPARE_WIDE_VECTORS_LOOP);
if (useAVX3Threshold == 0 && supportsAVX512VLBW(crb.target)) {
masm.cmplAndJcc(cnt2, stride2x2, ConditionFlag.Below, COMPARE_WIDE_VECTORS_LOOP_AVX2, true);
masm.testlAndJcc(cnt2, stride2x2 - 1, ConditionFlag.NotZero, COMPARE_WIDE_VECTORS_LOOP_AVX2, true);
masm.bind(COMPARE_WIDE_VECTORS_LOOP_AVX3);
if (kind1 == kind2) {
masm.evmovdqu64(vec1, new AMD64Address(str1, result, scale));
masm.evpcmpeqb(k7, vec1, new AMD64Address(str2, result, scale));
} else {
masm.evpmovzxbw(vec1, new AMD64Address(str1, result, scale1));
masm.evpcmpeqb(k7, vec1, new AMD64Address(str2, result, scale2));
}
masm.kortestq(k7, k7);
masm.jcc(ConditionFlag.AboveEqual, COMPARE_WIDE_VECTORS_LOOP_FAILED);
masm.addq(result, stride2x2);
masm.sublAndJcc(cnt2, stride2x2, ConditionFlag.NotZero, COMPARE_WIDE_VECTORS_LOOP_AVX3, true);
masm.vpxor(vec1, vec1, vec1);
masm.jmpb(COMPARE_WIDE_TAIL);
}
masm.bind(COMPARE_WIDE_VECTORS_LOOP_AVX2);
if (kind1 == kind2) {
masm.vmovdqu(vec1, new AMD64Address(str1, result, scale));
masm.vpxor(vec1, vec1, new AMD64Address(str2, result, scale));
} else {
masm.vpmovzxbw(vec1, new AMD64Address(str1, result, scale1));
masm.vpxor(vec1, vec1, new AMD64Address(str2, result, scale2));
}
masm.vptest(vec1, vec1);
masm.jcc(ConditionFlag.NotZero, VECTOR_NOT_EQUAL);
masm.addq(result, stride2);
masm.sublAndJcc(cnt2, stride2, ConditionFlag.NotZero, COMPARE_WIDE_VECTORS_LOOP, false);
masm.vpxor(vec1, vec1, vec1);
masm.bind(COMPARE_WIDE_TAIL);
masm.testqAndJcc(result, result, ConditionFlag.Zero, LENGTH_DIFF_LABEL, false);
masm.movl(result, stride2);
masm.movl(cnt2, result);
masm.negq(result);
masm.jmp(COMPARE_WIDE_VECTORS_LOOP_AVX2);
masm.bind(VECTOR_NOT_EQUAL);
masm.vpxor(vec1, vec1, vec1);
if (kind1 == kind2) {
masm.leaq(str1, new AMD64Address(str1, result, scale));
masm.leaq(str2, new AMD64Address(str2, result, scale));
} else {
masm.leaq(str1, new AMD64Address(str1, result, scale1));
masm.leaq(str2, new AMD64Address(str2, result, scale2));
}
masm.jmp(COMPARE_16_CHARS);
masm.bind(COMPARE_TAIL_LONG);
masm.movl(cnt2, result);
masm.cmplAndJcc(cnt2, stride, ConditionFlag.Less, COMPARE_SMALL_STR, false);
if (kind1 == kind2) {
masm.movdqu(vec1, new AMD64Address(str1, 0));
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, 0));
}
masm.pcmpestri(vec1, new AMD64Address(str2, 0), pcmpmask);
masm.jcc(ConditionFlag.Below, COMPARE_INDEX_CHAR);
masm.subqAndJcc(cnt2, stride, ConditionFlag.Zero, LENGTH_DIFF_LABEL, false);
if (kind1 == kind2) {
masm.leaq(str1, new AMD64Address(str1, result, scale));
masm.leaq(str2, new AMD64Address(str2, result, scale));
} else {
masm.leaq(str1, new AMD64Address(str1, result, scale1));
masm.leaq(str2, new AMD64Address(str2, result, scale2));
}
masm.negq(cnt2);
masm.jmpb(WHILE_HEAD_LABEL);
masm.bind(COMPARE_SMALL_STR);
} else if (supportsSSE42(crb.target)) {
Register vec1 = asRegister(vectorTemp1, AMD64Kind.DOUBLE);
Label COMPARE_WIDE_VECTORS = new Label();
Label VECTOR_NOT_EQUAL = new Label();
Label COMPARE_TAIL = new Label();
int pcmpmask = 0x19;
masm.movl(result, cnt2);
if (kind1 == JavaKind.Byte && kind2 == JavaKind.Byte) {
pcmpmask &= ~0x01;
}
masm.andlAndJcc(cnt2, ~(stride - 1), ConditionFlag.Zero, COMPARE_TAIL, false);
if (kind1 == kind2) {
masm.leaq(str1, new AMD64Address(str1, result, scale));
masm.leaq(str2, new AMD64Address(str2, result, scale));
} else {
masm.leaq(str1, new AMD64Address(str1, result, scale1));
masm.leaq(str2, new AMD64Address(str2, result, scale2));
}
masm.negq(result);
assert result.equals(rax) && cnt2.equals(rdx) && cnt1.equals(rcx) : "pcmpestri";
masm.bind(COMPARE_WIDE_VECTORS);
if (kind1 == kind2) {
masm.movdqu(vec1, new AMD64Address(str1, result, scale));
masm.pcmpestri(vec1, new AMD64Address(str2, result, scale), pcmpmask);
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, result, scale1));
masm.pcmpestri(vec1, new AMD64Address(str2, result, scale2), pcmpmask);
}
masm.jccb(ConditionFlag.Below, VECTOR_NOT_EQUAL);
masm.addq(result, stride);
masm.subqAndJcc(cnt2, stride, ConditionFlag.NotZero, COMPARE_WIDE_VECTORS, true);
masm.testqAndJcc(result, result, ConditionFlag.Zero, LENGTH_DIFF_LABEL, false);
masm.movl(cnt2, stride);
masm.movl(result, stride);
masm.negq(result);
if (kind1 == kind2) {
masm.movdqu(vec1, new AMD64Address(str1, result, scale));
masm.pcmpestri(vec1, new AMD64Address(str2, result, scale), pcmpmask);
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, result, scale1));
masm.pcmpestri(vec1, new AMD64Address(str2, result, scale2), pcmpmask);
}
masm.jccb(ConditionFlag.AboveEqual, LENGTH_DIFF_LABEL);
masm.bind(VECTOR_NOT_EQUAL);
masm.addq(cnt1, result);
loadNextElements(masm, result, cnt2, str1, str2, scale, scale1, scale2, cnt1);
masm.subl(result, cnt2);
masm.jmpb(POP_LABEL);
masm.bind(COMPARE_TAIL);
masm.movl(cnt2, result);
}
if (kind1 == kind2) {
masm.leaq(str1, new AMD64Address(str1, cnt2, scale));
masm.leaq(str2, new AMD64Address(str2, cnt2, scale));
} else {
masm.leaq(str1, new AMD64Address(str1, cnt2, scale1));
masm.leaq(str2, new AMD64Address(str2, cnt2, scale2));
}
masm.decrementl(cnt2);
masm.negq(cnt2);
masm.bind(WHILE_HEAD_LABEL);
loadNextElements(masm, result, cnt1, str1, str2, scale, scale1, scale2, cnt2);
masm.sublAndJcc(result, cnt1, ConditionFlag.NotZero, POP_LABEL, true);
masm.incqAndJcc(cnt2, ConditionFlag.NotZero, WHILE_HEAD_LABEL, true);
masm.bind(LENGTH_DIFF_LABEL);
masm.pop(result);
if (kind1 == JavaKind.Char && kind2 == JavaKind.Char) {
masm.sarl(result, 1);
}
masm.jmpb(DONE_LABEL);
if (supportsAVX512VLBW(crb.target)) {
masm.bind(COMPARE_WIDE_VECTORS_LOOP_FAILED);
masm.kmovq(cnt1, k7);
masm.notq(cnt1);
masm.bsfq(cnt2, cnt1);
if (!(kind1 == JavaKind.Byte && kind2 == JavaKind.Byte)) {
masm.sarl(cnt2, 1);
}
masm.addq(result, cnt2);
if (kind1 == JavaKind.Byte && kind2 == JavaKind.Byte) {
masm.movzbl(cnt1, new AMD64Address(str2, result, Scale.Times1));
masm.movzbl(result, new AMD64Address(str1, result, Scale.Times1));
} else if (kind1 == JavaKind.Char && kind2 == JavaKind.Char) {
masm.movzwl(cnt1, new AMD64Address(str2, result, scale));
masm.movzwl(result, new AMD64Address(str1, result, scale));
} else {
masm.movzwl(cnt1, new AMD64Address(str2, result, scale2));
masm.movzbl(result, new AMD64Address(str1, result, scale1));
}
masm.subl(result, cnt1);
masm.jmpb(POP_LABEL);
}
masm.bind(POP_LABEL);
masm.pop(cnt1);
masm.bind(DONE_LABEL);
if (kind1 == JavaKind.Char && kind2 == JavaKind.Byte) {
masm.negl(result);
}
}
private void loadNextElements(AMD64MacroAssembler masm, Register elem1, Register elem2, Register str1, Register str2,
AMD64Address.Scale scale, AMD64Address.Scale scale1,
AMD64Address.Scale scale2, Register index) {
if (kind1 == JavaKind.Byte && kind2 == JavaKind.Byte) {
masm.movzbl(elem1, new AMD64Address(str1, index, scale, 0));
masm.movzbl(elem2, new AMD64Address(str2, index, scale, 0));
} else if (kind1 == JavaKind.Char && kind2 == JavaKind.Char) {
masm.movzwl(elem1, new AMD64Address(str1, index, scale, 0));
masm.movzwl(elem2, new AMD64Address(str2, index, scale, 0));
} else {
masm.movzbl(elem1, new AMD64Address(str1, index, scale1, 0));
masm.movzwl(elem2, new AMD64Address(str2, index, scale2, 0));
}
}
@Override
public boolean needsClearUpperVectorRegisters() {
return true;
}
}