package org.testng.internal;
import org.testng.IMethodInstance;
import org.testng.ITestNGMethod;
import org.testng.collections.Lists;
import org.testng.collections.Sets;
import org.testng.thread.IWorker;
import org.testng.xml.XmlSuite;
import org.testng.xml.XmlTest;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
class ClassBasedParallelWorker extends AbstractParallelWorker {
@Override
public List<IWorker<ITestNGMethod>> createWorkers(Arguments arguments) {
List<IWorker<ITestNGMethod>> result = Lists.newArrayList();
Set<Class<?>> sequentialClasses = Sets.newHashSet();
for (ITestNGMethod m : arguments.getMethods()) {
Class<?> cls = m.getRealClass();
org.testng.annotations.ITestAnnotation test =
arguments.getFinder().findAnnotation(cls, org.testng.annotations.ITestAnnotation.class);
if (isSequential(test, arguments.getTestContext().getCurrentXmlTest())) {
sequentialClasses.add(cls);
}
}
List<IMethodInstance> methodInstances = Lists.newArrayList();
for (ITestNGMethod tm : arguments.getMethods()) {
methodInstances.addAll(methodsToMultipleMethodInstances(tm));
}
Set<Class<?>> processedClasses = Sets.newHashSet();
Map<String, String> params = null;
Class<?> prevClass = null;
for (IMethodInstance im : methodInstances) {
Class<?> c = im.getMethod().getTestClass().getRealClass();
if (!c.equals(prevClass)) {
params = getParameters(im);
prevClass = c;
}
if (sequentialClasses.contains(c)) {
if (!processedClasses.contains(c)) {
processedClasses.add(c);
TestMethodWorker worker = createTestMethodWorker(arguments, methodInstances, params, c);
result.add(worker);
}
} else {
TestMethodWorker worker =
createTestMethodWorker(arguments, Collections.singletonList(im), params, c);
result.add(worker);
}
}
return result;
}
private static List<IMethodInstance> findClasses(
List<IMethodInstance> methodInstances, Class<?> c) {
return methodInstances
.stream()
.filter(mi -> mi.getMethod().getTestClass().getRealClass() == c)
.collect(Collectors.toList());
}
private static TestMethodWorker createTestMethodWorker(
Arguments attributes,
List<IMethodInstance> methodInstances,
Map<String, String> params,
Class<?> c) {
IInvoker invoker = attributes.getInvoker();
ITestInvoker testInvoker = invoker.getTestInvoker();
IConfigInvoker configInvoker = invoker.getConfigInvoker();
return new TestMethodWorker(
testInvoker, configInvoker, findClasses(methodInstances, c),
params,
attributes.getConfigMethods(),
attributes.getClassMethodMap(),
attributes.getTestContext(),
attributes.getListeners());
}
private List<MethodInstance> methodsToMultipleMethodInstances(ITestNGMethod... methods) {
return Arrays.stream(methods).map(MethodInstance::new).collect(Collectors.toList());
}
private static boolean isSequential(
org.testng.annotations.ITestAnnotation test, XmlTest xmlTest) {
return test != null && test.getSingleThreaded()
|| XmlSuite.ParallelMode.CLASSES.equals(xmlTest.getParallel());
}
private static Map<String, String> getParameters(IMethodInstance im) {
XmlTest xmlTest = im.getMethod().getXmlTest();
return im.getMethod().findMethodParameters(xmlTest);
}
}