package org.graalvm.wasm.nodes;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.TruffleLanguage;
import com.oracle.truffle.api.TruffleLanguage.ContextReference;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.nodes.RootNode;
import org.graalvm.wasm.WasmType;
import org.graalvm.wasm.WasmCodeEntry;
import org.graalvm.wasm.WasmContext;
import org.graalvm.wasm.WasmInstance;
import org.graalvm.wasm.WasmLanguage;
import org.graalvm.wasm.WasmVoidResult;
import org.graalvm.wasm.exception.Failure;
import org.graalvm.wasm.exception.WasmException;
@NodeInfo(language = "wasm", description = "The root node of all WebAssembly functions")
public class WasmRootNode extends RootNode implements WasmNodeInterface {
protected final WasmInstance instance;
private final WasmCodeEntry codeEntry;
@CompilationFinal private ContextReference<WasmContext> rawContextReference;
@Child private WasmNode body;
public WasmRootNode(TruffleLanguage<?> language, WasmInstance instance, WasmCodeEntry codeEntry) {
super(language);
this.instance = instance;
this.codeEntry = codeEntry;
this.body = null;
}
protected ContextReference<WasmContext> contextReference() {
if (rawContextReference == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
rawContextReference = lookupContextReference(WasmLanguage.class);
}
return rawContextReference;
}
public void setBody(WasmNode body) {
this.body = insert(body);
}
@Override
protected boolean isInstrumentable() {
return false;
}
public void tryInitialize(WasmContext context) {
context.linker().tryLink(instance);
}
@Override
public final Object execute(VirtualFrame frame) {
final WasmContext context = contextReference().get();
tryInitialize(context);
return executeWithContext(frame, context);
}
public Object executeWithContext(VirtualFrame frame, WasmContext context) {
final int maxStackSize = codeEntry.maxStackSize();
final int numLocals = body.codeEntry().numLocals();
long[] stacklocals = new long[numLocals + maxStackSize];
frame.setObject(codeEntry.stackLocalsSlot(), stacklocals);
moveArgumentsToLocals(frame, stacklocals);
initializeLocals(stacklocals);
body.execute(context, frame, stacklocals);
switch (body.returnTypeId()) {
case 0x00:
case WasmType.VOID_TYPE: {
return WasmVoidResult.getInstance();
}
case WasmType.I32_TYPE: {
long returnValue = pop(stacklocals, numLocals);
assert returnValue >>> 32 == 0;
return (int) returnValue;
}
case WasmType.I64_TYPE: {
long returnValue = pop(stacklocals, numLocals);
return returnValue;
}
case WasmType.F32_TYPE: {
long returnValue = pop(stacklocals, numLocals);
assert returnValue >>> 32 == 0;
return Float.intBitsToFloat((int) returnValue);
}
case WasmType.F64_TYPE: {
long returnValue = pop(stacklocals, numLocals);
return Double.longBitsToDouble(returnValue);
}
default:
throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, this, "Unknown return type id: " + body.returnTypeId());
}
}
@ExplodeLoop
private void moveArgumentsToLocals(VirtualFrame frame, long[] stacklocals) {
Object[] args = frame.getArguments();
int numArgs = body.instance().symbolTable().function(codeEntry().functionIndex()).numArguments();
assert args.length == numArgs : "Expected number of arguments " + numArgs + ", actual " + args.length;
for (int i = 0; i != numArgs; ++i) {
final Object arg = args[i];
byte type = body.codeEntry().localType(i);
switch (type) {
case WasmType.I32_TYPE:
stacklocals[i] = (int) arg;
break;
case WasmType.I64_TYPE:
stacklocals[i] = (long) arg;
break;
case WasmType.F32_TYPE:
stacklocals[i] = Float.floatToRawIntBits((float) arg);
break;
case WasmType.F64_TYPE:
stacklocals[i] = Double.doubleToRawLongBits((double) arg);
break;
}
}
}
@ExplodeLoop
private void initializeLocals(long[] stacklocals) {
int numArgs = body.instance().symbolTable().function(codeEntry().functionIndex()).numArguments();
for (int i = numArgs; i != body.codeEntry().numLocals(); ++i) {
byte type = body.codeEntry().localType(i);
switch (type) {
case WasmType.I32_TYPE:
break;
case WasmType.I64_TYPE:
break;
case WasmType.F32_TYPE:
stacklocals[i] = Float.floatToRawIntBits(0.0f);
break;
case WasmType.F64_TYPE:
stacklocals[i] = Double.doubleToRawLongBits(0.0);
break;
}
}
}
@Override
public WasmCodeEntry codeEntry() {
return codeEntry;
}
@Override
public String toString() {
return getName();
}
@Override
public String getName() {
if (codeEntry == null) {
return "function";
}
return codeEntry.function().name();
}
@Override
public String getQualifiedName() {
if (codeEntry == null) {
return getName();
}
return codeEntry.function().moduleName() + "." + getName();
}
}