package org.jruby.util.io;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channel;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.WritableByteChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
public class BlockingIO {
public static final class Condition {
private final IOChannel channel;
Condition(IOChannel channel) {
this.channel = channel;
}
public void cancel() {
channel.wakeup(false);
}
public void interrupt() {
channel.interrupt();
}
public boolean await() throws InterruptedException {
return channel.await();
}
public boolean await(long timeout, TimeUnit unit) throws InterruptedException {
return channel.await(timeout, unit);
}
}
static final class IOChannel {
final SelectableChannel channel;
final int ops;
private final Object monitor;
private boolean woken = false;
private boolean ready = false;
private boolean interrupted = false;
IOChannel(SelectableChannel channel, int ops, Object monitor) {
this.channel = channel;
this.ops = ops;
this.monitor = monitor;
}
public final void wakeup(boolean ready) {
synchronized (monitor) {
this.woken = true;
this.ready = ready;
monitor.notifyAll();
}
}
public final void interrupt() {
synchronized (monitor) {
this.woken = true;
this.interrupted = true;
monitor.notifyAll();
}
}
public final boolean await() throws InterruptedException {
return await(0, TimeUnit.MILLISECONDS);
}
public final boolean await(final long timeout, TimeUnit unit) throws InterruptedException {
synchronized (monitor) {
if (!woken) {
monitor.wait(TimeUnit.MILLISECONDS.convert(timeout, unit));
}
if (interrupted) {
throw new InterruptedException("Interrupted");
}
return ready;
}
}
}
static final class IOSelector implements Runnable {
private final Selector selector;
private final ConcurrentLinkedQueue<IOChannel> registrationQueue;
public IOSelector(SelectorProvider provider) throws IOException {
selector = SelectorFactory.openWithRetryFrom(null, provider);
registrationQueue = new ConcurrentLinkedQueue<IOChannel>();
}
public void run() {
for ( ; ; ) {
try {
Set<SelectionKey> selected = new HashSet<SelectionKey>(selector.selectedKeys());
for (SelectionKey k : selected) {
List<IOChannel> waitq = (List<IOChannel>) k.attachment();
for (IOChannel ch : waitq) {
ch.wakeup(true);
}
waitq.clear();
}
IOChannel ch;
Set<SelectableChannel> added = new HashSet<SelectableChannel>();
while ((ch = registrationQueue.poll()) != null) {
SelectionKey k = ch.channel.keyFor(selector);
List<IOChannel> waitq = k == null
? new LinkedList<IOChannel>()
: (List<IOChannel>) k.attachment();
ch.channel.register(selector, ch.ops, waitq);
waitq.add(ch);
added.add(ch.channel);
}
for (SelectionKey k : selected) {
if (!added.contains(k.channel())) {
k.cancel();
}
}
selector.select();
} catch (IOException ex) {
}
}
}
Condition add(Channel channel, int ops, Object monitor) {
IOChannel io = new IOChannel((SelectableChannel) channel, ops, monitor);
registrationQueue.add(io);
selector.wakeup();
return new Condition(io);
}
public void await(Channel channel, int op) throws InterruptedException {
add(channel, op, new Object()).await();
}
}
static final private Map<SelectorProvider, IOSelector> selectors
= new ConcurrentHashMap<SelectorProvider, IOSelector>();
private static IOSelector getSelector(SelectorProvider provider) throws IOException {
IOSelector sel = selectors.get(provider);
if (sel != null) {
return sel;
}
synchronized (provider) {
sel = selectors.get(provider);
if (sel == null) {
sel = new IOSelector(provider);
selectors.put(provider, sel);
Thread t = new Thread(sel);
t.setDaemon(true);
t.start();
}
}
return sel;
}
private static IOSelector getSelector(Channel channel) throws IOException {
if (!(channel instanceof SelectableChannel)) {
throw new IllegalArgumentException("channel must be a SelectableChannel");
}
return getSelector(((SelectableChannel) channel).provider());
}
public static final Condition newCondition(Channel channel, int ops, Object monitor) throws IOException {
return getSelector(channel).add(channel, ops, monitor);
}
public static final Condition newCondition(Channel channel, int ops) throws IOException {
return newCondition(channel, ops, new Object());
}
public static void waitForIO(Channel channel, int op) throws InterruptedException, IOException {
getSelector(channel).await(channel, op);
}
public static void awaitReadable(ReadableByteChannel channel) throws InterruptedException, IOException {
waitForIO(channel, SelectionKey.OP_READ);
}
public static void awaitWritable(WritableByteChannel channel) throws InterruptedException, IOException {
waitForIO(channel, SelectionKey.OP_WRITE);
}
public static int read(ReadableByteChannel channel, ByteBuffer buf, boolean blocking) throws IOException {
do {
int n = channel.read(buf);
if (n != 0 || !blocking || !(channel instanceof SelectableChannel) || !buf.hasRemaining()) {
return n;
}
try {
awaitReadable(channel);
} catch (InterruptedException ex) {
throw new InterruptedIOException(ex.getMessage());
}
} while (true);
}
public static int write(WritableByteChannel channel, ByteBuffer buf, boolean blocking) throws IOException {
do {
int n = channel.write(buf);
if (n != 0 || !blocking || !(channel instanceof SelectableChannel) || !buf.hasRemaining()) {
return n;
}
try {
awaitWritable(channel);
} catch (InterruptedException ex) {
throw new InterruptedIOException(ex.getMessage());
}
} while (true);
}
public static int blockingRead(ReadableByteChannel channel, ByteBuffer buf) throws IOException {
return read(channel, buf, true);
}
public static int blockingWrite(WritableByteChannel channel, ByteBuffer buf) throws IOException {
return write(channel, buf, true);
}
}