package org.jruby.ext.timeout;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.jcodings.specific.UTF8Encoding;
import org.jruby.Ruby;
import org.jruby.RubyClass;
import org.jruby.RubyException;
import org.jruby.RubyKernel;
import org.jruby.RubyModule;
import org.jruby.RubyNumeric;
import org.jruby.RubyObject;
import org.jruby.RubyString;
import org.jruby.RubyThread;
import org.jruby.RubyTime;
import org.jruby.anno.JRubyMethod;
import org.jruby.exceptions.RaiseException;
import org.jruby.runtime.Helpers;
import org.jruby.runtime.Block;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.threading.DaemonThreadFactory;
import org.jruby.util.ByteList;
public class Timeout {
public static final String EXECUTOR_VARIABLE = "__executor__";
public static void load(Ruby runtime) {
define(runtime.getOrCreateModule("Timeout"));
}
public static void define(RubyModule timeout) {
timeout.defineAnnotatedMethods(Timeout.class);
ScheduledThreadPoolExecutor executor =
new ScheduledThreadPoolExecutor(Runtime.getRuntime().availableProcessors(), new DaemonThreadFactory());
executor.setRemoveOnCancelPolicy(true);
timeout.setInternalVariable(EXECUTOR_VARIABLE, executor);
}
@JRubyMethod(module = true)
public static IRubyObject timeout(final ThreadContext context, IRubyObject recv, IRubyObject seconds, Block block) {
return timeout(context, recv, seconds, context.nil, block);
}
private static final ByteList TIMEOUT_MESSAGE = ByteList.create("execution expired");
static { TIMEOUT_MESSAGE.setEncoding(UTF8Encoding.INSTANCE); }
private static RubyString defaultTimeoutMessage(final ThreadContext context) {
RubyString message = RubyString.newString(context.runtime, TIMEOUT_MESSAGE);
message.setFrozen(true);
return message;
}
@JRubyMethod(module = true)
public static IRubyObject timeout(final ThreadContext context, IRubyObject recv, IRubyObject seconds, IRubyObject exceptionType, Block block) {
return timeout(context, recv, seconds, exceptionType, defaultTimeoutMessage(context), block);
}
@JRubyMethod(module = true)
public static IRubyObject timeout(final ThreadContext context, IRubyObject recv, IRubyObject seconds, IRubyObject exceptionType, IRubyObject message, Block block) {
if ( nilOrZeroSeconds(context, seconds) ) {
return block.yieldSpecific(context);
}
final Ruby runtime = context.runtime;
final RubyModule timeout = runtime.getModule("Timeout");
final RubyThread currentThread = context.getThread();
final AtomicBoolean latch = new AtomicBoolean(false);
final Object id = exceptionType.isNil() ? new Object() : null;
Runnable timeoutRunnable = id != null ?
TimeoutTask.newAnonymousTask(currentThread, timeout, latch, id, message.convertToString()) :
TimeoutTask.newTaskWithException(currentThread, timeout, latch, exceptionType, message.convertToString());
ScheduledThreadPoolExecutor executor = (ScheduledThreadPoolExecutor) timeout.getInternalVariable("__executor__");
try {
return yieldWithTimeout(executor, context, seconds, block, timeoutRunnable, latch);
} catch (RaiseException re) {
if (re.getException().getMetaClass() == getTimeoutError(timeout)) {
if (id != null) {
raiseTimeoutErrorIfMatches(context, timeout, re, id);
}
}
throw re;
}
}
private static boolean nilOrZeroSeconds(final ThreadContext context, final IRubyObject seconds) {
if (seconds instanceof RubyNumeric) return ((RubyNumeric) seconds).isZero();
return seconds.isNil() || Helpers.invoke(context, seconds, "zero?").isTrue();
}
private static IRubyObject yieldWithTimeout(ScheduledThreadPoolExecutor executor, ThreadContext context,
final IRubyObject seconds, final Block block,
final Runnable runnable, final AtomicBoolean latch) throws RaiseException {
final long micros = (long) ( RubyTime.convertTimeInterval(context, seconds) * 1000000 );
ScheduledFuture timeoutFuture = null;
try {
timeoutFuture = executor.schedule(runnable, micros, TimeUnit.MICROSECONDS);
return block.yield(context, seconds);
}
finally {
if ( timeoutFuture != null ) killTimeoutThread(context, timeoutFuture, latch);
}
}
private static class TimeoutTask implements Runnable {
final RubyThread currentThread;
final AtomicBoolean latch;
final RubyModule timeout;
final Object id;
final IRubyObject exception;
final RubyString message;
private TimeoutTask(final RubyThread currentThread, final RubyModule timeout,
final AtomicBoolean latch, final Object id, final IRubyObject exception, final RubyString message) {
this.currentThread = currentThread;
this.timeout = timeout;
this.latch = latch;
this.id = id;
this.exception = exception;
this.message = message;
}
static TimeoutTask newAnonymousTask(final RubyThread currentThread, final RubyModule timeout,
final AtomicBoolean latch, final Object id, final RubyString message) {
return new TimeoutTask(currentThread, timeout, latch, id, null, message);
}
static TimeoutTask newTaskWithException(final RubyThread currentThread, final RubyModule timeout,
final AtomicBoolean latch, final IRubyObject exception, final RubyString message) {
return new TimeoutTask(currentThread, timeout, latch, null, exception, message);
}
public void run() {
if ( latch.compareAndSet(false, true) ) {
if ( exception == null ) {
raiseAnonymous();
} else {
raiseException();
}
}
}
private void raiseAnonymous() {
RubyObject anonException = (RubyObject)
getTimeoutError(timeout).newInstance(timeout.getRuntime().getCurrentContext(), message, Block.NULL_BLOCK);
anonException.setInternalVariable("__identifier__", id);
currentThread.raise(anonException);
}
private void raiseException() {
currentThread.raise(exception, message);
}
}
private static void killTimeoutThread(ThreadContext context, ScheduledFuture timeoutFuture, AtomicBoolean latch) {
if (latch.compareAndSet(false, true) && timeoutFuture.cancel(false)) {
} else {
try {
timeoutFuture.get();
}
catch (ExecutionException ex) {}
catch (InterruptedException ex) {}
context.pollThreadEvents();
}
}
private static IRubyObject raiseTimeoutErrorIfMatches(ThreadContext context,
final RubyModule timeout, final RaiseException ex, final Object id) {
if ( ex.getException().getInternalVariable("__identifier__") == id ) {
final RubyException rubyException = ex.getException();
return RubyKernel.raise(
context,
context.runtime.getKernel(),
new IRubyObject[] {
getTimeoutError(timeout),
rubyException.callMethod(context, "message"),
rubyException.callMethod(context, "backtrace")
},
Block.NULL_BLOCK);
}
return null;
}
private static RubyClass getTimeoutError(final RubyModule timeout) {
return timeout.getClass("Error");
}
}