package org.graalvm.compiler.core.aarch64;
import org.graalvm.compiler.asm.aarch64.AArch64Address;
import org.graalvm.compiler.core.common.LIRKind;
import org.graalvm.compiler.core.common.NumUtil;
import org.graalvm.compiler.core.common.type.Stamp;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.AddNode;
import org.graalvm.compiler.nodes.memory.address.AddressNode;
import org.graalvm.compiler.nodes.memory.address.OffsetAddressNode;
import org.graalvm.compiler.phases.common.AddressLoweringByUsePhase;
import jdk.vm.ci.aarch64.AArch64Kind;
import jdk.vm.ci.meta.JavaConstant;
public class AArch64AddressLoweringByUse extends AddressLoweringByUsePhase.AddressLoweringByUse {
private AArch64LIRKindTool kindtool;
public AArch64AddressLoweringByUse(AArch64LIRKindTool kindtool) {
this.kindtool = kindtool;
}
@Override
public AddressNode lower(ValueNode use, Stamp stamp, AddressNode address) {
if (address instanceof OffsetAddressNode) {
OffsetAddressNode offsetAddress = (OffsetAddressNode) address;
return doLower(stamp, offsetAddress.getBase(), offsetAddress.getOffset());
} else {
return address;
}
}
@Override
public AddressNode lower(AddressNode address) {
return lower(null, null, address);
}
private AddressNode doLower(Stamp stamp, ValueNode base, ValueNode index) {
AArch64AddressNode ret = new AArch64AddressNode(base, index);
AArch64Kind aarch64Kind = (stamp == null ? null : getAArch64Kind(stamp));
boolean changed;
do {
changed = improve(aarch64Kind, ret);
} while (changed);
return base.graph().unique(ret);
}
protected boolean improve(AArch64Kind kind, AArch64AddressNode ret) {
AArch64Address.AddressingMode mode = ret.getAddressingMode();
if (isDisplacementMode(mode) || isBaseOnlyMode(mode)) {
return false;
}
ValueNode base = ret.getBase();
ValueNode index = ret.getIndex();
if (base == null) {
ret.setBase(index);
ret.setIndex(base);
return true;
}
if (base.isJavaConstant() && base.asJavaConstant().getJavaKind().isNumericInteger() &&
index != null && !index.isJavaConstant()) {
ret.setBase(index);
ret.setIndex(base);
return true;
}
if (index == null && base instanceof AddNode) {
AddNode add = (AddNode) base;
ret.setBase(add.getX());
ret.setIndex(add.getY());
return true;
}
if (index != null && index.isJavaConstant()) {
JavaConstant javaConstant = index.asJavaConstant();
if (javaConstant.getJavaKind().isNumericInteger()) {
long disp = javaConstant.asLong();
mode = immediateMode(kind, disp);
if (isDisplacementMode(mode)) {
index = null;
boolean tryNextBase = (base instanceof AddNode);
while (tryNextBase) {
AddNode add = (AddNode) base;
tryNextBase = false;
ValueNode child = add.getX();
if (child.isJavaConstant() && child.asJavaConstant().getJavaKind().isNumericInteger()) {
long newDisp = disp + child.asJavaConstant().asLong();
AArch64Address.AddressingMode newMode = immediateMode(kind, newDisp);
if (newMode != AArch64Address.AddressingMode.REGISTER_OFFSET) {
disp = newDisp;
mode = newMode;
base = add.getY();
ret.setBase(base);
tryNextBase = (base instanceof AddNode);
}
} else {
child = add.getY();
if (child.isJavaConstant() && child.asJavaConstant().getJavaKind().isNumericInteger()) {
long newDisp = disp + child.asJavaConstant().asLong();
AArch64Address.AddressingMode newMode = immediateMode(kind, newDisp);
if (newMode != AArch64Address.AddressingMode.REGISTER_OFFSET) {
disp = newDisp;
mode = newMode;
base = add.getX();
ret.setBase(base);
tryNextBase = (base instanceof AddNode);
}
}
}
}
if (disp != 0) {
ret.setIndex(null);
int scaleFactor = computeScaleFactor(kind, mode);
ret.setDisplacement(disp, scaleFactor, mode);
} else {
ret.setIndex(null);
ret.setDisplacement(0, 1, AArch64Address.AddressingMode.BASE_REGISTER_ONLY);
}
return true;
}
}
}
return false;
}
private AArch64Kind getAArch64Kind(Stamp stamp) {
LIRKind lirKind = stamp.getLIRKind(kindtool);
if (!lirKind.isValue()) {
if (!lirKind.isReference(0) || lirKind.getReferenceCount() != 1) {
return null;
}
}
return (AArch64Kind) lirKind.getPlatformKind();
}
private static AArch64Address.AddressingMode immediateMode(AArch64Kind kind, long value) {
if (kind != null) {
int size = kind.getSizeInBytes();
if ((value & (size - 1)) == 0) {
long encodedValue = value / size;
if (NumUtil.isUnsignedNbit(12, encodedValue)) {
return AArch64Address.AddressingMode.IMMEDIATE_SCALED;
}
}
}
if (NumUtil.isSignedNbit(9, value)) {
return AArch64Address.AddressingMode.IMMEDIATE_UNSCALED;
}
return AArch64Address.AddressingMode.REGISTER_OFFSET;
}
private static int computeScaleFactor(AArch64Kind kind, AArch64Address.AddressingMode mode) {
if (mode == AArch64Address.AddressingMode.IMMEDIATE_SCALED) {
return kind.getSizeInBytes();
}
return 1;
}
boolean isBaseOnlyMode(AArch64Address.AddressingMode addressingMode) {
return addressingMode == AArch64Address.AddressingMode.BASE_REGISTER_ONLY;
}
private static boolean isDisplacementMode(AArch64Address.AddressingMode addressingMode) {
switch (addressingMode) {
case IMMEDIATE_POST_INDEXED:
case IMMEDIATE_PRE_INDEXED:
case IMMEDIATE_SCALED:
case IMMEDIATE_UNSCALED:
return true;
}
return false;
}
}