package org.graalvm.compiler.replacements.nodes;
import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_UNKNOWN;
import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_UNKNOWN;
import java.lang.invoke.MethodHandle;
import java.util.Arrays;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.core.common.type.StampPair;
import org.graalvm.compiler.core.common.type.TypeReference;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.graph.NodeClass;
import org.graalvm.compiler.graph.spi.Simplifiable;
import org.graalvm.compiler.graph.spi.SimplifierTool;
import org.graalvm.compiler.nodeinfo.NodeInfo;
import org.graalvm.compiler.nodes.CallTargetNode;
import org.graalvm.compiler.nodes.CallTargetNode.InvokeKind;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.InvokeNode;
import org.graalvm.compiler.nodes.PiNode;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.java.MethodCallTargetNode;
import org.graalvm.compiler.nodes.type.StampTool;
import org.graalvm.compiler.nodes.util.GraphUtil;
import jdk.vm.ci.meta.Assumptions;
import jdk.vm.ci.meta.Assumptions.AssumptionResult;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.JavaType;
import jdk.vm.ci.meta.MethodHandleAccessProvider;
import jdk.vm.ci.meta.MethodHandleAccessProvider.IntrinsicMethod;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import jdk.vm.ci.meta.ResolvedJavaType;
import jdk.vm.ci.meta.Signature;
@NodeInfo(cycles = CYCLES_UNKNOWN, size = SIZE_UNKNOWN)
public final class MethodHandleNode extends MacroStateSplitNode implements Simplifiable {
public static final NodeClass<MethodHandleNode> TYPE = NodeClass.create(MethodHandleNode.class);
protected final IntrinsicMethod intrinsicMethod;
public MethodHandleNode(IntrinsicMethod intrinsicMethod, InvokeKind invokeKind, ResolvedJavaMethod targetMethod, int bci, StampPair returnStamp, ValueNode... arguments) {
super(TYPE, invokeKind, targetMethod, bci, returnStamp, arguments);
this.intrinsicMethod = intrinsicMethod;
}
public static InvokeNode tryResolveTargetInvoke(Assumptions assumptions, MethodHandleAccessProvider methodHandleAccess, IntrinsicMethod intrinsicMethod, ResolvedJavaMethod original, int bci,
StampPair returnStamp, ValueNode... arguments) {
switch (intrinsicMethod) {
case INVOKE_BASIC:
return getInvokeBasicTarget(assumptions, intrinsicMethod, methodHandleAccess, original, bci, returnStamp, arguments);
case LINK_TO_STATIC:
case LINK_TO_SPECIAL:
case LINK_TO_VIRTUAL:
case LINK_TO_INTERFACE:
return getLinkToTarget(assumptions, intrinsicMethod, methodHandleAccess, original, bci, returnStamp, arguments);
default:
throw GraalError.shouldNotReachHere();
}
}
@Override
public void simplify(SimplifierTool tool) {
MethodHandleAccessProvider methodHandleAccess = tool.getConstantReflection().getMethodHandleAccess();
ValueNode[] argumentsArray = arguments.toArray(new ValueNode[arguments.size()]);
InvokeNode invoke = tryResolveTargetInvoke(graph().getAssumptions(), methodHandleAccess, intrinsicMethod, targetMethod, bci, returnStamp, argumentsArray);
if (invoke != null) {
assert invoke.graph() == null;
invoke = graph().addOrUniqueWithInputs(invoke);
invoke.setStateAfter(stateAfter());
FixedNode currentNext = next();
replaceAtUsages(invoke);
GraphUtil.removeFixedWithUnusedInputs(this);
graph().addBeforeFixed(currentNext, invoke);
}
}
private static ValueNode getReceiver(ValueNode[] arguments) {
return arguments[0];
}
private static ValueNode getMemberName(ValueNode[] arguments) {
return arguments[arguments.length - 1];
}
private static InvokeNode getInvokeBasicTarget(Assumptions assumptions, IntrinsicMethod intrinsicMethod, MethodHandleAccessProvider methodHandleAccess, ResolvedJavaMethod original, int bci,
StampPair returnStamp, ValueNode[] arguments) {
ValueNode methodHandleNode = getReceiver(arguments);
if (methodHandleNode.isConstant()) {
return getTargetInvokeNode(assumptions, intrinsicMethod, bci, returnStamp, arguments, methodHandleAccess.resolveInvokeBasicTarget(methodHandleNode.asJavaConstant(), true), original);
}
return null;
}
private static InvokeNode getLinkToTarget(Assumptions assumptions, IntrinsicMethod intrinsicMethod, MethodHandleAccessProvider methodHandleAccess, ResolvedJavaMethod original, int bci,
StampPair returnStamp, ValueNode[] arguments) {
ValueNode memberNameNode = getMemberName(arguments);
if (memberNameNode.isConstant()) {
return getTargetInvokeNode(assumptions, intrinsicMethod, bci, returnStamp, arguments, methodHandleAccess.resolveLinkToTarget(memberNameNode.asJavaConstant()), original);
}
return null;
}
private static InvokeNode getTargetInvokeNode(Assumptions assumptions, IntrinsicMethod intrinsicMethod, int bci, StampPair returnStamp, ValueNode[] originalArguments, ResolvedJavaMethod target,
ResolvedJavaMethod original) {
if (target == null) {
return null;
}
Signature signature = target.getSignature();
final boolean isStatic = target.isStatic();
final int receiverSkip = isStatic ? 0 : 1;
ValueNode[] arguments = originalArguments.clone();
if (!isStatic) {
JavaType receiverType = target.getDeclaringClass();
maybeCastArgument(assumptions, arguments, 0, receiverType);
}
for (int index = 0; index < signature.getParameterCount(false); index++) {
JavaType parameterType = signature.getParameterType(index, target.getDeclaringClass());
maybeCastArgument(assumptions, arguments, receiverSkip + index, parameterType);
}
if (target.canBeStaticallyBound()) {
return createTargetInvokeNode(assumptions, intrinsicMethod, target, original, bci, returnStamp, arguments);
}
if (intrinsicMethod == IntrinsicMethod.LINK_TO_VIRTUAL || intrinsicMethod == IntrinsicMethod.LINK_TO_INTERFACE) {
ValueNode receiver = getReceiver(arguments);
TypeReference receiverType = StampTool.typeReferenceOrNull(receiver.stamp());
if (receiverType != null) {
AssumptionResult<ResolvedJavaMethod> concreteMethod = receiverType.getType().findUniqueConcreteMethod(target);
if (concreteMethod != null && concreteMethod.canRecordTo(assumptions)) {
concreteMethod.recordTo(assumptions);
return createTargetInvokeNode(assumptions, intrinsicMethod, concreteMethod.getResult(), original, bci, returnStamp, arguments);
}
}
} else {
AssumptionResult<ResolvedJavaMethod> concreteMethod = target.getDeclaringClass().findUniqueConcreteMethod(target);
if (concreteMethod != null && concreteMethod.canRecordTo(assumptions)) {
concreteMethod.recordTo(assumptions);
return createTargetInvokeNode(assumptions, intrinsicMethod, concreteMethod.getResult(), original, bci, returnStamp, arguments);
}
}
return null;
}
private static void maybeCastArgument(Assumptions assumptions, ValueNode[] arguments, int index, JavaType type) {
if (type instanceof ResolvedJavaType) {
TypeReference targetType = TypeReference.create(assumptions, (ResolvedJavaType) type);
ValueNode argument = arguments[index];
if (targetType != null && !targetType.getType().isPrimitive() && !argument.getStackKind().isPrimitive()) {
ResolvedJavaType argumentType = StampTool.typeOrNull(argument.stamp());
if (argumentType == null || (argumentType.isAssignableFrom(targetType.getType()) && !argumentType.equals(targetType.getType()))) {
PiNode piNode = new PiNode(argument, StampFactory.object(targetType));
arguments[index] = piNode;
}
}
}
}
private static InvokeNode createTargetInvokeNode(Assumptions assumptions, IntrinsicMethod intrinsicMethod, ResolvedJavaMethod target, ResolvedJavaMethod original, int bci, StampPair returnStamp,
ValueNode[] arguments) {
InvokeKind targetInvokeKind = target.isStatic() ? InvokeKind.Static : InvokeKind.Special;
JavaType targetReturnType = target.getSignature().getReturnType(null);
ValueNode[] targetArguments;
switch (intrinsicMethod) {
case INVOKE_BASIC:
targetArguments = arguments;
break;
case LINK_TO_STATIC:
case LINK_TO_SPECIAL:
case LINK_TO_VIRTUAL:
case LINK_TO_INTERFACE:
targetArguments = Arrays.copyOfRange(arguments, 0, arguments.length - 1);
break;
default:
throw GraalError.shouldNotReachHere();
}
StampPair targetReturnStamp = StampFactory.forDeclaredType(assumptions, targetReturnType, false);
MethodCallTargetNode callTarget = ResolvedMethodHandleCallTargetNode.create(targetInvokeKind, target, targetArguments, targetReturnStamp, original, arguments, returnStamp);
if (returnStamp.getTrustedStamp().getStackKind() == JavaKind.Void) {
return new InvokeNode(callTarget, bci, StampFactory.forVoid());
} else {
return new InvokeNode(callTarget, bci);
}
}
}