package com.oracle.svm.core.graal.llvm;
import static com.oracle.svm.core.util.VMError.shouldNotReachHere;
import static com.oracle.svm.hosted.image.NativeBootImage.RWDATA_CGLOBALS_PARTITION_OFFSET;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.graalvm.compiler.code.CompilationResult;
import org.graalvm.compiler.core.common.NumUtil;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.debug.Indent;
import org.graalvm.nativeimage.Platform;
import org.graalvm.nativeimage.Platforms;
import org.graalvm.word.WordFactory;
import com.oracle.graal.pointsto.BigBang;
import com.oracle.graal.pointsto.util.CompletionExecutor;
import com.oracle.graal.pointsto.util.CompletionExecutor.DebugContextRunnable;
import com.oracle.graal.pointsto.util.Timer;
import com.oracle.graal.pointsto.util.Timer.StopTimer;
import com.oracle.objectfile.ObjectFile;
import com.oracle.objectfile.ObjectFile.Element;
import com.oracle.objectfile.SectionName;
import com.oracle.svm.core.SubstrateOptions;
import com.oracle.svm.core.SubstrateUtil;
import com.oracle.svm.core.graal.code.CGlobalDataReference;
import com.oracle.svm.core.graal.llvm.util.LLVMObjectFileReader;
import com.oracle.svm.core.graal.llvm.util.LLVMObjectFileReader.LLVMTextSectionInfo;
import com.oracle.svm.core.graal.llvm.util.LLVMOptions;
import com.oracle.svm.core.graal.llvm.util.LLVMStackMapInfo;
import com.oracle.svm.core.graal.llvm.util.LLVMTargetSpecific;
import com.oracle.svm.core.graal.llvm.util.LLVMToolchain;
import com.oracle.svm.core.graal.llvm.util.LLVMToolchain.RunFailureException;
import com.oracle.svm.core.heap.SubstrateReferenceMap;
import com.oracle.svm.core.jdk.UninterruptibleUtils.AtomicInteger;
import com.oracle.svm.hosted.image.NativeBootImage.NativeTextSectionImpl;
import com.oracle.svm.hosted.image.NativeImageCodeCache;
import com.oracle.svm.hosted.image.NativeImageHeap;
import com.oracle.svm.hosted.image.RelocatableBuffer;
import com.oracle.svm.hosted.meta.HostedMethod;
import com.oracle.svm.hosted.meta.MethodPointer;
import jdk.vm.ci.code.site.Call;
import jdk.vm.ci.code.site.DataPatch;
import jdk.vm.ci.code.site.DataSectionReference;
@Platforms(Platform.HOSTED_ONLY.class)
public class LLVMNativeImageCodeCache extends NativeImageCodeCache {
private HostedMethod[] methodIndex;
private final Path basePath;
private int batchSize;
private final LLVMObjectFileReader objectFileReader;
private final List<ObjectFile.Symbol> globalSymbols = new ArrayList<>();
private final StackMapDumper stackMapDumper;
LLVMNativeImageCodeCache(Map<HostedMethod, CompilationResult> compilations, NativeImageHeap imageHeap, Platform targetPlatform, Path tempDir) {
super(compilations, imageHeap, targetPlatform);
try {
basePath = tempDir.resolve("llvm");
Files.createDirectory(basePath);
} catch (IOException e) {
throw new GraalError(e);
}
this.stackMapDumper = getStackMapDumper(LLVMOptions.DumpLLVMStackMap.hasBeenSet());
this.objectFileReader = new LLVMObjectFileReader(stackMapDumper);
}
@Override
public int getCodeCacheSize() {
return 0;
}
@Override
@SuppressWarnings({"unused", "try"})
public void layoutMethods(DebugContext debug, String imageName, BigBang bb, ForkJoinPool threadPool) {
try (Indent indent = debug.logAndIndent("layout methods")) {
BatchExecutor executor = new BatchExecutor(bb, threadPool);
try (StopTimer t = new Timer(imageName, "(bitcode)").start()) {
writeBitcode(executor);
}
int numBatches;
try (StopTimer t = new Timer(imageName, "(prelink)").start()) {
numBatches = createBitcodeBatches(executor, debug);
}
try (StopTimer t = new Timer(imageName, "(llvm)").start()) {
compileBitcodeBatches(executor, debug, numBatches);
}
try (StopTimer t = new Timer(imageName, "(postlink)").start()) {
linkCompiledBatches(executor, debug, numBatches);
}
}
}
private void writeBitcode(BatchExecutor executor) {
methodIndex = new HostedMethod[compilations.size()];
AtomicInteger num = new AtomicInteger(-1);
executor.forEach(compilations.entrySet(), entry -> (debugContext) -> {
int id = num.incrementAndGet();
methodIndex[id] = entry.getKey();
try (FileOutputStream fos = new FileOutputStream(getBitcodePath(id).toString())) {
fos.write(entry.getValue().getTargetCode());
} catch (IOException e) {
throw new GraalError(e);
}
});
}
private int createBitcodeBatches(BatchExecutor executor, DebugContext debug) {
batchSize = LLVMOptions.LLVMMaxFunctionsPerBatch.getValue();
int numThreads = executor.executor.getExecutorService().getParallelism();
int idealSize = NumUtil.divideAndRoundUp(methodIndex.length, numThreads);
if (idealSize < batchSize) {
batchSize = idealSize;
}
if (batchSize == 0) {
batchSize = methodIndex.length;
}
int numBatches = NumUtil.divideAndRoundUp(methodIndex.length, batchSize);
if (batchSize > 1) {
numBatches -= (numBatches * batchSize - methodIndex.length) / batchSize;
executor.forEach(numBatches, batchId -> (debugContext) -> {
List<String> batchInputs = IntStream.range(getBatchStart(batchId), getBatchEnd(batchId)).mapToObj(this::getBitcodeFilename)
.collect(Collectors.toList());
llvmLink(debug, getBatchBitcodeFilename(batchId), batchInputs);
});
}
return numBatches;
}
private void compileBitcodeBatches(BatchExecutor executor, DebugContext debug, int numBatches) {
stackMapDumper.startDumpingFunctions();
executor.forEach(numBatches, batchId -> (debugContext) -> {
llvmOptimize(debug, getBatchOptimizedFilename(batchId), getBatchBitcodeFilename(batchId));
llvmCompile(debug, getBatchCompiledFilename(batchId), getBatchOptimizedFilename(batchId));
LLVMStackMapInfo stackMap = objectFileReader.parseStackMap(getBatchCompiledPath(batchId));
IntStream.range(getBatchStart(batchId), getBatchEnd(batchId)).forEach(id -> objectFileReader.readStackMap(stackMap, compilations.get(methodIndex[id]), methodIndex[id], id));
});
}
private void linkCompiledBatches(BatchExecutor executor, DebugContext debug, int numBatches) {
List<String> compiledBatches = IntStream.range(0, numBatches).mapToObj(this::getBatchCompiledFilename).collect(Collectors.toList());
nativeLink(debug, getLinkedFilename(), compiledBatches);
LLVMTextSectionInfo textSectionInfo = objectFileReader.parseCode(getLinkedPath());
executor.forEach(compilations.entrySet(), entry -> (debugContext) -> {
HostedMethod method = entry.getKey();
int offset = textSectionInfo.getOffset(SubstrateUtil.uniqueShortName(method));
int nextFunctionStartOffset = textSectionInfo.getNextOffset(offset);
int functionSize = nextFunctionStartOffset - offset;
CompilationResult compilation = entry.getValue();
compilation.setTargetCode(null, functionSize);
method.setCodeAddressOffset(offset);
});
compilations.forEach((method, compilation) -> compilationsByStart.put(method.getCodeAddressOffset(), compilation));
stackMapDumper.dumpOffsets(textSectionInfo);
stackMapDumper.close();
HostedMethod firstMethod = (HostedMethod) getFirstCompilation().getMethods()[0];
buildRuntimeMetadata(MethodPointer.factory(firstMethod), WordFactory.signed(textSectionInfo.getCodeSize()));
}
private void llvmOptimize(DebugContext debug, String outputPath, String inputPath) {
List<String> args = new ArrayList<>();
if (LLVMOptions.BitcodeOptimizations.getValue()) {
args.add("-disable-inlining");
args.add("-O2");
} else {
args.add("-mem2reg");
}
args.add("-rewrite-statepoints-for-gc");
args.add("-always-inline");
args.add("-o");
args.add(outputPath);
args.add(inputPath);
try {
LLVMToolchain.runLLVMCommand("opt", basePath, args);
} catch (RunFailureException e) {
debug.log("%s", e.getOutput());
throw new GraalError("LLVM optimization failed for " + getFunctionName(inputPath) + ": " + e.getStatus() + "\nCommand: opt " + String.join(" ", args));
}
}
private void llvmCompile(DebugContext debug, String outputPath, String inputPath) {
List<String> args = new ArrayList<>();
args.add("-relocation-model=pic");
args.add("--trap-unreachable");
args.add("-march=" + LLVMTargetSpecific.get().getLLVMArchName());
args.addAll(LLVMTargetSpecific.get().getLLCAdditionalOptions());
args.add("-O" + SubstrateOptions.Optimize.getValue());
args.add("-filetype=obj");
args.add("-o");
args.add(outputPath);
args.add(inputPath);
try {
LLVMToolchain.runLLVMCommand("llc", basePath, args);
} catch (RunFailureException e) {
debug.log("%s", e.getOutput());
throw new GraalError("LLVM compilation failed for " + getFunctionName(inputPath) + ": " + e.getStatus() + "\nCommand: llc " + String.join(" ", args));
}
}
private void llvmLink(DebugContext debug, String outputPath, List<String> inputPaths) {
List<String> args = new ArrayList<>();
args.add("-o");
args.add(outputPath);
args.addAll(inputPaths);
try {
LLVMToolchain.runLLVMCommand("llvm-link", basePath, args);
} catch (RunFailureException e) {
debug.log("%s", e.getOutput());
throw new GraalError("LLVM linking failed into " + getFunctionName(outputPath) + ": " + e.getStatus());
}
}
private void nativeLink(DebugContext debug, String outputPath, List<String> inputPaths) {
List<String> cmd = new ArrayList<>();
cmd.add((LLVMOptions.CustomLD.hasBeenSet()) ? LLVMOptions.CustomLD.getValue() : "ld");
cmd.add("-r");
cmd.add("-o");
cmd.add(outputPath);
cmd.addAll(inputPaths);
try {
LLVMToolchain.runCommand(basePath, cmd);
} catch (RunFailureException e) {
debug.log("%s", e.getOutput());
throw new GraalError("Native linking failed into " + getFunctionName(outputPath) + ": " + e.getStatus());
}
}
private Path getBitcodePath(int id) {
return basePath.resolve(getBitcodeFilename(id));
}
private String getBitcodeFilename(int id) {
return "f" + id + ".bc";
}
private String getBatchBitcodeFilename(int id) {
return ((batchSize == 1) ? "f" : "b") + id + ".bc";
}
private String getBatchOptimizedFilename(int id) {
return ((batchSize == 1) ? "f" : "b") + id + "o.bc";
}
private Path getBatchCompiledPath(int id) {
return basePath.resolve(getBatchCompiledFilename(id));
}
private String getBatchCompiledFilename(int id) {
return ((batchSize == 1) ? "f" : "b") + id + ".o";
}
private Path getLinkedPath() {
return basePath.resolve(getLinkedFilename());
}
private static String getLinkedFilename() {
return "llvm.o";
}
private int getBatchStart(int id) {
return id * batchSize;
}
private int getBatchEnd(int id) {
return Math.min((id + 1) * batchSize, methodIndex.length);
}
private String getFunctionName(String fileName) {
String function;
if (fileName.equals("llvm.o")) {
function = "the final object file";
} else {
char type = fileName.charAt(0);
String idString = fileName.substring(1, fileName.indexOf('.'));
if (idString.charAt(idString.length() - 1) == 'o') {
idString = idString.substring(0, idString.length() - 1);
}
int id = Integer.parseInt(idString);
switch (type) {
case 'f':
function = methodIndex[id].getQualifiedName();
break;
case 'b':
function = "batch " + id + " (f" + getBatchStart(id) + "-f" + getBatchEnd(id) + "). Use -H:LLVMMaxFunctionsPerBatch=1 to compile each method individually.";
break;
default:
throw shouldNotReachHere();
}
}
return function + " (" + basePath.resolve(fileName).toString() + ")";
}
@Override
public void patchMethods(DebugContext debug, RelocatableBuffer relocs, ObjectFile objectFile) {
Element rodataSection = objectFile.elementForName(SectionName.RODATA.getFormatDependentName(objectFile.getFormat()));
Element dataSection = objectFile.elementForName(SectionName.DATA.getFormatDependentName(objectFile.getFormat()));
for (CompilationResult result : getCompilations().values()) {
for (DataPatch dataPatch : result.getDataPatches()) {
if (dataPatch.reference instanceof CGlobalDataReference) {
CGlobalDataReference reference = (CGlobalDataReference) dataPatch.reference;
if (reference.getDataInfo().isSymbolReference()) {
objectFile.createUndefinedSymbol(reference.getDataInfo().getData().symbolName, 0, true);
}
int offset = reference.getDataInfo().getOffset();
String symbolName = (String) dataPatch.note;
if (reference.getDataInfo().getData().symbolName == null && objectFile.getOrCreateSymbolTable().getSymbol(symbolName) == null) {
objectFile.createDefinedSymbol(symbolName, dataSection, offset + RWDATA_CGLOBALS_PARTITION_OFFSET, 0, false, true);
}
} else if (dataPatch.reference instanceof DataSectionReference) {
DataSectionReference reference = (DataSectionReference) dataPatch.reference;
int offset = reference.getOffset();
String symbolName = (String) dataPatch.note;
if (objectFile.getOrCreateSymbolTable().getSymbol(symbolName) == null) {
objectFile.createDefinedSymbol(symbolName, rodataSection, offset, 0, false, true);
}
}
}
}
}
@Override
public NativeTextSectionImpl getTextSectionImpl(RelocatableBuffer buffer, ObjectFile objectFile, NativeImageCodeCache codeCache) {
return new NativeTextSectionImpl(buffer, objectFile, codeCache) {
@Override
protected void defineMethodSymbol(String name, boolean global, Element section, HostedMethod method, CompilationResult result) {
ObjectFile.Symbol symbol = objectFile.createUndefinedSymbol(name, 0, true);
if (global) {
globalSymbols.add(symbol);
}
}
};
}
@Override
public void writeCode(RelocatableBuffer buffer) {
}
@Override
public Path[] getCCInputFiles(Path tempDirectory, String imageName) {
Path[] nativeImageFiles = super.getCCInputFiles(tempDirectory, imageName);
Path[] allInputFiles = Arrays.copyOf(nativeImageFiles, nativeImageFiles.length + 1);
Path bitcodeFileName = getLinkedPath();
allInputFiles[nativeImageFiles.length] = bitcodeFileName;
return allInputFiles;
}
@Override
public List<ObjectFile.Symbol> getSymbols(ObjectFile objectFile, boolean onlyGlobal) {
return globalSymbols;
}
private static final class BatchExecutor {
private CompletionExecutor executor;
private BatchExecutor(BigBang bb, ForkJoinPool threadPool) {
this.executor = new CompletionExecutor(bb, threadPool, bb.getHeartbeatCallback());
executor.init();
}
private void forEach(int num, IntFunction<DebugContextRunnable> callback) {
try {
executor.start();
for (int i = 0; i < num; ++i) {
executor.execute(callback.apply(i));
}
executor.complete();
executor.init();
} catch (InterruptedException e) {
throw new GraalError(e);
}
}
private <T> void forEach(Set<T> set, Function<T, DebugContextRunnable> callback) {
try {
executor.start();
for (T elem : set) {
executor.execute(callback.apply(elem));
}
executor.complete();
executor.init();
} catch (InterruptedException e) {
throw new GraalError(e);
}
}
}
private StackMapDumper getStackMapDumper(boolean enable) {
if (enable) {
return new EnabledStackMapDumper();
} else {
return new DisabledStackMapDumper();
}
}
public interface StackMapDumper {
void dumpOffsets(LLVMTextSectionInfo textSectionInfo);
void startDumpingFunctions();
void startDumpingFunction(String methodSymbolName, int id, int totalFrameSize);
void dumpCallSite(Call call, int actualPcOffset, SubstrateReferenceMap referenceMap);
void endDumpingFunction();
void close();
}
private class EnabledStackMapDumper implements StackMapDumper {
private final FileWriter stackMapDump;
{
try {
stackMapDump = new FileWriter(LLVMOptions.DumpLLVMStackMap.getValue());
} catch (IOException e) {
throw new GraalError(e);
}
}
private ThreadLocal<StringBuilder> functionDump = new ThreadLocal<>();
@Override
public void dumpOffsets(LLVMTextSectionInfo textSectionInfo) {
dump("\nOffsets\n=======\n");
textSectionInfo.forEachOffsetRange((startOffset, endOffset) -> {
CompilationResult compilationResult = compilationsByStart.get(startOffset);
assert startOffset + compilationResult.getTargetCodeSize() == endOffset : compilationResult.getName();
String methodName = textSectionInfo.getSymbol(startOffset);
dump("[" + startOffset + "] " + methodName + " (" + compilationResult.getTargetCodeSize() + ")\n");
});
}
@Override
public void startDumpingFunctions() {
dump("Patchpoints\n===========\n");
}
@Override
public void startDumpingFunction(String methodSymbolName, int id, int totalFrameSize) {
StringBuilder builder = new StringBuilder();
builder.append(methodSymbolName);
builder.append(" -> f");
builder.append(id);
builder.append(" (");
builder.append(totalFrameSize);
builder.append(")\n");
functionDump.set(builder);
}
@Override
public void dumpCallSite(Call call, int actualPcOffset, SubstrateReferenceMap referenceMap) {
StringBuilder builder = functionDump.get();
builder.append(" [");
builder.append(actualPcOffset);
builder.append("] -> ");
builder.append(call.target != null ? ((HostedMethod) call.target).format("%H.%n") : "???");
builder.append(" (");
builder.append(call.pcOffset);
builder.append(") ");
referenceMap.dump(builder);
builder.append("\n");
}
@Override
public void endDumpingFunction() {
dump(functionDump.get().toString());
}
@Override
public void close() {
try {
stackMapDump.close();
} catch (IOException e) {
throw new GraalError(e);
}
}
private void dump(String str) {
try {
stackMapDump.write(str);
} catch (IOException e) {
throw new GraalError(e);
}
}
}
private static class DisabledStackMapDumper implements StackMapDumper {
@Override
public void dumpOffsets(LLVMTextSectionInfo textSectionInfo) {
}
@Override
public void startDumpingFunctions() {
}
@Override
public void startDumpingFunction(String methodSymbolName, int id, int totalFrameSize) {
}
@Override
public void dumpCallSite(Call call, int actualPcOffset, SubstrateReferenceMap referenceMap) {
}
@Override
public void endDumpingFunction() {
}
@Override
public void close() {
}
}
}