package org.junit.jupiter.api;
import static org.junit.jupiter.api.AssertionUtils.buildPrefix;
import static org.junit.jupiter.api.AssertionUtils.nullSafeGet;
import static org.junit.jupiter.api.Assertions.fail;
import java.time.Duration;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.api.function.ThrowingSupplier;
import org.junit.platform.commons.JUnitException;
import org.junit.platform.commons.util.ExceptionUtils;
import org.opentest4j.AssertionFailedError;
class AssertTimeout {
private AssertTimeout() {
}
static void assertTimeout(Duration timeout, Executable executable) {
assertTimeout(timeout, executable, (String) null);
}
static void assertTimeout(Duration timeout, Executable executable, String message) {
assertTimeout(timeout, () -> {
executable.execute();
return null;
}, message);
}
static void assertTimeout(Duration timeout, Executable executable, Supplier<String> messageSupplier) {
assertTimeout(timeout, () -> {
executable.execute();
return null;
}, messageSupplier);
}
static <T> T assertTimeout(Duration timeout, ThrowingSupplier<T> supplier) {
return assertTimeout(timeout, supplier, (Object) null);
}
static <T> T assertTimeout(Duration timeout, ThrowingSupplier<T> supplier, String message) {
return assertTimeout(timeout, supplier, (Object) message);
}
static <T> T assertTimeout(Duration timeout, ThrowingSupplier<T> supplier, Supplier<String> messageSupplier) {
return assertTimeout(timeout, supplier, (Object) messageSupplier);
}
private static <T> T assertTimeout(Duration timeout, ThrowingSupplier<T> supplier, Object messageOrSupplier) {
long timeoutInMillis = timeout.toMillis();
long start = System.currentTimeMillis();
T result = null;
try {
result = supplier.get();
}
catch (Throwable ex) {
ExceptionUtils.throwAsUncheckedException(ex);
}
long timeElapsed = System.currentTimeMillis() - start;
if (timeElapsed > timeoutInMillis) {
fail(buildPrefix(nullSafeGet(messageOrSupplier)) + "execution exceeded timeout of " + timeoutInMillis
+ " ms by " + (timeElapsed - timeoutInMillis) + " ms");
}
return result;
}
static void assertTimeoutPreemptively(Duration timeout, Executable executable) {
assertTimeoutPreemptively(timeout, executable, (String) null);
}
static void assertTimeoutPreemptively(Duration timeout, Executable executable, String message) {
assertTimeoutPreemptively(timeout, () -> {
executable.execute();
return null;
}, message);
}
static void assertTimeoutPreemptively(Duration timeout, Executable executable, Supplier<String> messageSupplier) {
assertTimeoutPreemptively(timeout, () -> {
executable.execute();
return null;
}, messageSupplier);
}
static <T> T assertTimeoutPreemptively(Duration timeout, ThrowingSupplier<T> supplier) {
return assertTimeoutPreemptively(timeout, supplier, (Object) null);
}
static <T> T assertTimeoutPreemptively(Duration timeout, ThrowingSupplier<T> supplier, String message) {
return assertTimeoutPreemptively(timeout, supplier, (Object) message);
}
static <T> T assertTimeoutPreemptively(Duration timeout, ThrowingSupplier<T> supplier,
Supplier<String> messageSupplier) {
return assertTimeoutPreemptively(timeout, supplier, (Object) messageSupplier);
}
private static <T> T assertTimeoutPreemptively(Duration timeout, ThrowingSupplier<T> supplier,
Object messageOrSupplier) {
AtomicReference<Thread> threadReference = new AtomicReference<>();
ExecutorService executorService = Executors.newSingleThreadExecutor(new TimeoutThreadFactory());
try {
Future<T> future = executorService.submit(() -> {
try {
threadReference.set(Thread.currentThread());
return supplier.get();
}
catch (Throwable throwable) {
throw ExceptionUtils.throwAsUncheckedException(throwable);
}
});
long timeoutInMillis = timeout.toMillis();
try {
return future.get(timeoutInMillis, TimeUnit.MILLISECONDS);
}
catch (TimeoutException ex) {
String message = buildPrefix(nullSafeGet(messageOrSupplier)) + "execution timed out after "
+ timeoutInMillis + " ms";
Thread thread = threadReference.get();
if (thread != null) {
ExecutionTimeoutException exception = new ExecutionTimeoutException(
"Execution timed out in thread " + thread.getName());
exception.setStackTrace(thread.getStackTrace());
throw new AssertionFailedError(message, exception);
}
else {
throw new AssertionFailedError(message);
}
}
catch (ExecutionException ex) {
throw ExceptionUtils.throwAsUncheckedException(ex.getCause());
}
catch (Throwable ex) {
throw ExceptionUtils.throwAsUncheckedException(ex);
}
}
finally {
executorService.shutdownNow();
}
}
private static class ExecutionTimeoutException extends JUnitException {
private static final long serialVersionUID = 1L;
ExecutionTimeoutException(String message) {
super(message);
}
}
private static class TimeoutThreadFactory implements ThreadFactory {
private static final AtomicInteger threadNumber = new AtomicInteger(1);
public Thread newThread(Runnable r) {
return new Thread(r, "junit-timeout-thread-" + threadNumber.getAndIncrement());
}
}
}