package com.oracle.truffle.nfi.test;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import com.oracle.truffle.api.CallTarget;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.interop.InteropException;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;
import com.oracle.truffle.api.nodes.Node;
import static com.oracle.truffle.nfi.test.NFITest.NFITestRootNode.getInterop;
import com.oracle.truffle.tck.TruffleRunner;
import com.oracle.truffle.tck.TruffleRunner.Inject;
@RunWith(TruffleRunner.class)
public class RegisterPackageNFITest extends NFITest {
private static final FunctionRegistry REGISTRY = new FunctionRegistry();
@ExportLibrary(InteropLibrary.class)
static class FunctionRegistry implements TruffleObject {
private final Map<String, Object> functions;
@TruffleBoundary
FunctionRegistry() {
functions = new HashMap<>();
}
@TruffleBoundary
void clear() {
functions.clear();
}
@TruffleBoundary
void add(String name, Object function) {
functions.put(name, function);
}
@TruffleBoundary
Object get(String name) {
return functions.get(name);
}
@ExportMessage
boolean isExecutable() {
return true;
}
@ExportMessage
Object execute(Object[] args,
@Cached RegisterFunctionNode register) {
register.execute(this, (String) args[0], (String) args[1], args[2]);
return "";
}
@GenerateUncached
abstract static class RegisterFunctionNode extends Node {
protected abstract void execute(FunctionRegistry receiver, String name, String signature, Object symbol);
@Specialization(limit = "3")
static void register(FunctionRegistry receiver, String name, String signature, Object symbol,
@CachedLibrary("symbol") InteropLibrary interop) {
try {
Object boundSymbol = interop.invokeMember(symbol, "bind", signature);
receiver.add(name, boundSymbol);
} catch (InteropException ex) {
CompilerDirectives.transferToInterpreter();
throw new AssertionError(ex);
}
}
}
}
static class LoadPackageNode extends Node {
private final Object initializePackage = lookupAndBind("initialize_package", "((string,string,pointer):void):void");
@Child InteropLibrary interop = getInterop(initializePackage);
FunctionRegistry loadPackage() {
REGISTRY.clear();
try {
interop.execute(initializePackage, REGISTRY);
} catch (InteropException ex) {
CompilerDirectives.transferToInterpreter();
throw new AssertionError(ex);
}
return REGISTRY;
}
}
public static class RegisterPackageTestNode extends NFITestRootNode {
@Child LoadPackageNode loadPackage = new LoadPackageNode();
@Child InteropLibrary interop = getInterop();
@Override
public Object executeTest(VirtualFrame frame) throws InteropException {
FunctionRegistry registry = loadPackage.loadPackage();
Object add = registry.get("add");
Object square = registry.get("square");
Object sqrt = registry.get("sqrt");
double a = (Double) frame.getArguments()[0];
double b = (Double) frame.getArguments()[1];
double aSq = (Double) interop.execute(square, a);
double bSq = (Double) interop.execute(square, b);
double cSq = (Double) interop.execute(add, aSq, bSq);
return interop.execute(sqrt, cSq);
}
}
@Test
public void testPythagoras(@Inject(RegisterPackageTestNode.class) CallTarget callTarget) {
Object ret = callTarget.call(3.0, 4.0);
Assert.assertThat("return value", ret, is(instanceOf(Double.class)));
Assert.assertEquals("return value", 5.0, ret);
}
}