package org.openjdk.jmh.runner;
import org.openjdk.jmh.infra.BenchmarkParams;
import org.openjdk.jmh.infra.Control;
import org.openjdk.jmh.infra.IterationParams;
import org.openjdk.jmh.infra.ThreadParams;
import org.openjdk.jmh.profile.InternalProfiler;
import org.openjdk.jmh.profile.ProfilerFactory;
import org.openjdk.jmh.results.*;
import org.openjdk.jmh.runner.format.OutputFormat;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.TimeValue;
import org.openjdk.jmh.util.ClassUtils;
import org.openjdk.jmh.util.Utils;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.*;
class BenchmarkHandler {
private final ExecutorService executor;
private final ThreadLocal<ThreadData> threadData;
private final OutputFormat out;
private final List<InternalProfiler> profilers;
private final List<InternalProfiler> profilersRev;
private final Method method;
public BenchmarkHandler(OutputFormat out, Options options, BenchmarkParams executionParams) {
String target = executionParams.generatedBenchmark();
int lastDot = target.lastIndexOf('.');
final Class<?> clazz = ClassUtils.loadClass(target.substring(0, lastDot));
this.method = BenchmarkHandler.findBenchmarkMethod(clazz, target.substring(lastDot + 1));
this.profilers = ProfilerFactory.getSupportedInternal(options.getProfilers());
this.profilersRev = new ArrayList<>(profilers);
Collections.reverse(profilersRev);
final BlockingQueue<ThreadParams> tps = new ArrayBlockingQueue<>(executionParams.getThreads());
tps.addAll(distributeThreads(executionParams.getThreads(), executionParams.getThreadGroups()));
this.threadData = new ThreadLocal<ThreadData>() {
@Override
protected ThreadData initialValue() {
try {
Object o = clazz.getConstructor().newInstance();
ThreadParams t = tps.poll();
if (t == null) {
throw new IllegalStateException("Cannot get another thread params");
}
return new ThreadData(o, t);
} catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
throw new RuntimeException("Class " + clazz.getName() + " instantiation error ", e);
}
}
};
this.out = out;
try {
this.executor = EXECUTOR_TYPE.createExecutor(executionParams.getThreads(), executionParams.getBenchmark());
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
static List<ThreadParams> distributeThreads(int threads, int[] groups) {
List<ThreadParams> result = new ArrayList<>();
int totalGroupThreads = Utils.sum(groups);
int totalGroups = (int) Math.ceil(1D * threads / totalGroupThreads);
int totalSubgroups = groups.length;
int currentGroupThread = 0;
int currentSubgroupThread = 0;
int currentGroup = 0;
int currentSubgroup = 0;
for (int t = 0; t < threads; t++) {
while (currentSubgroupThread >= groups[currentSubgroup]) {
currentSubgroup++;
if (currentSubgroup == groups.length) {
currentGroup++;
currentSubgroup = 0;
currentGroupThread = 0;
}
currentSubgroupThread = 0;
}
result.add(new ThreadParams(
t, threads,
currentGroup, totalGroups,
currentSubgroup, totalSubgroups,
currentGroupThread, totalGroupThreads,
currentSubgroupThread, groups[currentSubgroup]
)
);
currentGroupThread++;
currentSubgroupThread++;
}
return result;
}
public static Method findBenchmarkMethod(Class<?> clazz, String methodName) {
Method method = null;
for (Method m : ClassUtils.enumerateMethods(clazz)) {
if (m.getName().equals(methodName)) {
if (isValidBenchmarkSignature(m)) {
if (method != null) {
throw new IllegalArgumentException("Ambiguous methods: \n" + method + "\n and \n" + m + "\n, which one to execute?");
}
method = m;
} else {
throw new IllegalArgumentException("Benchmark parameters do not match the signature contract.");
}
}
}
if (method == null) {
throw new IllegalArgumentException("No matching methods found in benchmark");
}
return method;
}
private static boolean isValidBenchmarkSignature(Method m) {
if (m.getReturnType() != BenchmarkTaskResult.class) {
return false;
}
final Class<?>[] parameterTypes = m.getParameterTypes();
if (parameterTypes.length != 2) {
return false;
}
if (parameterTypes[0] != InfraControl.class) {
return false;
}
if (parameterTypes[1] != ThreadParams.class) {
return false;
}
return true;
}
private static final ExecutorType EXECUTOR_TYPE = Enum.valueOf(ExecutorType.class, System.getProperty("jmh.executor", ExecutorType.FIXED_TPE.name()));
private enum ExecutorType {
CACHED_TPE {
@Override
ExecutorService createExecutor(int maxThreads, String prefix) {
return Executors.newCachedThreadPool(new WorkerThreadFactory(prefix));
}
},
FIXED_TPE {
@Override
ExecutorService createExecutor(int maxThreads, String prefix) {
return Executors.newFixedThreadPool(maxThreads, new WorkerThreadFactory(prefix));
}
},
FJP {
@Override
ExecutorService createExecutor(int maxThreads, String prefix) {
return new ForkJoinPool(maxThreads);
}
},
FJP_COMMON {
@Override
ExecutorService createExecutor(int maxThreads, String prefix) throws Exception {
Method m = Class.forName("java.util.concurrent.ForkJoinPool").getMethod("commonPool");
return (ExecutorService) m.invoke(null);
}
@Override
boolean shutdownForbidden() {
return true;
}
},
CUSTOM {
@Override
ExecutorService createExecutor(int maxThreads, String prefix) throws Exception {
String className = System.getProperty("jmh.executor.class");
return (ExecutorService) Class.forName(className).getConstructor(int.class, String.class)
.newInstance(maxThreads, prefix);
}
},
;
abstract ExecutorService createExecutor(int maxThreads, String prefix) throws Exception;
boolean shutdownForbidden() {
return false;
}
}
protected void startProfilers(BenchmarkParams benchmarkParams, IterationParams iterationParams) {
for (InternalProfiler prof : profilers) {
try {
prof.beforeIteration(benchmarkParams, iterationParams);
} catch (Throwable ex) {
throw new BenchmarkException(ex);
}
}
}
protected void stopProfilers(BenchmarkParams benchmarkParams, IterationParams iterationParams, IterationResult iterationResults) {
for (InternalProfiler prof : profilersRev) {
try {
iterationResults.addResults(prof.afterIteration(benchmarkParams, iterationParams, iterationResults));
} catch (Throwable ex) {
throw new BenchmarkException(ex);
}
}
}
public void shutdown() {
if (EXECUTOR_TYPE.shutdownForbidden() || (executor == null)) {
return;
}
while (true) {
executor.shutdown();
try {
if (executor.awaitTermination(10, TimeUnit.SECONDS)) {
return;
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return;
}
out.println("Failed to stop executor service " + executor + ", trying again; check for the unaccounted running threads");
}
}
public IterationResult runIteration(BenchmarkParams benchmarkParams, IterationParams params, boolean last) {
int numThreads = benchmarkParams.getThreads();
TimeValue runtime = params.getTime();
CountDownLatch preSetupBarrier = new CountDownLatch(numThreads);
CountDownLatch preTearDownBarrier = new CountDownLatch(numThreads);
List<Result> iterationResults = new ArrayList<>();
InfraControl control = new InfraControl(benchmarkParams, params,
preSetupBarrier, preTearDownBarrier, last,
new Control());
BenchmarkTask[] runners = new BenchmarkTask[numThreads];
for (int i = 0; i < runners.length; i++) {
runners[i] = new BenchmarkTask(control);
}
long waitDeadline = System.nanoTime() + benchmarkParams.getTimeout().convertTo(TimeUnit.NANOSECONDS);
startProfilers(benchmarkParams, params);
List<Future<BenchmarkTaskResult>> completed = new ArrayList<>();
CompletionService<BenchmarkTaskResult> srv = new ExecutorCompletionService<>(executor);
for (BenchmarkTask runner : runners) {
srv.submit(runner);
}
control.awaitWarmupReady();
switch (benchmarkParams.getMode()) {
case SingleShotTime:
break;
default:
try {
Future<BenchmarkTaskResult> failing = srv.poll(runtime.convertTo(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS);
if (failing != null) {
completed.add(failing);
waitDeadline = System.nanoTime();
}
} catch (InterruptedException e) {
}
}
control.announceDone();
control.awaitWarmdownReady();
while (completed.size() < numThreads) {
try {
long waitFor = Math.max(TimeUnit.MILLISECONDS.toNanos(100), waitDeadline - System.nanoTime());
Future<BenchmarkTaskResult> fr = srv.poll(waitFor, TimeUnit.NANOSECONDS);
if (fr == null) {
out.print("(*interrupt*) ");
for (BenchmarkTask task : runners) {
Thread runner = task.runner;
if (runner != null) {
runner.interrupt();
}
}
} else {
completed.add(fr);
}
} catch (InterruptedException ex) {
throw new BenchmarkException(ex);
}
}
long allOps = 0;
long measuredOps = 0;
List<Throwable> errors = new ArrayList<>();
for (Future<BenchmarkTaskResult> fr : completed) {
try {
BenchmarkTaskResult btr = fr.get();
iterationResults.addAll(btr.getResults());
allOps += btr.getAllOps();
measuredOps += btr.getMeasuredOps();
} catch (ExecutionException ex) {
Throwable cause = ex.getCause().getCause().getCause();
if (!(cause instanceof FailureAssistException)) {
errors.add(cause);
}
} catch (InterruptedException ex) {
throw new BenchmarkException(ex);
}
}
IterationResult result = new IterationResult(benchmarkParams, params, new IterationResultMetaData(allOps, measuredOps));
result.addResults(iterationResults);
stopProfilers(benchmarkParams, params, result);
if (!errors.isEmpty()) {
throw new BenchmarkException("Benchmark error during the run", errors);
}
return result;
}
class BenchmarkTask implements Callable<BenchmarkTaskResult> {
private volatile Thread runner;
private final InfraControl control;
BenchmarkTask(InfraControl control) {
this.control = control;
}
@Override
public BenchmarkTaskResult call() throws Exception {
try {
runner = Thread.currentThread();
ThreadData td = threadData.get();
return (BenchmarkTaskResult) method.invoke(td.instance, control, td.params);
} catch (Throwable e) {
control.isFailing = true;
control.preSetupForce();
control.preTearDownForce();
if (control.benchmarkParams.shouldSynchIterations()) {
try {
control.announceWarmupReady();
} catch (Exception e1) {
}
try {
control.announceWarmdownReady();
} catch (Exception e1) {
}
}
throw new Exception(e);
} finally {
runner = null;
}
}
}
private static class ThreadData {
final Object instance;
final ThreadParams params;
public ThreadData(Object instance, ThreadParams params) {
this.instance = instance;
this.params = params;
}
}
}