package org.jruby.ext.set;
import org.jcodings.specific.USASCIIEncoding;
import org.jruby.*;
import org.jruby.anno.JRubyMethod;
import org.jruby.common.IRubyWarnings;
import org.jruby.javasupport.JavaUtil;
import org.jruby.runtime.*;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.runtime.marshal.MarshalStream;
import org.jruby.runtime.marshal.UnmarshalStream;
import org.jruby.util.ArraySupport;
import java.io.IOException;
import java.lang.reflect.Array;
import java.util.Collection;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.Set;
import static org.jruby.RubyEnumerator.enumeratorizeWithSize;
@org.jruby.anno.JRubyClass(name="Set", include = { "Enumerable" })
public class RubySet extends RubyObject implements Set {
static RubyClass createSetClass(final Ruby runtime) {
RubyClass Set = runtime.defineClass("Set", runtime.getObject(), ALLOCATOR);
Set.setReifiedClass(RubySet.class);
Set.includeModule(runtime.getEnumerable());
Set.defineAnnotatedMethods(RubySet.class);
Set.setMarshal(new SetMarshal(Set.getMarshal()));
runtime.getLoadService().require("jruby/set.rb");
return Set;
}
private static final class SetMarshal implements ObjectMarshal {
protected final ObjectMarshal defaultMarshal;
SetMarshal(ObjectMarshal defaultMarshal) {
this.defaultMarshal = defaultMarshal;
}
public void marshalTo(Ruby runtime, Object obj, RubyClass type, MarshalStream marshalStream) throws IOException {
defaultMarshal.marshalTo(runtime, obj, type, marshalStream);
}
public Object unmarshalFrom(Ruby runtime, RubyClass type, UnmarshalStream unmarshalStream) throws IOException {
Object result = defaultMarshal.unmarshalFrom(runtime, type, unmarshalStream);
((RubySet) result).unmarshal();
return result;
}
}
void unmarshal() {
this.hash = (RubyHash) getInstanceVariable("@hash");
}
private static final ObjectAllocator ALLOCATOR = new ObjectAllocator() {
public RubySet allocate(Ruby runtime, RubyClass klass) {
return new RubySet(runtime, klass);
}
};
RubyHash hash;
protected RubySet(Ruby runtime, RubyClass klass) {
super(runtime, klass);
}
final void allocHash(final Ruby runtime) {
setHash(new RubyHash(runtime, runtime.getFalse()));
}
final void allocHash(final Ruby runtime, final int size) {
setHash(new RubyHash(runtime, runtime.getFalse(), size));
}
final void setHash(final RubyHash hash) {
this.hash = hash;
setInstanceVariable("@hash", hash);
}
RubySet newSet(final Ruby runtime) {
RubySet set = new RubySet(runtime, getMetaClass());
set.allocHash(runtime);
return set;
}
private RubySet newSet(final ThreadContext context, final RubyClass metaClass, final RubyArray elements) {
final RubySet set = new RubySet(context.runtime, metaClass);
return set.initSet(context, elements.toJavaArrayMaybeUnsafe(), 0, elements.size());
}
final RubySet initSet(final ThreadContext context, final IRubyObject[] elements, final int off, final int len) {
allocHash(context.runtime, Math.max(4, len));
for ( int i = off; i < len; i++ ) {
invokeAdd(context, elements[i]);
}
return this;
}
@JRubyMethod(name = "[]", rest = true, meta = true)
public static RubySet create(final ThreadContext context, IRubyObject self, IRubyObject... ary) {
final Ruby runtime = context.runtime;
RubySet set = new RubySet(runtime, (RubyClass) self);
return set.initSet(context, ary, 0, ary.length);
}
@JRubyMethod(visibility = Visibility.PRIVATE)
public IRubyObject initialize(ThreadContext context, Block block) {
if ( block.isGiven() && context.runtime.isVerbose() ) {
context.runtime.getWarnings().warning(IRubyWarnings.ID.BLOCK_UNUSED, "given block not used");
}
allocHash(context.runtime);
return this;
}
@JRubyMethod(required = 1, visibility = Visibility.PRIVATE)
public IRubyObject initialize(ThreadContext context, IRubyObject enume, Block block) {
if ( enume.isNil() ) return initialize(context, block);
if ( block.isGiven() ) {
return initWithEnum(context, enume, block);
}
allocHash(context.runtime);
return callMethod(context, "merge", enume);
}
protected IRubyObject initialize(ThreadContext context, IRubyObject[] args, Block block) {
switch (args.length) {
case 0: return initialize(context, block);
case 1: return initialize(context, args[0], block);
}
throw context.runtime.newArgumentError(args.length, 1);
}
private IRubyObject initWithEnum(final ThreadContext context, final IRubyObject enume, final Block block) {
if ( enume instanceof RubyArray ) {
RubyArray ary = (RubyArray) enume;
allocHash(context.runtime, ary.size());
for ( int i = 0; i < ary.size(); i++ ) {
invokeAdd(context, block.yield(context, ary.eltInternal(i)));
}
return ary;
}
if ( enume instanceof RubySet ) {
RubySet set = (RubySet) enume;
allocHash(context.runtime, set.size());
for ( IRubyObject elem : set.elementsOrdered() ) {
invokeAdd(context, block.yield(context, elem));
}
return set;
}
final Ruby runtime = context.runtime;
allocHash(runtime);
return doWithEnum(context, enume, new EachBody(runtime) {
IRubyObject yieldImpl(ThreadContext context, IRubyObject val) {
return invokeAdd(context, block.yield(context, val));
}
});
}
private static IRubyObject doWithEnum(final ThreadContext context, final IRubyObject enume, final EachBody blockImpl) {
if ( enume.respondsTo("each_entry") ) {
return enume.callMethod(context, "each_entry", IRubyObject.NULL_ARRAY, new Block(blockImpl));
}
if ( enume.respondsTo("each") ) {
return enume.callMethod(context, "each", IRubyObject.NULL_ARRAY, new Block(blockImpl));
}
throw context.runtime.newArgumentError("value must be enumerable");
}
@Override
public IRubyObject instance_variable_set(IRubyObject name, IRubyObject value) {
if (getRuntime().newSymbol("@hash").equals(name)) {
if (value instanceof RubyHash) {
setHash((RubyHash) value); return value;
}
}
return super.instance_variable_set(name, value);
}
IRubyObject invokeAdd(final ThreadContext context, final IRubyObject val) {
return this.callMethod(context,"add", val);
}
private static abstract class EachBody extends JavaInternalBlockBody {
EachBody(final Ruby runtime) {
super(runtime, Signature.ONE_REQUIRED);
}
@Override
public IRubyObject yield(ThreadContext context, IRubyObject[] args) {
return yieldImpl(context, args[0]);
}
abstract IRubyObject yieldImpl(ThreadContext context, IRubyObject val) ;
@Override
protected final IRubyObject doYield(ThreadContext context, Block block, IRubyObject[] args, IRubyObject self) {
return yieldImpl(context, args[0]);
}
@Override
protected final IRubyObject doYield(ThreadContext context, Block block, IRubyObject value) {
return yieldImpl(context, value);
}
}
@JRubyMethod
public IRubyObject initialize_dup(ThreadContext context, IRubyObject orig) {
super.initialize_copy(orig);
setHash((RubyHash) (((RubySet) orig).hash).dup(context));
return this;
}
@JRubyMethod
public IRubyObject initialize_clone(ThreadContext context, IRubyObject orig) {
super.initialize_copy(orig);
setHash((RubyHash) (((RubySet) orig).hash).rbClone(context));
return this;
}
@Override
@JRubyMethod
public IRubyObject freeze(ThreadContext context) {
final RubyHash hash = this.hash;
if ( hash != null ) hash.freeze(context);
return super.freeze(context);
}
@Override
@JRubyMethod
public IRubyObject taint(ThreadContext context) {
final RubyHash hash = this.hash;
if ( hash != null ) hash.taint(context);
return super.taint(context);
}
@Override
@JRubyMethod
public IRubyObject untaint(ThreadContext context) {
final RubyHash hash = this.hash;
if ( hash != null ) hash.untaint(context);
return super.untaint(context);
}
@JRubyMethod(name = "size", alias = "length")
public IRubyObject length(ThreadContext context) {
return context.runtime.newFixnum( size() );
}
@JRubyMethod(name = "empty?")
public IRubyObject empty_p(ThreadContext context) {
return context.runtime.newBoolean( isEmpty() );
}
@JRubyMethod(name = "clear")
public IRubyObject rb_clear(ThreadContext context) {
modifyCheck(context.runtime);
clearImpl();
return this;
}
protected void clearImpl() {
hash.rb_clear();
}
@JRubyMethod
public RubySet replace(final ThreadContext context, IRubyObject enume) {
if ( enume instanceof RubySet ) {
modifyCheck(context.runtime);
clearImpl();
addImplSet(context, (RubySet) enume);
}
else {
final Ruby runtime = context.runtime;
if ( ! enume.getMetaClass().hasModuleInHierarchy(runtime.getEnumerable()) ) {
if ( ! enume.respondsTo("each_entry") ) {
throw runtime.newArgumentError("value must be enumerable");
}
}
clearImpl();
rb_merge(context, enume);
}
return this;
}
@JRubyMethod
public RubyArray to_a(final ThreadContext context) {
return this.hash.keys(context);
}
@JRubyMethod
public RubySet to_set(final ThreadContext context, final Block block) {
if ( block.isGiven() ) {
RubySet set = new RubySet(context.runtime, getMetaClass());
set.initialize(context, this, block);
return set;
}
return this;
}
@JRubyMethod(rest = true)
public RubySet to_set(final ThreadContext context, final IRubyObject[] args, final Block block) {
if ( args.length == 0 ) return to_set(context, block);
final Ruby runtime = context.runtime;
IRubyObject klass = args[0]; final RubyClass Set = runtime.getClass("Set");
if ( klass == Set && args.length == 1 & ! block.isGiven() ) {
return this;
}
final IRubyObject[] rest;
if ( klass instanceof RubyClass ) {
rest = ArraySupport.newCopy(args, 1, args.length - 1);
}
else {
klass = Set; rest = args;
}
RubySet set = new RubySet(context.runtime, (RubyClass) klass);
set.initialize(context, rest, block);
return set;
}
@JRubyMethod
public IRubyObject compare_by_identity(ThreadContext context) {
this.hash.compare_by_identity(context);
return this;
}
@JRubyMethod(name = "compare_by_identity?")
public IRubyObject compare_by_identity_p(ThreadContext context) {
return this.hash.compare_by_identity_p(context);
}
@JRubyMethod(visibility = Visibility.PROTECTED)
public RubySet flatten_merge(final ThreadContext context, IRubyObject set) {
flattenMerge(context, set, new IdentityHashMap());
return this;
}
private void flattenMerge(final ThreadContext context, final IRubyObject set, final IdentityHashMap seen) {
if ( set instanceof RubySet ) {
for ( IRubyObject e : ((RubySet) set).elementsOrdered() ) {
addFlattened(context, seen, e);
}
}
else {
set.callMethod(context, "each", IRubyObject.NULL_ARRAY, new Block(
new EachBody(context.runtime) {
IRubyObject yieldImpl(ThreadContext context, IRubyObject e) {
addFlattened(context, seen, e); return context.nil;
}
})
);
}
}
private void addFlattened(final ThreadContext context, final IdentityHashMap seen, IRubyObject e) {
if ( e instanceof RubySet ) {
if ( seen.containsKey(e) ) {
throw context.runtime.newArgumentError("tried to flatten recursive Set");
}
seen.put(e, null);
flattenMerge(context, e, seen);
seen.remove(e);
}
else {
add(context, e);
}
}
@JRubyMethod
public RubySet flatten(final ThreadContext context) {
return newSet(context.runtime).flatten_merge(context, this);
}
@JRubyMethod(name = "flatten!")
public IRubyObject flatten_bang(final ThreadContext context) {
for ( IRubyObject e : elementsOrdered() ) {
if ( e instanceof RubySet ) {
return replace(context, flatten(context));
}
}
return context.nil;
}
@JRubyMethod(name = "include?", alias = { "member?", "===" })
public RubyBoolean include_p(final ThreadContext context, IRubyObject obj) {
return context.runtime.newBoolean( containsImpl(obj) );
}
final boolean containsImpl(IRubyObject obj) {
return hash.fastARef(obj) != null;
}
private boolean allElementsIncluded(final RubySet set) {
for ( IRubyObject o : set.elements() ) {
if ( ! containsImpl(o) ) return false;
}
return true;
}
@JRubyMethod(name = "superset?", alias = { ">=" })
public IRubyObject superset_p(final ThreadContext context, IRubyObject set) {
if ( set instanceof RubySet ) {
if ( getMetaClass().isInstance(set) ) {
return this.hash.op_ge(context, ((RubySet) set).hash);
}
return context.runtime.newBoolean(
size() >= ((RubySet) set).size() && allElementsIncluded((RubySet) set)
);
}
throw context.runtime.newArgumentError("value must be a set");
}
@JRubyMethod(name = "proper_superset?", alias = { ">" })
public IRubyObject proper_superset_p(final ThreadContext context, IRubyObject set) {
if ( set instanceof RubySet ) {
if ( getMetaClass().isInstance(set) ) {
return this.hash.op_gt(context, ((RubySet) set).hash);
}
return context.runtime.newBoolean(
size() > ((RubySet) set).size() && allElementsIncluded((RubySet) set)
);
}
throw context.runtime.newArgumentError("value must be a set");
}
@JRubyMethod(name = "subset?", alias = { "<=" })
public IRubyObject subset_p(final ThreadContext context, IRubyObject set) {
if ( set instanceof RubySet ) {
if ( getMetaClass().isInstance(set) ) {
return this.hash.op_le(context, ((RubySet) set).hash);
}
return context.runtime.newBoolean(
size() <= ((RubySet) set).size() && allElementsIncluded((RubySet) set)
);
}
throw context.runtime.newArgumentError("value must be a set");
}
@JRubyMethod(name = "proper_subset?", alias = { "<" })
public IRubyObject proper_subset_p(final ThreadContext context, IRubyObject set) {
if ( set instanceof RubySet ) {
if ( getMetaClass().isInstance(set) ) {
return this.hash.op_lt(context, ((RubySet) set).hash);
}
return context.runtime.newBoolean(
size() < ((RubySet) set).size() && allElementsIncluded((RubySet) set)
);
}
throw context.runtime.newArgumentError("value must be a set");
}
@JRubyMethod(name = "intersect?")
public IRubyObject intersect_p(final ThreadContext context, IRubyObject set) {
if ( set instanceof RubySet ) {
return context.runtime.newBoolean( intersect((RubySet) set) );
}
throw context.runtime.newArgumentError("value must be a set");
}
public boolean intersect(final RubySet set) {
if ( size() < set.size() ) {
for ( IRubyObject o : elementsOrdered() ) {
if ( set.containsImpl(o) ) return true;
}
}
else {
for ( IRubyObject o : set.elementsOrdered() ) {
if ( containsImpl(o) ) return true;
}
}
return false;
}
@JRubyMethod(name = "disjoint?")
public IRubyObject disjoint_p(final ThreadContext context, IRubyObject set) {
if ( set instanceof RubySet ) {
return context.runtime.newBoolean( ! intersect((RubySet) set) );
}
throw context.runtime.newArgumentError("value must be a set");
}
@JRubyMethod
public IRubyObject each(final ThreadContext context, Block block) {
if ( ! block.isGiven() ) {
return enumeratorizeWithSize(context, this, "each", enumSize());
}
for (IRubyObject elem : elementsOrdered()) block.yield(context, elem);
return this;
}
private RubyEnumerator.SizeFn enumSize() {
return new RubyEnumerator.SizeFn() {
@Override
public IRubyObject size(IRubyObject[] args) {
return getRuntime().newFixnum( RubySet.this.size() );
}
};
}
@JRubyMethod(name = "add", alias = "<<")
public RubySet add(final ThreadContext context, IRubyObject obj) {
modifyCheck(context.runtime);
addImpl(context.runtime, obj);
return this;
}
protected void addImpl(final Ruby runtime, final IRubyObject obj) {
hash.fastASetCheckString(runtime, obj, runtime.getTrue());
}
protected void addImplSet(final ThreadContext context, final RubySet set) {
hash.merge_bang(context, set.hash, Block.NULL_BLOCK);
}
@JRubyMethod(name = "add?")
public IRubyObject add_p(final ThreadContext context, IRubyObject obj) {
if ( containsImpl(obj) ) return context.nil;
return add(context, obj);
}
@JRubyMethod
public IRubyObject delete(final ThreadContext context, IRubyObject obj) {
modifyCheck(context.runtime);
deleteImpl(obj);
return this;
}
protected boolean deleteImpl(final IRubyObject obj) {
hash.modify();
return hash.fastDelete(obj);
}
protected void deleteImplIterator(final IRubyObject obj, final Iterator it) {
it.remove();
}
@JRubyMethod(name = "delete?")
public IRubyObject delete_p(final ThreadContext context, IRubyObject obj) {
if ( ! containsImpl(obj) ) return context.nil;
return delete(context, obj);
}
@JRubyMethod
public IRubyObject delete_if(final ThreadContext context, Block block) {
if ( ! block.isGiven() ) {
return enumeratorizeWithSize(context, this, "delete_if", enumSize());
}
Iterator<IRubyObject> it = elementsOrdered().iterator();
while ( it.hasNext() ) {
IRubyObject elem = it.next();
if ( block.yield(context, elem).isTrue() ) deleteImplIterator(elem, it);
}
return this;
}
@JRubyMethod
public IRubyObject keep_if(final ThreadContext context, Block block) {
if ( ! block.isGiven() ) {
return enumeratorizeWithSize(context, this, "keep_if", enumSize());
}
Iterator<IRubyObject> it = elementsOrdered().iterator();
while ( it.hasNext() ) {
IRubyObject elem = it.next();
if ( ! block.yield(context, elem).isTrue() ) deleteImplIterator(elem, it);
}
return this;
}
@JRubyMethod(name = "collect!", alias = "map!")
public IRubyObject collect_bang(final ThreadContext context, Block block) {
if ( ! block.isGiven() ) {
return enumeratorizeWithSize(context, this, "collect!", enumSize());
}
final RubyArray elems = to_a(context); clearImpl();
for ( int i=0; i<elems.size(); i++ ) {
addImpl(context.runtime, block.yield(context, elems.eltInternal(i)));
}
return this;
}
@JRubyMethod(name = "reject!")
public IRubyObject reject_bang(final ThreadContext context, Block block) {
if ( ! block.isGiven() ) {
return enumeratorizeWithSize(context, this, "reject!", enumSize());
}
final int size = size();
Iterator<IRubyObject> it = elementsOrdered().iterator();
while ( it.hasNext() ) {
IRubyObject elem = it.next();
if ( block.yield(context, elem).isTrue() ) deleteImplIterator(elem, it);
}
return size == size() ? context.nil : this;
}
@JRubyMethod(name = "select!")
public IRubyObject select_bang(final ThreadContext context, Block block) {
if ( ! block.isGiven() ) {
return enumeratorizeWithSize(context, this, "select!", enumSize());
}
final int size = size();
Iterator<IRubyObject> it = elementsOrdered().iterator();
while ( it.hasNext() ) {
IRubyObject elem = it.next();
if ( ! block.yield(context, elem).isTrue() ) deleteImplIterator(elem, it);
}
return size == size() ? context.nil : this;
}
@JRubyMethod(name = "merge")
public RubySet rb_merge(final ThreadContext context, IRubyObject enume) {
final Ruby runtime = context.runtime;
if ( enume instanceof RubySet ) {
modifyCheck(runtime);
addImplSet(context, (RubySet) enume);
}
else if ( enume instanceof RubyArray ) {
modifyCheck(runtime);
RubyArray ary = (RubyArray) enume;
for ( int i = 0; i < ary.size(); i++ ) {
addImpl(runtime, ary.eltInternal(i));
}
}
else {
doWithEnum(context, enume, new EachBody(runtime) {
IRubyObject yieldImpl(ThreadContext context, IRubyObject val) {
addImpl(context.runtime, val); return context.nil;
}
});
}
return this;
}
@JRubyMethod(name = "subtract")
public IRubyObject subtract(final ThreadContext context, IRubyObject enume) {
final Ruby runtime = context.runtime;
if ( enume instanceof RubySet ) {
modifyCheck(runtime);
for ( IRubyObject elem : ((RubySet) enume).elementsOrdered() ) {
deleteImpl(elem);
}
}
else if ( enume instanceof RubyArray ) {
modifyCheck(runtime);
RubyArray ary = (RubyArray) enume;
for ( int i = 0; i < ary.size(); i++ ) {
deleteImpl(ary.eltInternal(i));
}
}
else {
doWithEnum(context, enume, new EachBody(runtime) {
IRubyObject yieldImpl(ThreadContext context, IRubyObject val) {
deleteImpl(val); return context.nil;
}
});
}
return this;
}
@JRubyMethod(name = "|", alias = { "+", "union" })
public IRubyObject op_or(final ThreadContext context, IRubyObject enume) {
return ((RubySet) dup()).rb_merge(context, enume);
}
@JRubyMethod(name = "-", alias = { "difference" })
public IRubyObject op_diff(final ThreadContext context, IRubyObject enume) {
return ((RubySet) dup()).subtract(context, enume);
}
@JRubyMethod(name = "&", alias = { "intersection" })
public IRubyObject op_and(final ThreadContext context, IRubyObject enume) {
final Ruby runtime = context.runtime;
final RubySet newSet = new RubySet(runtime, getMetaClass());
if ( enume instanceof RubySet ) {
newSet.allocHash(runtime, ((RubySet) enume).size());
for ( IRubyObject obj : ((RubySet) enume).elementsOrdered() ) {
if ( containsImpl(obj) ) newSet.addImpl(runtime, obj);
}
}
else if ( enume instanceof RubyArray ) {
RubyArray ary = (RubyArray) enume;
newSet.allocHash(runtime, ary.size());
for ( int i = 0; i < ary.size(); i++ ) {
final IRubyObject obj = ary.eltInternal(i);
if ( containsImpl(obj) ) newSet.addImpl(runtime, obj);
}
}
else {
newSet.allocHash(runtime);
doWithEnum(context, enume, new EachBody(runtime) {
IRubyObject yieldImpl(ThreadContext context, IRubyObject obj) {
if ( containsImpl(obj) ) newSet.addImpl(runtime, obj);
return context.nil;
}
});
}
return newSet;
}
@JRubyMethod(name = "^")
public IRubyObject op_xor(final ThreadContext context, IRubyObject enume) {
final Ruby runtime = context.runtime;
RubySet newSet = new RubySet(runtime, runtime.getClass("Set"));
newSet.initialize(context, enume, Block.NULL_BLOCK);
for ( IRubyObject o : elementsOrdered() ) {
if ( newSet.containsImpl(o) ) newSet.deleteImpl(o);
else newSet.addImpl(runtime, o);
}
return newSet;
}
@Override
@JRubyMethod(name = "==")
public IRubyObject op_equal(ThreadContext context, IRubyObject other) {
if ( this == other ) return context.tru;
if ( getMetaClass().isInstance(other) ) {
return this.hash.op_equal(context, ((RubySet) other).hash);
}
if ( other instanceof RubySet ) {
RubySet that = (RubySet) other;
if ( this.size() == that.size() ) {
for ( IRubyObject obj : elementsOrdered() ) {
if ( ! that.containsImpl(obj) ) return context.fals;
}
return context.tru;
}
}
return context.fals;
}
@JRubyMethod(name = "reset")
public IRubyObject reset(ThreadContext context) {
this.hash.rehash();
return this;
}
@JRubyMethod(name = "eql?")
public IRubyObject op_eql(ThreadContext context, IRubyObject other) {
if ( other instanceof RubySet ) {
return this.hash.op_eql(context, ((RubySet) other).hash);
}
return context.fals;
}
@Override
public boolean eql(IRubyObject other) {
if ( other instanceof RubySet ) {
final Ruby runtime = getRuntime();
return this.hash.op_eql(runtime.getCurrentContext(), ((RubySet) other).hash) == runtime.getTrue();
}
return false;
}
@Override
@JRubyMethod
public RubyFixnum hash() {
return hash.hash();
}
@JRubyMethod(name = "classify")
public IRubyObject classify(ThreadContext context, final Block block) {
if ( ! block.isGiven() ) {
return enumeratorizeWithSize(context, this, "classify", enumSize());
}
final Ruby runtime = context.runtime;
final RubyHash h = new RubyHash(runtime, size());
for ( IRubyObject i : elementsOrdered() ) {
final IRubyObject key = block.yield(context, i);
IRubyObject set;
if ( ( set = h.fastARef(key) ) == null ) {
h.fastASet(key, set = newSet(runtime));
}
((RubySet) set).invokeAdd(context, i);
}
return h;
}
@JRubyMethod(name = "divide")
public IRubyObject divide(ThreadContext context, final Block block) {
if ( ! block.isGiven() ) {
return enumeratorizeWithSize(context, this, "divide", enumSize());
}
if ( block.getSignature().arityValue() == 2 ) {
return divideTSort(context, block);
}
final Ruby runtime = context.runtime;
RubyHash vals = (RubyHash) classify(context, block);
final RubySet set = new RubySet(runtime, runtime.getClass("Set"));
set.allocHash(runtime, vals.size());
for ( IRubyObject val : (Collection<IRubyObject>) vals.directValues() ) {
set.invokeAdd(context, val);
}
return set;
}
private IRubyObject divideTSort(ThreadContext context, final Block block) {
final Ruby runtime = context.runtime;
final RubyHash dig = DivideTSortHash.newInstance(context);
for ( IRubyObject u : elementsOrdered() ) {
RubyArray a;
dig.fastASet(u, a = runtime.newArray());
for ( IRubyObject v : elementsOrdered() ) {
IRubyObject ret = block.call(context, u, v);
if ( ret.isTrue() ) a.append(v);
}
}
final RubyClass Set = runtime.getClass("Set");
final RubySet set = new RubySet(runtime, Set);
set.allocHash(runtime, dig.size());
dig.callMethod(context, "each_strongly_connected_component", IRubyObject.NULL_ARRAY, new Block(
new JavaInternalBlockBody(runtime, Signature.ONE_REQUIRED) {
@Override
public IRubyObject yield(ThreadContext context, IRubyObject[] args) {
return doYield(context, null, args[0]);
}
@Override
protected IRubyObject doYield(ThreadContext context, Block block, IRubyObject css) {
set.addImpl(runtime, RubySet.this.newSet(context, Set, (RubyArray) css));
return context.nil;
}
})
);
return set;
}
public static final class DivideTSortHash extends RubyHash {
private static final String NAME = "DivideTSortHash";
static DivideTSortHash newInstance(final ThreadContext context) {
final Ruby runtime = context.runtime;
RubyClass Set = runtime.getClass("Set");
RubyClass klass = (RubyClass) Set.getConstantAt(NAME, true);
if (klass == null) {
synchronized (DivideTSortHash.class) {
klass = (RubyClass) Set.getConstantAt(NAME, true);
if (klass == null) {
klass = Set.defineClassUnder(NAME, runtime.getHash(), runtime.getHash().getAllocator());
Set.setConstantVisibility(runtime, NAME, true);
klass.includeModule(getTSort(runtime));
klass.defineAnnotatedMethods(DivideTSortHash.class);
}
}
}
return new DivideTSortHash(runtime, klass);
}
DivideTSortHash(final Ruby runtime, final RubyClass metaClass) {
super(runtime, metaClass);
}
@JRubyMethod
public IRubyObject tsort_each_node(ThreadContext context, Block block) {
return each_key(context, block);
}
@JRubyMethod
public IRubyObject tsort_each_child(ThreadContext context, IRubyObject node, Block block) {
IRubyObject set = fetch(context, node, Block.NULL_BLOCK);
if ( set instanceof RubySet ) {
return ((RubySet) set).each(context, block);
}
return set.callMethod(context, "each", IRubyObject.NULL_ARRAY, block);
}
}
static RubyModule getTSort(final Ruby runtime) {
if ( ! runtime.getObject().hasConstant("TSort") ) {
runtime.getLoadService().require("tsort");
}
return runtime.getModule("TSort");
}
@Override
public final IRubyObject inspect() {
return inspect(getRuntime().getCurrentContext());
}
private static final byte[] RECURSIVE_BYTES = new byte[] { '.','.','.' };
@JRubyMethod(name = "inspect", alias = "to_s")
public RubyString inspect(ThreadContext context) {
final Ruby runtime = context.runtime;
final RubyString str;
if (size() == 0) {
return inspectEmpty(runtime);
}
if (runtime.isInspecting(this)) {
return inspectRecurse(runtime);
}
str = RubyString.newStringLight(runtime, 32, USASCIIEncoding.INSTANCE);
inspectPrefix(str, getMetaClass());
try {
runtime.registerInspecting(this);
inspectSet(context, str);
return str.cat('>');
}
finally {
runtime.unregisterInspecting(this);
}
}
private RubyString inspectEmpty(final Ruby runtime) {
RubyString str = RubyString.newStringLight(runtime, 16, USASCIIEncoding.INSTANCE);
inspectPrefix(str, getMetaClass()); str.cat('{').cat('}').cat('>');
return str;
}
private RubyString inspectRecurse(final Ruby runtime) {
RubyString str = RubyString.newStringLight(runtime, 20, USASCIIEncoding.INSTANCE);
inspectPrefix(str, getMetaClass());
str.cat('{').cat(RECURSIVE_BYTES).cat('}').cat('>');
return str;
}
private static RubyString inspectPrefix(final RubyString str, final RubyClass metaClass) {
str.cat('#').cat('<').cat(metaClass.getRealClass().getName().getBytes(RubyEncoding.UTF8));
str.cat(':').cat(' '); return str;
}
private void inspectSet(final ThreadContext context, final RubyString str) {
str.cat((byte) '{');
boolean tainted = isTaint(); boolean notFirst = false;
for ( IRubyObject elem : elementsOrdered() ) {
final RubyString s = inspect(context, elem);
if ( s.isTaint() ) tainted = true;
if ( notFirst ) str.cat((byte) ',').cat((byte) ' ');
else str.setEncoding( s.getEncoding() ); notFirst = true;
str.cat19( s );
}
str.cat((byte) '}');
if ( tainted ) str.setTaint(true);
}
protected final Set<IRubyObject> elements() {
return hash.directKeySet();
}
protected Set<IRubyObject> elementsOrdered() {
return elements();
}
protected final void modifyCheck(final Ruby runtime) {
if ((flags & FROZEN_F) != 0) throw runtime.newFrozenError("Set");
}
public int size() { return hash.size(); }
public boolean isEmpty() { return hash.isEmpty(); }
public void clear() { clearImpl(); }
public boolean contains(Object o) {
return containsImpl(toRuby(o));
}
public Iterator<IRubyObject> rawIterator() {
return elementsOrdered().iterator();
}
public Iterator<Object> iterator() {
return hash.keySet().iterator();
}
public Object[] toArray() {
Object[] array = new Object[size()]; int i = 0;
for ( IRubyObject elem : elementsOrdered() ) {
array[i++] = elem.toJava(Object.class);
}
return array;
}
public Object[] toArray(final Object[] ary) {
final Class type = ary.getClass().getComponentType();
Object[] array = ary;
if (array.length < size()) {
array = (Object[]) Array.newInstance(type, size());
}
int i = 0;
for ( IRubyObject elem : elementsOrdered() ) {
array[i++] = elem.toJava(type);
}
return array;
}
public boolean add(Object element) {
final Ruby runtime = getRuntime();
final int size = size();
addImpl(runtime, toRuby(runtime, element));
return size() > size;
}
public boolean remove(Object element) {
return deleteImpl(toRuby(element));
}
public boolean containsAll(Collection coll) {
for ( Object elem : coll ) {
if ( ! contains(elem) ) return false;
}
return true;
}
public boolean addAll(Collection coll) {
final Ruby runtime = getRuntime();
final int size = size();
for ( Object elem : coll ) {
addImpl(runtime, toRuby(runtime, elem));
}
return size() > size;
}
public boolean retainAll(Collection coll) {
final int size = size();
for (Iterator<IRubyObject> iter = rawIterator(); iter.hasNext();) {
IRubyObject elem = iter.next();
if ( ! coll.contains(elem.toJava(Object.class)) ) {
deleteImplIterator(elem, iter);
}
}
return size() < size;
}
public boolean removeAll(Collection coll) {
boolean removed = false;
for ( Object elem : coll ) {
removed = remove(elem) | removed;
}
return removed;
}
static IRubyObject toRuby(Ruby runtime, Object obj) {
return JavaUtil.convertJavaToUsableRubyObject(runtime, obj);
}
final IRubyObject toRuby(Object obj) {
return JavaUtil.convertJavaToUsableRubyObject(getRuntime(), obj);
}
}