package com.oracle.truffle.llvm.runtime.nodes.control;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeField;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.instrumentation.GenerateWrapper;
import com.oracle.truffle.api.instrumentation.ProbeNode;
import com.oracle.truffle.llvm.runtime.LLVMIVarBit;
import com.oracle.truffle.llvm.runtime.floating.LLVM80BitFloat;
import com.oracle.truffle.llvm.runtime.memory.LLVMMemMoveNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMControlFlowNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMExpressionNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMStatementNode;
import com.oracle.truffle.llvm.runtime.nodes.base.LLVMBasicBlockNode;
import com.oracle.truffle.llvm.runtime.nodes.func.LLVMArgNode;
import com.oracle.truffle.llvm.runtime.nodes.func.LLVMArgNodeGen;
import com.oracle.truffle.llvm.runtime.pointer.LLVMNativePointer;
import com.oracle.truffle.llvm.runtime.pointer.LLVMPointer;
import com.oracle.truffle.llvm.runtime.vector.LLVMDoubleVector;
import com.oracle.truffle.llvm.runtime.vector.LLVMFloatVector;
import com.oracle.truffle.llvm.runtime.vector.LLVMI16Vector;
import com.oracle.truffle.llvm.runtime.vector.LLVMI1Vector;
import com.oracle.truffle.llvm.runtime.vector.LLVMI32Vector;
import com.oracle.truffle.llvm.runtime.vector.LLVMI64Vector;
import com.oracle.truffle.llvm.runtime.vector.LLVMI8Vector;
@GenerateWrapper
public abstract class LLVMRetNode extends LLVMControlFlowNode {
@Override
public WrapperNode createWrapper(ProbeNode probe) {
return new LLVMRetNodeWrapper(this, probe);
}
@GenerateWrapper.OutgoingConverter
Object convertOutgoing(@SuppressWarnings("unused") Object object) {
return null;
}
@Override
public int getSuccessorCount() {
return 1;
}
@Override
public final int[] getSuccessors() {
return new int[]{LLVMBasicBlockNode.RETURN_FROM_FUNCTION};
}
public int getSuccessor() {
return LLVMBasicBlockNode.RETURN_FROM_FUNCTION;
}
@Override
public LLVMStatementNode getPhiNode(int successorIndex) {
assert successorIndex == 0;
return null;
}
public abstract Object execute(VirtualFrame frame);
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMI1RetNode extends LLVMRetNode {
@Specialization
protected Object doOp(boolean retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMI8RetNode extends LLVMRetNode {
@Specialization
protected Object doOp(byte retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMI16RetNode extends LLVMRetNode {
@Specialization
protected Object doOp(short retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMI32RetNode extends LLVMRetNode {
@Specialization
protected Object doOp(int retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMI64RetNode extends LLVMRetNode {
@Specialization
protected Object doOp(long retResult) {
return retResult;
}
@Specialization
protected Object doOp(LLVMPointer retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMIVarBitRetNode extends LLVMRetNode {
@Specialization
protected Object doOp(LLVMIVarBit retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMFloatRetNode extends LLVMRetNode {
@Specialization
protected Object doOp(float retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMDoubleRetNode extends LLVMRetNode {
@Specialization
protected Object doOp(double retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVM80BitFloatRetNode extends LLVMRetNode {
@Specialization
protected Object doOp(LLVM80BitFloat retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMAddressRetNode extends LLVMRetNode {
@Specialization
protected Object doOp(Object retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
public abstract static class LLVMVectorRetNode extends LLVMRetNode {
@Specialization
protected Object doOp(LLVMDoubleVector retResult) {
return retResult;
}
@Specialization
protected Object doOp(LLVMFloatVector retResult) {
return retResult;
}
@Specialization
protected Object doOp(LLVMI16Vector retResult) {
return retResult;
}
@Specialization
protected Object doOp(LLVMI1Vector retResult) {
return retResult;
}
@Specialization
protected Object doOp(LLVMI32Vector retResult) {
return retResult;
}
@Specialization
protected Object doOp(LLVMI64Vector retResult) {
return retResult;
}
@Specialization
protected Object doOp(LLVMI8Vector retResult) {
return retResult;
}
}
@NodeChild(value = "retResult", type = LLVMExpressionNode.class)
@NodeField(name = "structSize", type = long.class)
public abstract static class LLVMStructRetNode extends LLVMRetNode {
@Child private LLVMArgNode argIdx1 = LLVMArgNodeGen.create(1);
@Child private LLVMMemMoveNode memMove;
public abstract long getStructSize();
public LLVMStructRetNode(LLVMMemMoveNode memMove) {
this.memMove = memMove;
}
@Specialization
protected Object doOp(VirtualFrame frame, LLVMPointer retResult) {
return returnStruct(frame, retResult);
}
private Object returnStruct(VirtualFrame frame, Object retResult) {
Object retStructAddress = argIdx1.executeGeneric(frame);
memMove.executeWithTarget(retStructAddress, retResult, getStructSize());
return retStructAddress;
}
}
public abstract static class LLVMVoidReturnNode extends LLVMRetNode {
@Specialization
protected Object doOp() {
return LLVMNativePointer.createNull();
}
}
}