package org.openjdk.jmh.generators.core;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.*;
import org.openjdk.jmh.results.*;
import org.openjdk.jmh.runner.*;
import org.openjdk.jmh.runner.Defaults;
import org.openjdk.jmh.util.HashMultimap;
import org.openjdk.jmh.util.Multimap;
import org.openjdk.jmh.util.SampleBuffer;
import java.io.*;
import java.lang.reflect.Field;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
public class BenchmarkGenerator {
private static final String JMH_STUB_SUFFIX = "_jmhStub";
private static final String JMH_TESTCLASS_SUFFIX = "_jmhTest";
protected static final String JMH_GENERATED_SUBPACKAGE = "jmh_generated";
private final Set<BenchmarkInfo> benchmarkInfos;
private final CompilerControlPlugin compilerControl;
private final Set<String> processedBenchmarks;
private final BenchmarkGeneratorSession session;
public BenchmarkGenerator() {
benchmarkInfos = new HashSet<>();
processedBenchmarks = new HashSet<>();
compilerControl = new CompilerControlPlugin();
session = new BenchmarkGeneratorSession();
}
public void generate(GeneratorSource source, GeneratorDestination destination) {
try {
Multimap<ClassInfo, MethodInfo> clazzes = buildAnnotatedSet(source);
for (ClassInfo clazz : clazzes.keys()) {
if (!processedBenchmarks.add(clazz.getQualifiedName())) continue;
try {
validateBenchmark(clazz, clazzes.get(clazz));
Collection<BenchmarkInfo> infos = makeBenchmarkInfo(clazz, clazzes.get(clazz));
for (BenchmarkInfo info : infos) {
generateClass(destination, clazz, info);
}
benchmarkInfos.addAll(infos);
} catch (GenerationException ge) {
destination.printError(ge.getMessage(), ge.getElement());
}
}
for (Mode mode : Mode.values()) {
compilerControl.alwaysDontInline("*", "*_" + mode.shortLabel() + JMH_STUB_SUFFIX);
}
compilerControl.process(source, destination);
} catch (Throwable t) {
destination.printError("Annotation generator had thrown the exception.", t);
}
}
public void complete(GeneratorSource source, GeneratorDestination destination) {
compilerControl.finish(source, destination);
Set<BenchmarkListEntry> entries = new HashSet<>();
Multimap<String, BenchmarkListEntry> entriesByQName = new HashMultimap<>();
try (InputStream stream = destination.getResource(BenchmarkList.BENCHMARK_LIST.substring(1))) {
for (BenchmarkListEntry ble : BenchmarkList.readBenchmarkList(stream)) {
entries.add(ble);
entriesByQName.put(ble.getUserClassQName(), ble);
}
} catch (IOException e) {
} catch (UnsupportedOperationException e) {
destination.printError("Unable to read the existing benchmark list.", e);
}
for (BenchmarkInfo info : benchmarkInfos) {
try {
MethodGroup group = info.methodGroup;
for (Mode m : group.getModes()) {
BenchmarkListEntry br = new BenchmarkListEntry(
info.userClassQName,
info.generatedClassQName,
group.getName(),
m,
group.getTotalThreadCount(),
group.getGroupThreads(),
group.getGroupLabels(),
group.getWarmupIterations(),
group.getWarmupTime(),
group.getWarmupBatchSize(),
group.getMeasurementIterations(),
group.getMeasurementTime(),
group.getMeasurementBatchSize(),
group.getForks(),
group.getWarmupForks(),
group.getJvm(),
group.getJvmArgs(),
group.getJvmArgsPrepend(),
group.getJvmArgsAppend(),
group.getParams(),
group.getOutputTimeUnit(),
group.getOperationsPerInvocation(),
group.getTimeout()
);
if (entriesByQName.keys().contains(info.userClassQName)) {
destination.printNote("Benchmark entries for " + info.userClassQName + " already exist, overwriting");
entries.removeAll(entriesByQName.get(info.userClassQName));
entriesByQName.remove(info.userClassQName);
}
entries.add(br);
}
} catch (GenerationException ge) {
destination.printError(ge.getMessage(), ge.getElement());
}
}
try (OutputStream stream = destination.newResource(BenchmarkList.BENCHMARK_LIST.substring(1))) {
BenchmarkList.writeBenchmarkList(stream, entries);
} catch (IOException ex) {
destination.printError("Error writing benchmark list", ex);
}
}
private Multimap<ClassInfo, MethodInfo> buildAnnotatedSet(GeneratorSource source) {
Multimap<ClassInfo, MethodInfo> result = new HashMultimap<>();
for (ClassInfo currentClass : source.getClasses()) {
if (currentClass.getQualifiedName().contains(JMH_GENERATED_SUBPACKAGE)) continue;
if (currentClass.isAbstract()) continue;
ClassInfo walk = currentClass;
do {
for (MethodInfo mi : walk.getMethods()) {
Benchmark ann = mi.getAnnotation(Benchmark.class);
if (ann != null) {
result.put(currentClass, mi);
}
}
} while ((walk = walk.getSuperClass()) != null);
}
return result;
}
private void validateBenchmark(ClassInfo clazz, Collection<MethodInfo> methods) {
if (clazz.getPackageName().isEmpty()) {
throw new GenerationException("Benchmark class should have package other than default.", clazz);
}
if (clazz.isFinal()) {
throw new GenerationException("Benchmark classes should not be final.", clazz);
}
for (MethodInfo e : methods) {
StateObjectHandler.validateStateArgs(e);
}
boolean explicitState = BenchmarkGeneratorUtils.getAnnSuper(clazz, State.class) != null;
if (explicitState) {
StateObjectHandler.validateState(clazz);
}
for (MethodInfo e : methods) {
StateObjectHandler.validateNoCycles(e);
}
if (!explicitState || clazz.isAbstract()) {
for (FieldInfo fi : BenchmarkGeneratorUtils.getAllFields(clazz)) {
if (fi.isStatic()) continue;
throw new GenerationException(
"Field \"" + fi.getName() + "\" is declared within " +
"the class not having @" + State.class.getSimpleName() + " annotation. " +
"This can result in unspecified behavior, and prohibited.", fi);
}
}
BenchmarkGeneratorUtils.checkAnnotations(clazz);
for (FieldInfo fi : BenchmarkGeneratorUtils.getAllFields(clazz)) {
BenchmarkGeneratorUtils.checkAnnotations(fi);
}
for (MethodInfo mi : methods) {
BenchmarkGeneratorUtils.checkAnnotations(mi);
}
for (MethodInfo m : methods) {
if (!m.isPublic()) {
throw new GenerationException("@" + Benchmark.class.getSimpleName() +
" method should be public.", m);
}
if (m.isAbstract()) {
throw new GenerationException("@" + Benchmark.class.getSimpleName()
+ " method can not be abstract.", m);
}
if (m.isSynchronized()) {
State annState = BenchmarkGeneratorUtils.getAnnSuper(m, State.class);
if (annState == null) {
throw new GenerationException("@" + Benchmark.class.getSimpleName()
+ " method can only be synchronized if the enclosing class is annotated with "
+ "@" + State.class.getSimpleName() + ".", m);
} else {
if (m.isStatic() && annState.value() != Scope.Benchmark) {
throw new GenerationException("@" + Benchmark.class.getSimpleName()
+ " method can only be static and synchronized if the enclosing class is annotated with "
+ "@" + State.class.getSimpleName() + "(" + Scope.class.getSimpleName() + "." + Scope.Benchmark + ").", m);
}
}
}
}
for (MethodInfo m : methods) {
OperationsPerInvocation opi = BenchmarkGeneratorUtils.getAnnSuper(m, clazz, OperationsPerInvocation.class);
if (opi != null && opi.value() < 1) {
throw new GenerationException("The " + OperationsPerInvocation.class.getSimpleName() +
" needs to be greater than 0.", m);
}
}
for (MethodInfo m : methods) {
if (m.getAnnotation(Group.class) != null && m.getAnnotation(Threads.class) != null) {
throw new GenerationException("@" + Threads.class.getSimpleName() + " annotation is placed within " +
"the benchmark method with @" + Group.class.getSimpleName() + " annotation. " +
"This has ambiguous behavioral effect, and prohibited. " +
"Did you mean @" + GroupThreads.class.getSimpleName() + " instead?",
m);
}
}
}
private void validateBenchmarkInfo(BenchmarkInfo info) {
MethodGroup group = info.methodGroup;
if (group.methods().size() == 1) {
MethodInfo meth = group.methods().iterator().next();
if (meth.getAnnotation(Group.class) == null) {
for (ParameterInfo param : meth.getParameters()) {
State stateAnn = BenchmarkGeneratorUtils.getAnnSuper(param.getType(), State.class);
if (stateAnn != null && stateAnn.value() == Scope.Group) {
throw new GenerationException(
"Only @" + Group.class.getSimpleName() + " methods can reference @" + State.class.getSimpleName()
+ "(" + Scope.class.getSimpleName() + "." + Scope.Group + ") states.",
meth);
}
}
State stateAnn = BenchmarkGeneratorUtils.getAnnSuper(meth.getDeclaringClass(), State.class);
if (stateAnn != null && stateAnn.value() == Scope.Group) {
throw new GenerationException(
"Only @" + Group.class.getSimpleName() + " methods can implicitly reference @" + State.class.getSimpleName()
+ "(" + Scope.class.getSimpleName() + "." + Scope.Group + ") states.",
meth);
}
}
} else {
for (MethodInfo m : group.methods()) {
if (m.getAnnotation(Group.class) == null) {
throw new GenerationException(
"Internal error: multiple methods per @" + Group.class.getSimpleName()
+ ", but not all methods have @" + Group.class.getSimpleName(),
m);
}
}
}
}
private Collection<BenchmarkInfo> makeBenchmarkInfo(ClassInfo clazz, Collection<MethodInfo> methods) {
Map<String, MethodGroup> result = new TreeMap<>();
for (MethodInfo method : methods) {
Group groupAnn = method.getAnnotation(Group.class);
String groupName = (groupAnn != null) ? groupAnn.value() : method.getName();
if (!BenchmarkGeneratorUtils.checkJavaIdentifier(groupName)) {
throw new GenerationException("Group name should be the legal Java identifier.", method);
}
MethodGroup group = result.get(groupName);
if (group == null) {
group = new MethodGroup(clazz, groupName);
result.put(groupName, group);
}
BenchmarkMode mbAn = BenchmarkGeneratorUtils.getAnnSuper(method, clazz, BenchmarkMode.class);
if (mbAn != null) {
group.addModes(mbAn.value());
}
group.addStrictFP(clazz.isStrictFP());
group.addStrictFP(method.isStrictFP());
group.addMethod(method, (method.getAnnotation(GroupThreads.class) != null) ? method.getAnnotation(GroupThreads.class).value() : 1);
for (ParameterInfo pi : method.getParameters()) {
BenchmarkGeneratorUtils.addParameterValuesToGroup(pi.getType(), group);
}
BenchmarkGeneratorUtils.addParameterValuesToGroup(clazz, group);
}
for (MethodGroup group : result.values()) {
if (group.getModes().isEmpty()) {
group.addModes(Defaults.BENCHMARK_MODE);
}
}
Collection<BenchmarkInfo> benchmarks = new ArrayList<>();
for (MethodGroup group : result.values()) {
String sourcePackage = clazz.getPackageName();
String generatedPackageName = sourcePackage + "." + JMH_GENERATED_SUBPACKAGE;
String generatedClassName = BenchmarkGeneratorUtils.getGeneratedName(clazz) + "_" + group.getName() + JMH_TESTCLASS_SUFFIX;
BenchmarkInfo info = new BenchmarkInfo(clazz.getQualifiedName(), generatedPackageName, generatedClassName, group);
validateBenchmarkInfo(info);
benchmarks.add(info);
}
return benchmarks;
}
private void generateClass(GeneratorDestination destination, ClassInfo classInfo, BenchmarkInfo info) throws IOException {
StateObjectHandler states = new StateObjectHandler(compilerControl);
states.bindMethods(classInfo, info.methodGroup);
PrintWriter writer = new PrintWriter(destination.newClass(info.generatedClassQName), false);
writer.println("package " + info.generatedPackageName + ';');
writer.println();
generateImport(writer);
states.addImports(writer);
writer.println("public final class " + info.generatedClassName + " {");
writer.println();
Paddings.padding(writer);
writer.println(ident(1) + "int startRndMask;");
writer.println(ident(1) + "BenchmarkParams benchmarkParams;");
writer.println(ident(1) + "IterationParams iterationParams;");
writer.println(ident(1) + "ThreadParams threadParams;");
writer.println(ident(1) + "Blackhole blackhole;");
writer.println(ident(1) + "Control notifyControl;");
for (Mode benchmarkKind : Mode.values()) {
if (benchmarkKind == Mode.All) continue;
generateMethod(benchmarkKind, writer, info.methodGroup, states);
}
for (String s : states.getStateInitializers()) {
writer.println(ident(1) + s);
}
writer.println();
for (String s : states.getFields()) {
writer.println(ident(1) + s);
}
writer.println();
states.writeStateOverrides(session, destination);
writer.println("}");
writer.println();
writer.close();
}
private void generateImport(PrintWriter writer) {
Class<?>[] imports = new Class<?>[]{
List.class, AtomicInteger.class,
Collection.class, ArrayList.class,
TimeUnit.class, CompilerControl.class,
InfraControl.class, ThreadParams.class,
BenchmarkTaskResult.class,
Result.class, ThroughputResult.class, AverageTimeResult.class,
SampleTimeResult.class, SingleShotResult.class, SampleBuffer.class,
Mode.class, Fork.class, Measurement.class, Threads.class, Warmup.class,
BenchmarkMode.class, RawResults.class, ResultRole.class,
Field.class, BenchmarkParams.class, IterationParams.class,
Blackhole.class, Control.class,
ScalarResult.class, AggregationPolicy.class,
FailureAssistException.class
};
for (Class<?> c : imports) {
writer.println("import " + c.getName() + ';');
}
writer.println();
}
private void generateMethod(Mode benchmarkKind, PrintWriter writer, MethodGroup methodGroup, StateObjectHandler states) {
writer.println();
switch (benchmarkKind) {
case Throughput:
generateThroughput(writer, benchmarkKind, methodGroup, states);
break;
case AverageTime:
generateAverageTime(writer, benchmarkKind, methodGroup, states);
break;
case SampleTime:
generateSampleTime(writer, benchmarkKind, methodGroup, states);
break;
case SingleShotTime:
generateSingleShotTime(writer, benchmarkKind, methodGroup, states);
break;
default:
throw new AssertionError("Shouldn't be here");
}
}
private void generateThroughput(PrintWriter writer, Mode benchmarkKind, MethodGroup methodGroup, StateObjectHandler states) {
writer.println(ident(1) + "public BenchmarkTaskResult " + methodGroup.getName() + "_" + benchmarkKind +
"(InfraControl control, ThreadParams threadParams) throws Throwable {");
methodProlog(writer);
boolean isSingleMethod = (methodGroup.methods().size() == 1);
int subGroup = -1;
for (MethodInfo method : methodGroup.methods()) {
subGroup++;
writer.println(ident(2) + "if (threadParams.getSubgroupIndex() == " + subGroup + ") {");
writer.println(ident(3) + "RawResults res = new RawResults();");
iterationProlog(writer, 3, method, states);
writer.println(ident(3) + "control.announceWarmupReady();");
writer.println(ident(3) + "while (control.warmupShouldWait) {");
invocationProlog(writer, 4, method, states, false);
writer.println(ident(4) + emitCall(method, states) + ';');
invocationEpilog(writer, 4, method, states, false);
writer.println(ident(4) + "res.allOps++;");
writer.println(ident(3) + "}");
writer.println();
writer.println(ident(3) + "notifyControl.startMeasurement = true;");
writer.println(ident(3) + method.getName() + "_" + benchmarkKind.shortLabel() + JMH_STUB_SUFFIX +
"(" + getStubArgs() + prefix(states.getArgList(method)) + ");");
writer.println(ident(3) + "notifyControl.stopMeasurement = true;");
writer.println(ident(3) + "control.announceWarmdownReady();");
writer.println(ident(3) + "try {");
writer.println(ident(4) + "while (control.warmdownShouldWait) {");
invocationProlog(writer, 5, method, states, false);
writer.println(ident(5) + emitCall(method, states) + ';');
invocationEpilog(writer, 5, method, states, false);
writer.println(ident(5) + "res.allOps++;");
writer.println(ident(4) + "}");
writer.println(ident(4) + "control.preTearDown();");
writer.println(ident(3) + "} catch (InterruptedException ie) {");
writer.println(ident(4) + "control.preTearDownForce();");
writer.println(ident(3) + "}");
iterationEpilog(writer, 3, method, states);
writer.println(ident(3) + "res.allOps += res.measuredOps;");
writer.println(ident(3) + "int batchSize = iterationParams.getBatchSize();");
writer.println(ident(3) + "int opsPerInv = benchmarkParams.getOpsPerInvocation();");
writer.println(ident(3) + "res.allOps *= opsPerInv;");
writer.println(ident(3) + "res.allOps /= batchSize;");
writer.println(ident(3) + "res.measuredOps *= opsPerInv;");
writer.println(ident(3) + "res.measuredOps /= batchSize;");
writer.println(ident(3) + "BenchmarkTaskResult results = new BenchmarkTaskResult((long)res.allOps, (long)res.measuredOps);");
if (isSingleMethod) {
writer.println(ident(3) + "results.add(new ThroughputResult(ResultRole.PRIMARY, \"" + method.getName() + "\", res.measuredOps, res.getTime(), benchmarkParams.getTimeUnit()));");
} else {
writer.println(ident(3) + "results.add(new ThroughputResult(ResultRole.PRIMARY, \"" + methodGroup.getName() + "\", res.measuredOps, res.getTime(), benchmarkParams.getTimeUnit()));");
writer.println(ident(3) + "results.add(new ThroughputResult(ResultRole.SECONDARY, \"" + method.getName() + "\", res.measuredOps, res.getTime(), benchmarkParams.getTimeUnit()));");
}
addAuxCounters(writer, "ThroughputResult", states, method);
methodEpilog(writer);
writer.println(ident(3) + "return results;");
writer.println(ident(2) + "} else");
}
writer.println(ident(3) + "throw new IllegalStateException(\"Harness failed to distribute threads among groups properly\");");
writer.println(ident(1) + "}");
writer.println();
for (MethodInfo method : methodGroup.methods()) {
String methodName = method.getName() + "_" + benchmarkKind.shortLabel() + JMH_STUB_SUFFIX;
compilerControl.defaultForceInline(method);
writer.println(ident(1) + "public static" + (methodGroup.isStrictFP() ? " strictfp" : "") + " void " + methodName + "(" +
getStubTypeArgs() + prefix(states.getTypeArgList(method)) + ") throws Throwable {");
writer.println(ident(2) + "long operations = 0;");
writer.println(ident(2) + "long realTime = 0;");
writer.println(ident(2) + "result.startTime = System.nanoTime();");
writer.println(ident(2) + "do {");
invocationProlog(writer, 3, method, states, true);
writer.println(ident(3) + emitCall(method, states) + ';');
invocationEpilog(writer, 3, method, states, true);
writer.println(ident(3) + "operations++;");
writer.println(ident(2) + "} while(!control.isDone);");
writer.println(ident(2) + "result.stopTime = System.nanoTime();");
writer.println(ident(2) + "result.realTime = realTime;");
writer.println(ident(2) + "result.measuredOps = operations;");
writer.println(ident(1) + "}");
writer.println();
}
}
private void addAuxCounters(PrintWriter writer, String resName, StateObjectHandler states, MethodInfo method) {
for (String res : states.getAuxResults(method, resName)) {
writer.println(ident(3) + "results.add(" + res + ");");
}
}
private void generateAverageTime(PrintWriter writer, Mode benchmarkKind, MethodGroup methodGroup, StateObjectHandler states) {
writer.println(ident(1) + "public BenchmarkTaskResult " + methodGroup.getName() + "_" + benchmarkKind +
"(InfraControl control, ThreadParams threadParams) throws Throwable {");
methodProlog(writer);
boolean isSingleMethod = (methodGroup.methods().size() == 1);
int subGroup = -1;
for (MethodInfo method : methodGroup.methods()) {
subGroup++;
writer.println(ident(2) + "if (threadParams.getSubgroupIndex() == " + subGroup + ") {");
writer.println(ident(3) + "RawResults res = new RawResults();");
iterationProlog(writer, 3, method, states);
writer.println(ident(3) + "control.announceWarmupReady();");
writer.println(ident(3) + "while (control.warmupShouldWait) {");
invocationProlog(writer, 4, method, states, false);
writer.println(ident(4) + emitCall(method, states) + ';');
invocationEpilog(writer, 4, method, states, false);
writer.println(ident(4) + "res.allOps++;");
writer.println(ident(3) + "}");
writer.println();
writer.println(ident(3) + "notifyControl.startMeasurement = true;");
writer.println(ident(3) + method.getName() + "_" + benchmarkKind.shortLabel() + JMH_STUB_SUFFIX + "(" + getStubArgs() + prefix(states.getArgList(method)) + ");");
writer.println(ident(3) + "notifyControl.stopMeasurement = true;");
writer.println(ident(3) + "control.announceWarmdownReady();");
writer.println(ident(3) + "try {");
writer.println(ident(4) + "while (control.warmdownShouldWait) {");
invocationProlog(writer, 5, method, states, false);
writer.println(ident(5) + emitCall(method, states) + ';');
invocationEpilog(writer, 5, method, states, false);
writer.println(ident(5) + "res.allOps++;");
writer.println(ident(4) + "}");
writer.println(ident(4) + "control.preTearDown();");
writer.println(ident(3) + "} catch (InterruptedException ie) {");
writer.println(ident(4) + "control.preTearDownForce();");
writer.println(ident(3) + "}");
iterationEpilog(writer, 3, method, states);
writer.println(ident(3) + "res.allOps += res.measuredOps;");
writer.println(ident(3) + "int batchSize = iterationParams.getBatchSize();");
writer.println(ident(3) + "int opsPerInv = benchmarkParams.getOpsPerInvocation();");
writer.println(ident(3) + "res.allOps *= opsPerInv;");
writer.println(ident(3) + "res.allOps /= batchSize;");
writer.println(ident(3) + "res.measuredOps *= opsPerInv;");
writer.println(ident(3) + "res.measuredOps /= batchSize;");
writer.println(ident(3) + "BenchmarkTaskResult results = new BenchmarkTaskResult((long)res.allOps, (long)res.measuredOps);");
if (isSingleMethod) {
writer.println(ident(3) + "results.add(new AverageTimeResult(ResultRole.PRIMARY, \"" + method.getName() + "\", res.measuredOps, res.getTime(), benchmarkParams.getTimeUnit()));");
} else {
writer.println(ident(3) + "results.add(new AverageTimeResult(ResultRole.PRIMARY, \"" + methodGroup.getName() + "\", res.measuredOps, res.getTime(), benchmarkParams.getTimeUnit()));");
writer.println(ident(3) + "results.add(new AverageTimeResult(ResultRole.SECONDARY, \"" + method.getName() + "\", res.measuredOps, res.getTime(), benchmarkParams.getTimeUnit()));");
}
addAuxCounters(writer, "AverageTimeResult", states, method);
methodEpilog(writer);
writer.println(ident(3) + "return results;");
writer.println(ident(2) + "} else");
}
writer.println(ident(3) + "throw new IllegalStateException(\"Harness failed to distribute threads among groups properly\");");
writer.println(ident(1) + "}");
writer.println();
for (MethodInfo method : methodGroup.methods()) {
String methodName = method.getName() + "_" + benchmarkKind.shortLabel() + JMH_STUB_SUFFIX;
compilerControl.defaultForceInline(method);
writer.println(ident(1) + "public static" + (methodGroup.isStrictFP() ? " strictfp" : "") + " void " + methodName +
"(" + getStubTypeArgs() + prefix(states.getTypeArgList(method)) + ") throws Throwable {");
writer.println(ident(2) + "long operations = 0;");
writer.println(ident(2) + "long realTime = 0;");
writer.println(ident(2) + "result.startTime = System.nanoTime();");
writer.println(ident(2) + "do {");
invocationProlog(writer, 3, method, states, true);
writer.println(ident(3) + emitCall(method, states) + ';');
invocationEpilog(writer, 3, method, states, true);
writer.println(ident(3) + "operations++;");
writer.println(ident(2) + "} while(!control.isDone);");
writer.println(ident(2) + "result.stopTime = System.nanoTime();");
writer.println(ident(2) + "result.realTime = realTime;");
writer.println(ident(2) + "result.measuredOps = operations;");
writer.println(ident(1) + "}");
writer.println();
}
}
private String getStubArgs() {
return "control, res, benchmarkParams, iterationParams, threadParams, blackhole, notifyControl, startRndMask";
}
private String getStubTypeArgs() {
return "InfraControl control, RawResults result, " +
"BenchmarkParams benchmarkParams, IterationParams iterationParams, ThreadParams threadParams, " +
"Blackhole blackhole, Control notifyControl, int startRndMask";
}
private void methodProlog(PrintWriter writer) {
writer.println(ident(2) + "this.benchmarkParams = control.benchmarkParams;");
writer.println(ident(2) + "this.iterationParams = control.iterationParams;");
writer.println(ident(2) + "this.threadParams = threadParams;");
writer.println(ident(2) + "this.notifyControl = control.notifyControl;");
writer.println(ident(2) + "if (this.blackhole == null) {");
writer.println(ident(3) + "this.blackhole = new Blackhole(\"Today's password is swordfish. I understand instantiating Blackholes directly is dangerous.\");");
writer.println(ident(2) + "}");
}
private void methodEpilog(PrintWriter writer) {
writer.println(ident(3) + "this.blackhole.evaporate(\"Yes, I am Stephen Hawking, and know a thing or two about black holes.\");");
}
private String prefix(String argList) {
if (argList.trim().isEmpty()) {
return "";
} else {
return ", " + argList;
}
}
private void generateSampleTime(PrintWriter writer, Mode benchmarkKind, MethodGroup methodGroup, StateObjectHandler states) {
writer.println(ident(1) + "public BenchmarkTaskResult " + methodGroup.getName() + "_" + benchmarkKind +
"(InfraControl control, ThreadParams threadParams) throws Throwable {");
methodProlog(writer);
boolean isSingleMethod = (methodGroup.methods().size() == 1);
int subGroup = -1;
for (MethodInfo method : methodGroup.methods()) {
subGroup++;
writer.println(ident(2) + "if (threadParams.getSubgroupIndex() == " + subGroup + ") {");
writer.println(ident(3) + "RawResults res = new RawResults();");
iterationProlog(writer, 3, method, states);
writer.println(ident(3) + "control.announceWarmupReady();");
writer.println(ident(3) + "while (control.warmupShouldWait) {");
invocationProlog(writer, 4, method, states, false);
writer.println(ident(4) + emitCall(method, states) + ';');
invocationEpilog(writer, 4, method, states, false);
writer.println(ident(4) + "res.allOps++;");
writer.println(ident(3) + "}");
writer.println();
writer.println(ident(3) + "notifyControl.startMeasurement = true;");
writer.println(ident(3) + "int targetSamples = (int) (control.getDuration(TimeUnit.MILLISECONDS) * 20); // at max, 20 timestamps per millisecond");
writer.println(ident(3) + "int batchSize = iterationParams.getBatchSize();");
writer.println(ident(3) + "int opsPerInv = benchmarkParams.getOpsPerInvocation();");
writer.println(ident(3) + "SampleBuffer buffer = new SampleBuffer();");
writer.println(ident(3) + method.getName() + "_" + benchmarkKind.shortLabel() + JMH_STUB_SUFFIX + "(" +
getStubArgs() + ", buffer, targetSamples, opsPerInv, batchSize" + prefix(states.getArgList(method)) + ");");
writer.println(ident(3) + "notifyControl.stopMeasurement = true;");
writer.println(ident(3) + "control.announceWarmdownReady();");
writer.println(ident(3) + "try {");
writer.println(ident(4) + "while (control.warmdownShouldWait) {");
invocationProlog(writer, 5, method, states, false);
writer.println(ident(5) + emitCall(method, states) + ';');
invocationEpilog(writer, 5, method, states, false);
writer.println(ident(5) + "res.allOps++;");
writer.println(ident(4) + "}");
writer.println(ident(4) + "control.preTearDown();");
writer.println(ident(3) + "} catch (InterruptedException ie) {");
writer.println(ident(4) + "control.preTearDownForce();");
writer.println(ident(3) + "}");
iterationEpilog(writer, 3, method, states);
writer.println(ident(3) + "res.allOps += res.measuredOps * batchSize;");
writer.println(ident(3) + "res.allOps *= opsPerInv;");
writer.println(ident(3) + "res.allOps /= batchSize;");
writer.println(ident(3) + "res.measuredOps *= opsPerInv;");
writer.println(ident(3) + "BenchmarkTaskResult results = new BenchmarkTaskResult((long)res.allOps, (long)res.measuredOps);");
if (isSingleMethod) {
writer.println(ident(3) + "results.add(new SampleTimeResult(ResultRole.PRIMARY, \"" + method.getName() + "\", buffer, benchmarkParams.getTimeUnit()));");
} else {
writer.println(ident(3) + "results.add(new SampleTimeResult(ResultRole.PRIMARY, \"" + methodGroup.getName() + "\", buffer, benchmarkParams.getTimeUnit()));");
writer.println(ident(3) + "results.add(new SampleTimeResult(ResultRole.SECONDARY, \"" + method.getName() + "\", buffer, benchmarkParams.getTimeUnit()));");
}
addAuxCounters(writer, "SampleTimeResult", states, method);
methodEpilog(writer);
writer.println(ident(3) + "return results;");
writer.println(ident(2) + "} else");
}
writer.println(ident(3) + "throw new IllegalStateException(\"Harness failed to distribute threads among groups properly\");");
writer.println(ident(1) + "}");
writer.println();
for (MethodInfo method : methodGroup.methods()) {
String methodName = method.getName() + "_" + benchmarkKind.shortLabel() + JMH_STUB_SUFFIX;
compilerControl.defaultForceInline(method);
writer.println(ident(1) + "public static" + (methodGroup.isStrictFP() ? " strictfp" : "") + " void " + methodName + "(" +
getStubTypeArgs() + ", SampleBuffer buffer, int targetSamples, long opsPerInv, int batchSize" + prefix(states.getTypeArgList(method)) + ") throws Throwable {");
writer.println(ident(2) + "long realTime = 0;");
writer.println(ident(2) + "long operations = 0;");
writer.println(ident(2) + "int rnd = (int)System.nanoTime();");
writer.println(ident(2) + "int rndMask = startRndMask;");
writer.println(ident(2) + "long time = 0;");
writer.println(ident(2) + "int currentStride = 0;");
writer.println(ident(2) + "do {");
invocationProlog(writer, 3, method, states, true);
writer.println(ident(3) + "rnd = (rnd * 1664525 + 1013904223);");
writer.println(ident(3) + "boolean sample = (rnd & rndMask) == 0;");
writer.println(ident(3) + "if (sample) {");
writer.println(ident(4) + "time = System.nanoTime();");
writer.println(ident(3) + "}");
writer.println(ident(3) + "for (int b = 0; b < batchSize; b++) {");
writer.println(ident(4) + "if (control.volatileSpoiler) return;");
writer.println(ident(4) + "" + emitCall(method, states) + ';');
writer.println(ident(3) + "}");
writer.println(ident(3) + "if (sample) {");
writer.println(ident(4) + "buffer.add((System.nanoTime() - time) / opsPerInv);");
writer.println(ident(4) + "if (currentStride++ > targetSamples) {");
writer.println(ident(5) + "buffer.half();");
writer.println(ident(5) + "currentStride = 0;");
writer.println(ident(5) + "rndMask = (rndMask << 1) + 1;");
writer.println(ident(4) + "}");
writer.println(ident(3) + "}");
invocationEpilog(writer, 3, method, states, true);
writer.println(ident(3) + "operations++;");
writer.println(ident(2) + "} while(!control.isDone);");
writer.println(ident(2) + "startRndMask = Math.max(startRndMask, rndMask);");
writer.println(ident(2) + "result.realTime = realTime;");
writer.println(ident(2) + "result.measuredOps = operations;");
writer.println(ident(1) + "}");
writer.println();
}
}
private void generateSingleShotTime(PrintWriter writer, Mode benchmarkKind, MethodGroup methodGroup, StateObjectHandler states) {
writer.println(ident(1) + "public BenchmarkTaskResult " + methodGroup.getName() + "_" + benchmarkKind + "(InfraControl control, ThreadParams threadParams) throws Throwable {");
methodProlog(writer);
boolean isSingleMethod = (methodGroup.methods().size() == 1);
int subGroup = -1;
for (MethodInfo method : methodGroup.methods()) {
compilerControl.defaultForceInline(method);
subGroup++;
writer.println(ident(2) + "if (threadParams.getSubgroupIndex() == " + subGroup + ") {");
iterationProlog(writer, 3, method, states);
writer.println(ident(3) + "notifyControl.startMeasurement = true;");
writer.println(ident(3) + "RawResults res = new RawResults();");
writer.println(ident(3) + "int batchSize = iterationParams.getBatchSize();");
writer.println(ident(3) + method.getName() + "_" + benchmarkKind.shortLabel() + JMH_STUB_SUFFIX + "(" +
getStubArgs() + ", batchSize" + prefix(states.getArgList(method)) + ");");
writer.println(ident(3) + "control.preTearDown();");
iterationEpilog(writer, 3, method, states);
writer.println(ident(3) + "int opsPerInv = control.benchmarkParams.getOpsPerInvocation();");
writer.println(ident(3) + "long totalOps = opsPerInv;");
writer.println(ident(3) + "BenchmarkTaskResult results = new BenchmarkTaskResult((long)totalOps, (long)totalOps);");
if (isSingleMethod) {
writer.println(ident(3) + "results.add(new SingleShotResult(ResultRole.PRIMARY, \"" + method.getName() + "\", res.getTime(), benchmarkParams.getTimeUnit()));");
} else {
writer.println(ident(3) + "results.add(new SingleShotResult(ResultRole.PRIMARY, \"" + methodGroup.getName() + "\", res.getTime(), benchmarkParams.getTimeUnit()));");
writer.println(ident(3) + "results.add(new SingleShotResult(ResultRole.SECONDARY, \"" + method.getName() + "\", res.getTime(), benchmarkParams.getTimeUnit()));");
}
addAuxCounters(writer, "SingleShotResult", states, method);
methodEpilog(writer);
writer.println(ident(3) + "return results;");
writer.println(ident(2) + "} else");
}
writer.println(ident(3) + "throw new IllegalStateException(\"Harness failed to distribute threads among groups properly\");");
writer.println(ident(1) + "}");
writer.println();
for (MethodInfo method : methodGroup.methods()) {
String methodName = method.getName() + "_" + benchmarkKind.shortLabel() + JMH_STUB_SUFFIX;
compilerControl.defaultForceInline(method);
writer.println(ident(1) + "public static" + (methodGroup.isStrictFP() ? " strictfp" : "") + " void " + methodName +
"(" + getStubTypeArgs() + ", int batchSize" + prefix(states.getTypeArgList(method)) + ") throws Throwable {");
writer.println(ident(2) + "long realTime = 0;");
writer.println(ident(2) + "result.startTime = System.nanoTime();");
writer.println(ident(2) + "for (int b = 0; b < batchSize; b++) {");
writer.println(ident(3) + "if (control.volatileSpoiler) return;");
invocationProlog(writer, 3, method, states, true);
writer.println(ident(3) + emitCall(method, states) + ';');
invocationEpilog(writer, 3, method, states, true);
writer.println(ident(2) + "}");
writer.println(ident(2) + "result.stopTime = System.nanoTime();");
writer.println(ident(2) + "result.realTime = realTime;");
writer.println(ident(1) + "}");
writer.println();
}
}
private void invocationProlog(PrintWriter writer, int prefix, MethodInfo method, StateObjectHandler states, boolean pauseMeasurement) {
if (states.hasInvocationStubs(method)) {
for (String s : states.getInvocationSetups(method))
writer.println(ident(prefix) + s);
if (pauseMeasurement)
writer.println(ident(prefix) + "long rt = System.nanoTime();");
}
}
private void invocationEpilog(PrintWriter writer, int prefix, MethodInfo method, StateObjectHandler states, boolean pauseMeasurement) {
if (states.hasInvocationStubs(method)) {
if (pauseMeasurement)
writer.println(ident(prefix) + "realTime += (System.nanoTime() - rt);");
for (String s : states.getInvocationTearDowns(method))
writer.println(ident(prefix) + s);
}
}
private void iterationProlog(PrintWriter writer, int prefix, MethodInfo method, StateObjectHandler states) {
for (String s : states.getStateGetters(method)) writer.println(ident(prefix) + s);
writer.println();
writer.println(ident(prefix) + "control.preSetup();");
for (String s : states.getIterationSetups(method)) writer.println(ident(prefix) + s);
writer.println();
for (String s : states.getAuxResets(method)) writer.println(ident(prefix) + s);
writer.println();
}
private void iterationEpilog(PrintWriter writer, int prefix, MethodInfo method, StateObjectHandler states) {
for (String s : states.getIterationTearDowns(method)) writer.println(ident(prefix) + s);
writer.println();
writer.println(ident(prefix) + "if (control.isLastIteration()) {");
for (String s : states.getRunTearDowns(method)) writer.println(ident(prefix + 1) + s);
for (String s : states.getStateDestructors(method)) writer.println(ident(prefix + 1) + s);
writer.println(ident(prefix) + "}");
}
private String emitCall(MethodInfo method, StateObjectHandler states) {
if ("void".equalsIgnoreCase(method.getReturnType())) {
return states.getImplicit("bench").localIdentifier + "." + method.getName() + "(" + states.getBenchmarkArgList(method) + ")";
} else {
return "blackhole.consume(" + states.getImplicit("bench").localIdentifier + "." + method.getName() + "(" + states.getBenchmarkArgList(method) + "))";
}
}
static volatile String[] INDENTS;
static final Object INDENTS_LOCK = new Object();
static String ident(int tabs) {
String[] is = INDENTS;
if (is == null || tabs >= is.length) {
synchronized (INDENTS_LOCK) {
is = INDENTS;
if (is == null || tabs >= is.length) {
final int TAB_SIZE = 4;
is = new String[tabs + 1];
for (int p = 0; p <= tabs; p++) {
char[] cs = new char[p * TAB_SIZE];
Arrays.fill(cs, ' ');
is[p] = new String(cs);
}
INDENTS = is;
}
}
}
return is[tabs];
}
}