package com.oracle.truffle.llvm.parser.nodes;
import com.oracle.truffle.api.frame.FrameDescriptor;
import com.oracle.truffle.api.frame.FrameSlot;
import com.oracle.truffle.llvm.parser.LLVMParserRuntime;
import com.oracle.truffle.llvm.parser.model.SymbolImpl;
import com.oracle.truffle.llvm.parser.model.symbols.constants.Constant;
import com.oracle.truffle.llvm.parser.model.symbols.constants.NullConstant;
import com.oracle.truffle.llvm.parser.model.symbols.constants.integer.BigIntegerConstant;
import com.oracle.truffle.llvm.parser.model.symbols.constants.integer.IntegerConstant;
import com.oracle.truffle.llvm.runtime.CommonNodeFactory;
import com.oracle.truffle.llvm.runtime.GetStackSpaceFactory;
import com.oracle.truffle.llvm.runtime.NodeFactory;
import com.oracle.truffle.llvm.runtime.datalayout.DataLayout;
import com.oracle.truffle.llvm.runtime.except.LLVMParserException;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMExpressionNode;
import com.oracle.truffle.llvm.runtime.types.Type;
import com.oracle.truffle.llvm.runtime.types.symbols.SSAValue;
public final class LLVMSymbolReadResolver {
private final boolean storeSSAValueInSlot;
private final LLVMParserRuntime runtime;
private final NodeFactory nodeFactory;
private final FrameDescriptor frame;
private final GetStackSpaceFactory getStackSpaceFactory;
private final DataLayout dataLayout;
public LLVMSymbolReadResolver(LLVMParserRuntime runtime, FrameDescriptor frame, GetStackSpaceFactory getStackSpaceFactory, DataLayout dataLayout, boolean storeSSAValueInSlot) {
this.runtime = runtime;
this.storeSSAValueInSlot = storeSSAValueInSlot;
this.nodeFactory = runtime.getNodeFactory();
this.frame = frame;
this.getStackSpaceFactory = getStackSpaceFactory;
this.dataLayout = dataLayout;
}
public FrameSlot findOrAddFrameSlot(FrameDescriptor descriptor, SSAValue value) {
FrameSlot slot = descriptor.findFrameSlot(value.getFrameIdentifier());
Object info = storeSSAValueInSlot ? value : null;
if (slot == null) {
slot = descriptor.findOrAddFrameSlot(value.getFrameIdentifier(), info, Type.getFrameSlotKind(value.getType()));
}
assert slot.getInfo() == info;
return slot;
}
public static Integer evaluateIntegerConstant(SymbolImpl constant) {
if (constant instanceof IntegerConstant) {
assert ((IntegerConstant) constant).getValue() == (int) ((IntegerConstant) constant).getValue();
return (int) ((IntegerConstant) constant).getValue();
} else if (constant instanceof BigIntegerConstant) {
return ((BigIntegerConstant) constant).getValue().intValueExact();
} else if (constant instanceof NullConstant) {
return 0;
} else {
return null;
}
}
public static Long evaluateLongIntegerConstant(SymbolImpl constant) {
if (constant instanceof IntegerConstant) {
return ((IntegerConstant) constant).getValue();
} else if (constant instanceof BigIntegerConstant) {
return ((BigIntegerConstant) constant).getValue().longValueExact();
} else if (constant instanceof NullConstant) {
return 0L;
} else {
return null;
}
}
public interface OptimizedResolver {
LLVMExpressionNode resolve(SymbolImpl symbol, int excludeOtherIndex, SymbolImpl other, SymbolImpl... others);
}
public LLVMExpressionNode resolveElementPointer(SymbolImpl base, SymbolImpl[] indices, OptimizedResolver resolver) {
LLVMExpressionNode[] indexNodes = new LLVMExpressionNode[indices.length];
Long[] indexConstants = new Long[indices.length];
Type[] indexTypes = new Type[indices.length];
for (int i = indices.length - 1; i >= 0; i--) {
SymbolImpl indexSymbol = indices[i];
indexConstants[i] = evaluateLongIntegerConstant(indexSymbol);
indexTypes[i] = indexSymbol.getType();
if (indexConstants[i] == null) {
indexNodes[i] = resolver.resolve(indexSymbol, i, base, indices);
}
}
LLVMExpressionNode currentAddress = resolver.resolve(base, -1, null, indices);
Type currentType = base.getType();
return CommonNodeFactory.createNestedElementPointerNode(nodeFactory, dataLayout, indexNodes, indexConstants, indexTypes, currentAddress, currentType);
}
public LLVMExpressionNode resolve(SymbolImpl symbol) {
if (symbol == null) {
return null;
}
if (symbol instanceof Constant) {
return ((Constant) symbol).createNode(runtime, dataLayout, getStackSpaceFactory);
} else if (symbol instanceof SSAValue) {
SSAValue value = (SSAValue) symbol;
FrameSlot slot = frame.findFrameSlot(value.getFrameIdentifier());
if (slot == null) {
slot = findOrAddFrameSlot(frame, value);
}
return CommonNodeFactory.createFrameRead(value.getType(), slot);
} else {
throw new LLVMParserException("Cannot resolve symbol: " + symbol);
}
}
}