package org.graalvm.wasm.test;
import com.oracle.truffle.api.Truffle;
import junit.framework.AssertionFailedError;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.PolyglotException;
import org.graalvm.polyglot.Source;
import org.graalvm.polyglot.Value;
import org.graalvm.wasm.GlobalRegistry;
import org.graalvm.wasm.WasmContext;
import org.graalvm.wasm.WasmFunctionInstance;
import org.graalvm.wasm.WasmInstance;
import org.graalvm.wasm.memory.WasmMemory;
import org.graalvm.wasm.test.options.WasmTestOptions;
import org.graalvm.wasm.utils.cases.WasmCase;
import org.graalvm.wasm.utils.cases.WasmCaseData;
import org.junit.Assert;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static junit.framework.TestCase.fail;
import static org.graalvm.wasm.WasmUtil.prepend;
public abstract class WasmFileSuite extends AbstractWasmSuite {
private static final String MOVE_LEFT = "\u001b[1D";
private static final String TEST_PASSED_ICON = "\uD83D\uDE0D";
private static final String TEST_FAILED_ICON = "\uD83D\uDE21";
private static final String TEST_IN_PROGRESS_ICON = "\u003F";
private static final String PHASE_PARSE_ICON = "\uD83D\uDCD6";
private static final String PHASE_SYNC_NO_INLINE_ICON = "\uD83D\uDD39";
private static final String PHASE_SYNC_INLINE_ICON = "\uD83D\uDD37";
private static final String PHASE_ASYNC_ICON = "\uD83D\uDD36";
private static final String PHASE_INTERPRETER_ICON = "\uD83E\uDD16";
private static final int STATUS_ICON_WIDTH = 2;
private static final int STATUS_LABEL_WIDTH = 11;
private static final int DEFAULT_INTERPRETER_ITERATIONS = 1;
private static final int DEFAULT_SYNC_NOINLINE_ITERATIONS = 3;
private static final int DEFAULT_SYNC_INLINE_ITERATIONS = 3;
private static final int DEFAULT_ASYNC_ITERATIONS = 100000;
private static final int INITIAL_STATE_CHECK_ITERATIONS = 10;
private static final int STATE_CHECK_PERIODICITY = 2000;
private static Context getInterpretedNoInline(Context.Builder contextBuilder) {
contextBuilder.option("engine.Compilation", "false");
contextBuilder.option("engine.Inlining", "false");
return contextBuilder.build();
}
private static Context getSyncCompiledNoInline(Context.Builder contextBuilder) {
contextBuilder.option("engine.Compilation", "true");
contextBuilder.option("engine.BackgroundCompilation", "false");
contextBuilder.option("engine.CompileImmediately", "true");
contextBuilder.option("engine.Inlining", "false");
return contextBuilder.build();
}
private static Context getSyncCompiledWithInline(Context.Builder contextBuilder) {
contextBuilder.option("engine.Compilation", "true");
contextBuilder.option("engine.BackgroundCompilation", "false");
contextBuilder.option("engine.CompileImmediately", "true");
contextBuilder.option("engine.Inlining", "true");
return contextBuilder.build();
}
private static Context getAsyncCompiled(Context.Builder contextBuilder) {
contextBuilder.option("engine.Compilation", "true");
contextBuilder.option("engine.BackgroundCompilation", "true");
contextBuilder.option("engine.CompileImmediately", "false");
contextBuilder.option("engine.Inlining", "false");
contextBuilder.option("engine.CompilationThreshold", "100");
return contextBuilder.build();
}
private static boolean inCI() {
final String prid = System.getenv("CONTINUOUS_INTEGRATION");
return prid != null;
}
private static boolean inWindows() {
return System.getProperty("os.name").contains("win");
}
private static boolean isPoorShell() {
return inCI() || inWindows();
}
private static Value findMain(WasmContext wasmContext) {
for (final WasmInstance instance : wasmContext.moduleInstances().values()) {
final WasmFunctionInstance function = instance.inferEntryPoint();
if (function != null) {
return Value.asValue(function);
}
}
throw new AssertionFailedError("No start function exported.");
}
private static void runInContext(WasmCase testCase, Context context, List<Source> sources, int iterations, String phaseIcon, String phaseLabel) {
final PrintStream oldOut = System.out;
try {
final ByteArrayOutputStream capturedStream = new ByteArrayOutputStream();
final PrintStream capturedStdout = new PrintStream(capturedStream);
System.setOut(capturedStdout);
final boolean requiresZeroMemory = Boolean.parseBoolean(testCase.options().getProperty("zero-memory", "false"));
resetStatus(oldOut, PHASE_PARSE_ICON, "parsing");
context.enter();
try {
sources.forEach(context::eval);
} catch (PolyglotException e) {
validateThrown(testCase.data(), WasmCaseData.ErrorType.Validation, e);
return;
}
final WasmContext wasmContext = WasmContext.getCurrent();
final Value mainFunction = findMain(wasmContext);
resetStatus(oldOut, phaseIcon, phaseLabel);
final String argString = testCase.options().getProperty("argument");
final Integer arg = argString == null ? null : Integer.parseInt(argString);
ContextState firstIterationContextState = null;
for (int i = 0; i != iterations; ++i) {
try {
capturedStream.reset();
final Value result = arg == null ? mainFunction.execute() : mainFunction.execute(arg);
WasmCase.validateResult(testCase.data().resultValidator(), result, capturedStream);
} catch (PolyglotException e) {
if (e.isExit() && testCase.data().expectedErrorMessage() == null) {
Assert.assertEquals("Program exited with non-zero return code.", e.getExitStatus(), 0);
WasmCase.validateResult(testCase.data().resultValidator(), null, capturedStream);
} else if (testCase.data().expectedErrorTime() == WasmCaseData.ErrorType.Validation) {
validateThrown(testCase.data(), WasmCaseData.ErrorType.Validation, e);
return;
} else {
validateThrown(testCase.data(), WasmCaseData.ErrorType.Runtime, e);
}
} catch (Throwable t) {
final RuntimeException e = new RuntimeException("Error during test phase '" + phaseLabel + "'", t);
e.setStackTrace(new StackTraceElement[0]);
throw e;
} finally {
if (iterationNeedsStateCheck(i)) {
final ContextState contextState = saveContext(wasmContext);
if (firstIterationContextState == null) {
firstIterationContextState = contextState;
} else {
assertContextEqual(firstIterationContextState, contextState);
}
}
final boolean reinitMemory = requiresZeroMemory || iterationNeedsStateCheck(i + 1);
if (reinitMemory) {
for (int j = 0; j < wasmContext.memories().count(); ++j) {
wasmContext.memories().memory(j).clear();
}
}
for (final WasmInstance instance : wasmContext.moduleInstances().values()) {
if (!instance.isBuiltin()) {
wasmContext.reinitInstance(instance, reinitMemory);
}
}
}
}
} finally {
context.close(true);
System.setOut(oldOut);
}
}
private static boolean iterationNeedsStateCheck(int i) {
return i < INITIAL_STATE_CHECK_ITERATIONS || i % STATE_CHECK_PERIODICITY == 0;
}
private static void resetStatus(PrintStream oldOut, String icon, String label) {
String formattedLabel = label;
if (formattedLabel.length() > STATUS_LABEL_WIDTH) {
formattedLabel = formattedLabel.substring(0, STATUS_LABEL_WIDTH);
}
for (int i = formattedLabel.length(); i < STATUS_LABEL_WIDTH; i++) {
formattedLabel += " ";
}
if (isPoorShell()) {
oldOut.println();
oldOut.print(icon);
oldOut.print(formattedLabel);
oldOut.flush();
} else {
eraseStatus(oldOut);
oldOut.print(icon);
oldOut.print(formattedLabel);
oldOut.flush();
}
}
private static void eraseStatus(PrintStream oldOut) {
for (int i = 0; i < STATUS_ICON_WIDTH + STATUS_LABEL_WIDTH; i++) {
oldOut.print(MOVE_LEFT);
}
}
private WasmTestStatus runTestCase(WasmCase testCase) {
try {
Context.Builder contextBuilder = Context.newBuilder("wasm");
contextBuilder.allowExperimentalOptions(true);
contextBuilder.option("engine.EncodedGraphCacheCapacity", "-1");
if (WasmTestOptions.LOG_LEVEL != null && !WasmTestOptions.LOG_LEVEL.equals("")) {
contextBuilder.option("log.wasm.level", WasmTestOptions.LOG_LEVEL);
}
if (WasmTestOptions.STORE_CONSTANTS_POLICY != null && !WasmTestOptions.STORE_CONSTANTS_POLICY.equals("")) {
contextBuilder.option("wasm.StoreConstantsPolicy", WasmTestOptions.STORE_CONSTANTS_POLICY);
System.out.println("wasm.StoreConstantsPolicy: " + WasmTestOptions.STORE_CONSTANTS_POLICY);
}
contextBuilder.option("wasm.Builtins", includedExternalModules());
String commandLineArgs = testCase.options().getProperty("command-line-args");
if (commandLineArgs != null) {
contextBuilder.arguments("wasm", prepend(commandLineArgs.split(" "), ""));
}
Context context;
ArrayList<Source> sources = testCase.getSources();
int interpreterIterations = Integer.parseInt(testCase.options().getProperty("interpreter-iterations", String.valueOf(DEFAULT_INTERPRETER_ITERATIONS)));
context = getInterpretedNoInline(contextBuilder);
runInContext(testCase, context, sources, interpreterIterations, PHASE_INTERPRETER_ICON, "interpreter");
int syncNoinlineIterations = Integer.parseInt(testCase.options().getProperty("sync-noinline-iterations", String.valueOf(DEFAULT_SYNC_NOINLINE_ITERATIONS)));
context = getSyncCompiledNoInline(contextBuilder);
runInContext(testCase, context, sources, syncNoinlineIterations, PHASE_SYNC_NO_INLINE_ICON, "sync,no-inl");
int syncInlineIterations = Integer.parseInt(testCase.options().getProperty("sync-inline-iterations", String.valueOf(DEFAULT_SYNC_INLINE_ITERATIONS)));
context = getSyncCompiledWithInline(contextBuilder);
runInContext(testCase, context, sources, syncInlineIterations, PHASE_SYNC_INLINE_ICON, "sync,inl");
int asyncIterations = Integer.parseInt(testCase.options().getProperty("async-iterations", String.valueOf(DEFAULT_ASYNC_ITERATIONS)));
context = getAsyncCompiled(contextBuilder);
runInContext(testCase, context, sources, asyncIterations, PHASE_ASYNC_ICON, "async,multi");
} catch (InterruptedException | IOException e) {
Assert.fail(String.format("Test %s failed: %s", testCase.name(), e.getMessage()));
}
return WasmTestStatus.OK;
}
protected String includedExternalModules() {
return "testutil:testutil";
}
private static void validateThrown(WasmCaseData data, WasmCaseData.ErrorType phase, PolyglotException e) throws PolyglotException {
if (data.expectedErrorMessage() == null || !data.expectedErrorMessage().equals(e.getMessage())) {
throw e;
}
Assert.assertEquals("Unexpected error phase.", data.expectedErrorTime(), phase);
}
@Override
public void test() throws IOException {
Collection<? extends WasmCase> allTestCases = collectTestCases();
Collection<? extends WasmCase> qualifyingTestCases = filterTestCases(allTestCases);
Map<WasmCase, Throwable> errors = new LinkedHashMap<>();
System.out.println();
System.out.println("--------------------------------------------------------------------------------");
System.out.printf("Running: %s ", suiteName());
if (allTestCases.size() != qualifyingTestCases.size()) {
System.out.printf("(%d/%d tests - you have enabled filters)%n", qualifyingTestCases.size(), allTestCases.size());
} else {
System.out.printf("(%d tests)%n", qualifyingTestCases.size());
}
System.out.println("--------------------------------------------------------------------------------");
System.out.println("Using runtime: " + Truffle.getRuntime().toString());
int width = retrieveTerminalWidth();
int position = 0;
for (WasmCase testCase : qualifyingTestCases) {
int extraWidth = 1 + STATUS_ICON_WIDTH + STATUS_LABEL_WIDTH;
int requiredWidth = testCase.name().length() + extraWidth;
if (position + requiredWidth >= width) {
System.out.println();
position = 0;
}
String statusIcon = TEST_IN_PROGRESS_ICON;
try {
System.out.print(" ");
System.out.print(testCase.name());
for (int i = 1; i < extraWidth; i++) {
System.out.print(" ");
}
System.out.flush();
runTestCase(testCase);
statusIcon = TEST_PASSED_ICON;
} catch (Throwable e) {
statusIcon = TEST_FAILED_ICON;
errors.put(testCase, e);
} finally {
if (isPoorShell()) {
System.out.println();
System.out.println(statusIcon);
System.out.println();
} else {
for (int i = 0; i < requiredWidth; i++) {
System.out.print(MOVE_LEFT);
System.out.print(" ");
System.out.print(MOVE_LEFT);
}
System.out.print(statusIcon);
System.out.flush();
}
}
position++;
}
System.out.println();
System.out.println("Finished running: " + suiteName());
if (!errors.isEmpty()) {
for (Map.Entry<WasmCase, Throwable> entry : errors.entrySet()) {
System.err.printf("Failure in: %s.%s%n", suiteName(), entry.getKey().name());
System.err.println(entry.getValue().getClass().getSimpleName() + ": " + entry.getValue().getMessage());
entry.getValue().printStackTrace();
}
System.err.printf("\uD83D\uDCA5\u001B[31m %d/%d Wasm tests passed.\u001B[0m%n", qualifyingTestCases.size() - errors.size(), qualifyingTestCases.size());
System.out.println();
fail();
} else {
System.out.printf("\uD83C\uDF40\u001B[32m %d/%d Wasm tests passed.\u001B[0m%n", qualifyingTestCases.size() - errors.size(), qualifyingTestCases.size());
System.out.println();
}
}
private static int retrieveTerminalWidth() {
try {
final ProcessBuilder builder = new ProcessBuilder("/bin/sh", "-c", "stty size </dev/tty");
final Process process = builder.start();
final BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
final String output = reader.readLine();
if (process.waitFor() != 0) {
return -1;
}
return Integer.parseInt(output.split(" ")[1]);
} catch (IOException e) {
System.err.println("Could not retrieve terminal width: " + e.getMessage());
return -1;
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
protected String testResource() {
return null;
}
protected Collection<? extends WasmCase> collectTestCases() throws IOException {
return Stream.concat(collectStringTestCases().stream(), WasmCase.collectFileCases("test", testResource()).stream()).collect(Collectors.toList());
}
protected Collection<? extends WasmCase> collectStringTestCases() {
return new ArrayList<>();
}
protected Collection<? extends WasmCase> filterTestCases(Collection<? extends WasmCase> testCases) {
return testCases.stream().filter((WasmCase x) -> filterTestName().test(x.name())).collect(Collectors.toList());
}
protected String suiteName() {
return getClass().getSimpleName();
}
private static ContextState saveContext(WasmContext context) {
Assert.assertTrue("Currently, only 0 or 1 memories can be saved.", context.memories().count() <= 1);
final WasmMemory currentMemory = context.memories().count() == 1 ? context.memories().memory(0).duplicate() : null;
final GlobalRegistry globals = context.globals().duplicate();
return new ContextState(currentMemory, globals);
}
private static void assertContextEqual(ContextState expectedState, ContextState actualState) {
final WasmMemory expectedMemory = expectedState.memory();
final WasmMemory actualMemory = actualState.memory();
if (expectedMemory == null) {
Assert.assertNull("Memory should be null", actualMemory);
} else {
Assert.assertNotNull("Memory should not be null", actualMemory);
Assert.assertEquals("Mismatch in memory lengths", expectedMemory.byteSize(), actualMemory.byteSize());
for (int ptr = 0; ptr < expectedMemory.byteSize(); ptr++) {
byte expectedByte = (byte) expectedMemory.load_i32_8s(null, ptr);
byte actualByte = (byte) actualMemory.load_i32_8s(null, ptr);
Assert.assertEquals("Memory mismatch", expectedByte, actualByte);
}
}
final GlobalRegistry firstGlobals = expectedState.globals();
final GlobalRegistry lastGlobals = actualState.globals();
Assert.assertEquals("Mismatch in global counts.", firstGlobals.count(), lastGlobals.count());
for (int address = 0; address < firstGlobals.count(); address++) {
long first = firstGlobals.loadAsLong(address);
long last = lastGlobals.loadAsLong(address);
Assert.assertEquals("Mismatch in global at " + address + ". ", first, last);
}
}
private static final class ContextState {
private final WasmMemory memory;
private final GlobalRegistry globals;
private ContextState(WasmMemory memory, GlobalRegistry globals) {
this.memory = memory;
this.globals = globals;
}
public WasmMemory memory() {
return memory;
}
public GlobalRegistry globals() {
return globals;
}
}
}