package org.graalvm.compiler.replacements.nodes;
import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_UNKNOWN;
import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_UNKNOWN;
import java.lang.invoke.MethodHandle;
import java.util.Arrays;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.core.common.type.StampPair;
import org.graalvm.compiler.core.common.type.TypeReference;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.graph.NodeClass;
import org.graalvm.compiler.graph.spi.Simplifiable;
import org.graalvm.compiler.graph.spi.SimplifierTool;
import org.graalvm.compiler.nodeinfo.NodeInfo;
import org.graalvm.compiler.nodes.CallTargetNode;
import org.graalvm.compiler.nodes.CallTargetNode.InvokeKind;
import org.graalvm.compiler.nodes.FixedGuardNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.GuardNode;
import org.graalvm.compiler.nodes.InvokeNode;
import org.graalvm.compiler.nodes.LogicNode;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.PiNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.extended.AnchoringNode;
import org.graalvm.compiler.nodes.extended.GuardingNode;
import org.graalvm.compiler.nodes.extended.ValueAnchorNode;
import org.graalvm.compiler.nodes.java.InstanceOfNode;
import org.graalvm.compiler.nodes.java.MethodCallTargetNode;
import org.graalvm.compiler.nodes.type.StampTool;
import org.graalvm.compiler.nodes.util.GraphUtil;
import jdk.vm.ci.meta.Assumptions;
import jdk.vm.ci.meta.Assumptions.AssumptionResult;
import jdk.vm.ci.meta.DeoptimizationAction;
import jdk.vm.ci.meta.DeoptimizationReason;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.JavaType;
import jdk.vm.ci.meta.MethodHandleAccessProvider;
import jdk.vm.ci.meta.MethodHandleAccessProvider.IntrinsicMethod;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import jdk.vm.ci.meta.ResolvedJavaType;
import jdk.vm.ci.meta.Signature;
import jdk.vm.ci.meta.SpeculationLog;
import jdk.vm.ci.meta.SpeculationLog.Speculation;
@NodeInfo(cycles = CYCLES_UNKNOWN, size = SIZE_UNKNOWN)
public final class MethodHandleNode extends MacroStateSplitNode implements Simplifiable {
public static final NodeClass<MethodHandleNode> TYPE = NodeClass.create(MethodHandleNode.class);
protected final IntrinsicMethod intrinsicMethod;
public MethodHandleNode(IntrinsicMethod intrinsicMethod, InvokeKind invokeKind, ResolvedJavaMethod targetMethod, int bci, StampPair returnStamp, ValueNode... arguments) {
super(TYPE, invokeKind, targetMethod, bci, returnStamp, arguments);
this.intrinsicMethod = intrinsicMethod;
}
public static InvokeNode tryResolveTargetInvoke(GraphAdder adder, MethodHandleAccessProvider methodHandleAccess, IntrinsicMethod intrinsicMethod,
ResolvedJavaMethod original, int bci,
StampPair returnStamp, ValueNode... arguments) {
switch (intrinsicMethod) {
case INVOKE_BASIC:
return getInvokeBasicTarget(adder, intrinsicMethod, methodHandleAccess, original, bci, returnStamp, arguments);
case LINK_TO_STATIC:
case LINK_TO_SPECIAL:
case LINK_TO_VIRTUAL:
case LINK_TO_INTERFACE:
return getLinkToTarget(adder, intrinsicMethod, methodHandleAccess, original, bci, returnStamp, arguments);
default:
throw GraalError.shouldNotReachHere();
}
}
public abstract static class GraphAdder {
private final StructuredGraph graph;
public GraphAdder(StructuredGraph graph) {
this.graph = graph;
}
public abstract <T extends ValueNode> T add(T node);
public AnchoringNode getGuardAnchor() {
return null;
}
public Assumptions getAssumptions() {
return graph.getAssumptions();
}
}
@Override
public void simplify(SimplifierTool tool) {
MethodHandleAccessProvider methodHandleAccess = tool.getConstantReflection().getMethodHandleAccess();
ValueNode[] argumentsArray = arguments.toArray(new ValueNode[arguments.size()]);
final FixedNode before = this;
GraphAdder adder = new GraphAdder(graph()) {
@Override
public <T extends ValueNode> T add(T node) {
T added = graph().addOrUnique(node);
if (added instanceof FixedWithNextNode) {
graph().addBeforeFixed(before, (FixedWithNextNode) added);
}
return added;
}
};
InvokeNode invoke = tryResolveTargetInvoke(adder, methodHandleAccess, intrinsicMethod, targetMethod, bci, returnStamp, argumentsArray);
if (invoke != null) {
assert invoke.graph() == null;
invoke = graph().addOrUniqueWithInputs(invoke);
invoke.setStateAfter(stateAfter());
FixedNode currentNext = next();
replaceAtUsages(invoke);
GraphUtil.removeFixedWithUnusedInputs(this);
graph().addBeforeFixed(currentNext, invoke);
}
}
private static ValueNode getReceiver(ValueNode[] arguments) {
return arguments[0];
}
private static ValueNode getMemberName(ValueNode[] arguments) {
return arguments[arguments.length - 1];
}
private static InvokeNode getInvokeBasicTarget(GraphAdder adder, IntrinsicMethod intrinsicMethod, MethodHandleAccessProvider methodHandleAccess,
ResolvedJavaMethod original,
int bci,
StampPair returnStamp, ValueNode[] arguments) {
ValueNode methodHandleNode = getReceiver(arguments);
if (methodHandleNode.isConstant()) {
return getTargetInvokeNode(adder, intrinsicMethod, bci, returnStamp, arguments, methodHandleAccess.resolveInvokeBasicTarget(methodHandleNode.asJavaConstant(), true), original);
}
return null;
}
private static InvokeNode getLinkToTarget(GraphAdder adder, IntrinsicMethod intrinsicMethod, MethodHandleAccessProvider methodHandleAccess,
ResolvedJavaMethod original,
int bci,
StampPair returnStamp, ValueNode[] arguments) {
ValueNode memberNameNode = getMemberName(arguments);
if (memberNameNode.isConstant()) {
return getTargetInvokeNode(adder, intrinsicMethod, bci, returnStamp, arguments, methodHandleAccess.resolveLinkToTarget(memberNameNode.asJavaConstant()), original);
}
return null;
}
private static InvokeNode getTargetInvokeNode(GraphAdder adder, IntrinsicMethod intrinsicMethod, int bci, StampPair returnStamp, ValueNode[] originalArguments, ResolvedJavaMethod target,
ResolvedJavaMethod original) {
if (target == null) {
return null;
}
Signature signature = target.getSignature();
final boolean isStatic = target.isStatic();
final int receiverSkip = isStatic ? 0 : 1;
Assumptions assumptions = adder.getAssumptions();
ResolvedJavaMethod realTarget = null;
if (target.canBeStaticallyBound()) {
realTarget = target;
} else {
ResolvedJavaType targetType = target.getDeclaringClass();
AssumptionResult<ResolvedJavaMethod> concreteMethod = targetType.findUniqueConcreteMethod(target);
if (concreteMethod == null) {
if (intrinsicMethod == IntrinsicMethod.LINK_TO_VIRTUAL || intrinsicMethod == IntrinsicMethod.LINK_TO_INTERFACE) {
ValueNode receiver = getReceiver(originalArguments);
TypeReference receiverType = StampTool.typeReferenceOrNull(receiver.stamp(NodeView.DEFAULT));
if (receiverType != null) {
concreteMethod = receiverType.getType().findUniqueConcreteMethod(target);
}
}
}
if (concreteMethod != null && concreteMethod.canRecordTo(assumptions)) {
concreteMethod.recordTo(assumptions);
realTarget = concreteMethod.getResult();
}
}
if (realTarget != null) {
ValueNode[] arguments = originalArguments.clone();
if (!isStatic) {
JavaType receiverType = target.getDeclaringClass();
maybeCastArgument(adder, arguments, 0, receiverType);
}
for (int index = 0; index < signature.getParameterCount(false); index++) {
JavaType parameterType = signature.getParameterType(index, target.getDeclaringClass());
maybeCastArgument(adder, arguments, receiverSkip + index, parameterType);
}
InvokeNode invoke = createTargetInvokeNode(assumptions, intrinsicMethod, realTarget, original, bci, returnStamp, arguments);
assert invoke != null : "graph has been modified so this must result an invoke";
return invoke;
}
return null;
}
private static void maybeCastArgument(GraphAdder adder, ValueNode[] arguments, int index, JavaType type) {
ValueNode argument = arguments[index];
if (type instanceof ResolvedJavaType && !((ResolvedJavaType) type).isJavaLangObject()) {
Assumptions assumptions = adder.getAssumptions();
TypeReference targetType = TypeReference.create(assumptions, (ResolvedJavaType) type);
if (targetType != null && !targetType.getType().isPrimitive() && !argument.getStackKind().isPrimitive()) {
ResolvedJavaType argumentType = StampTool.typeOrNull(argument.stamp(NodeView.DEFAULT));
if (argumentType == null || (argumentType.isAssignableFrom(targetType.getType()) && !argumentType.equals(targetType.getType()))) {
LogicNode inst = InstanceOfNode.createAllowNull(targetType, argument, null, null);
assert !inst.isAlive();
if (!inst.isTautology()) {
inst = adder.add(inst);
AnchoringNode guardAnchor = adder.getGuardAnchor();
DeoptimizationReason reason = DeoptimizationReason.ClassCastException;
DeoptimizationAction action = DeoptimizationAction.InvalidateRecompile;
Speculation speculation = SpeculationLog.NO_SPECULATION;
GuardingNode guard;
if (guardAnchor == null) {
FixedGuardNode fixedGuard = adder.add(new FixedGuardNode(inst, reason, action, speculation, false));
guard = fixedGuard;
} else {
GuardNode newGuard = adder.add(new GuardNode(inst, guardAnchor, reason, action, false, speculation, null));
adder.add(new ValueAnchorNode(newGuard));
guard = newGuard;
}
ValueNode valueNode = adder.add(PiNode.create(argument, StampFactory.object(targetType), guard.asNode()));
arguments[index] = valueNode;
}
}
}
}
}
private static InvokeNode createTargetInvokeNode(Assumptions assumptions, IntrinsicMethod intrinsicMethod, ResolvedJavaMethod target, ResolvedJavaMethod original, int bci, StampPair returnStamp,
ValueNode[] arguments) {
InvokeKind targetInvokeKind = target.isStatic() ? InvokeKind.Static : InvokeKind.Special;
JavaType targetReturnType = target.getSignature().getReturnType(null);
ValueNode[] targetArguments;
switch (intrinsicMethod) {
case INVOKE_BASIC:
targetArguments = arguments;
break;
case LINK_TO_STATIC:
case LINK_TO_SPECIAL:
case LINK_TO_VIRTUAL:
case LINK_TO_INTERFACE:
targetArguments = Arrays.copyOfRange(arguments, 0, arguments.length - 1);
break;
default:
throw GraalError.shouldNotReachHere();
}
StampPair targetReturnStamp = StampFactory.forDeclaredType(assumptions, targetReturnType, false);
MethodCallTargetNode callTarget = ResolvedMethodHandleCallTargetNode.create(targetInvokeKind, target, targetArguments, targetReturnStamp, original, arguments, returnStamp);
if (returnStamp.getTrustedStamp().getStackKind() == JavaKind.Void) {
return new InvokeNode(callTarget, bci, StampFactory.forVoid());
} else {
return new InvokeNode(callTarget, bci);
}
}
}