package com.oracle.truffle.llvm.tests.util;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Context.Builder;
import org.graalvm.polyglot.Engine;
import org.graalvm.polyglot.Value;
import com.oracle.truffle.llvm.tests.pipe.CaptureOutput;
import com.oracle.truffle.llvm.runtime.LLVMLanguage;
import com.oracle.truffle.llvm.runtime.except.LLVMLinkerException;
import com.oracle.truffle.llvm.tests.options.TestOptions;
public class ProcessUtil {
private static final int BUFFER_SIZE = 1024;
private static final int PROCESS_WAIT_TIMEOUT = 60 * 1000;
private static final int JOIN_TIMEOUT = 5 * 1000;
public static final class ProcessResult {
private final String originalCommand;
private final String stdErr;
private final String stdOutput;
private final int returnValue;
private ProcessResult(String originalCommand, int returnValue, String stdErr, String stdOutput) {
this.originalCommand = originalCommand;
this.returnValue = returnValue;
this.stdErr = stdErr;
this.stdOutput = stdOutput;
}
public String getStdErr() {
return stdErr;
}
public String getStdOutput() {
return stdOutput;
}
public int getReturnValue() {
return returnValue;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("command : " + originalCommand + "\n");
sb.append("stderr: " + stdErr + "\n");
sb.append("stdout: " + stdOutput + "\n");
sb.append("return value: " + returnValue + "\n");
return sb.toString();
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof ProcessResult)) {
return false;
}
ProcessResult other = (ProcessResult) obj;
return this.returnValue == other.returnValue &&
Objects.equals(this.stdErr, other.stdErr) &&
Objects.equals(this.stdOutput, other.stdOutput);
}
@Override
public int hashCode() {
int hash = 5;
hash = 97 * hash + Objects.hashCode(this.stdErr);
hash = 97 * hash + Objects.hashCode(this.stdOutput);
hash = 97 * hash + this.returnValue;
return hash;
}
}
public static final class TimeoutError extends AssertionError {
private static final long serialVersionUID = 1L;
TimeoutError(String command) {
super("timeout running command: " + command);
}
}
public static ProcessResult executeSulongTestMain(File bitcodeFile, String[] args, Map<String, String> options, Function<Context.Builder, CaptureOutput> captureOutput) throws IOException {
return executeSulongTestMainSameEngine(bitcodeFile, args, options, captureOutput, Engine.newBuilder().allowExperimentalOptions(true).build());
}
public static ProcessResult executeSulongTestMainSameEngine(File bitcodeFile, String[] args, Map<String, String> options, Function<Context.Builder, CaptureOutput> captureOutput, Engine engine)
throws IOException {
if (TestOptions.TEST_AOT_IMAGE == null) {
org.graalvm.polyglot.Source source = org.graalvm.polyglot.Source.newBuilder(LLVMLanguage.ID, bitcodeFile).build();
Builder builder = Context.newBuilder();
try (CaptureOutput out = captureOutput.apply(builder)) {
int result;
try (Context context = builder.engine(engine).arguments(LLVMLanguage.ID, args).options(options).allowAllAccess(true).build()) {
Value main = context.eval(source);
if (!main.canExecute()) {
throw new LLVMLinkerException("No main function found.");
}
result = main.execute().asInt();
}
return new ProcessResult(bitcodeFile.getName(), result, out.getStdErr(), out.getStdOut());
}
} else {
String aotArgs = TestOptions.TEST_AOT_ARGS == null ? "" : TestOptions.TEST_AOT_ARGS + " ";
String cmdline = TestOptions.TEST_AOT_IMAGE + " " + aotArgs + concatOptions(options) + bitcodeFile.getAbsolutePath() + " " + concatCommand(args);
return executeNativeCommand(cmdline);
}
}
private static String concatOptions(Map<String, String> options) {
StringBuilder str = new StringBuilder();
for (Map.Entry<String, String> entry : options.entrySet()) {
String encoded = entry.getKey() + '=' + entry.getValue();
str.append("'--").append(encoded.replace("'", "''")).append("' ");
}
return str.toString();
}
public static ProcessResult executeNativeCommandZeroReturn(String command) {
ProcessResult result = executeNativeCommand(command);
checkNoError(result);
return result;
}
public static ProcessResult executeNativeCommandZeroReturn(String... command) {
ProcessResult result;
if (command.length == 1) {
result = executeNativeCommand(command[0]);
} else {
result = executeNativeCommand(concatCommand(command));
}
checkNoError(result);
return result;
}
static String concatCommand(Object[] command) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < command.length; i++) {
if (i != 0) {
sb.append(" ");
}
sb.append(command[i]);
}
return sb.toString();
}
public static ProcessResult executeNativeCommand(String command) {
if (command == null) {
throw new IllegalArgumentException("command is null!");
}
Process process = null;
try {
process = Runtime.getRuntime().exec(command);
StreamReader readError = StreamReader.read(process.getErrorStream());
StreamReader readOutput = StreamReader.read(process.getInputStream());
boolean success = process.waitFor(PROCESS_WAIT_TIMEOUT, TimeUnit.MILLISECONDS);
if (!success) {
throw new TimeoutError(command);
}
int llvmResult = process.exitValue();
return new ProcessResult(command, llvmResult, readError.getResult(), readOutput.getResult());
} catch (Exception e) {
throw new RuntimeException(command + " ", e);
} finally {
if (process != null) {
process.destroyForcibly();
}
}
}
public static void checkNoError(ProcessResult processResult) {
if (processResult.getReturnValue() != 0) {
throw new IllegalStateException(processResult.originalCommand + " exited with value " + processResult.getReturnValue() + " " + processResult.getStdErr());
}
}
private static class StreamReader {
private final Thread thread;
private final ByteArrayOutputStream result;
private IOException exception;
static StreamReader read(InputStream inputStream) {
StreamReader ret = new StreamReader(inputStream);
ret.thread.start();
return ret;
}
StreamReader(InputStream inputStream) {
this.result = new ByteArrayOutputStream();
this.thread = new Thread(new Runnable() {
@Override
public void run() {
try {
final byte[] buffer = new byte[BUFFER_SIZE];
int length;
while ((length = inputStream.read(buffer)) != -1) {
result.write(buffer, 0, length);
}
} catch (IOException ex) {
exception = ex;
}
}
});
}
String getResult() throws IOException {
try {
thread.join(JOIN_TIMEOUT);
result.close();
} catch (InterruptedException ex) {
}
if (exception != null) {
throw exception;
}
return result.toString();
}
}
}