package org.pcollections;
import java.io.Serializable;
import java.util.AbstractCollection;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map.Entry;
public final class MapPBag<E> extends AbstractCollection<E> implements PBag<E>, Serializable {
private static final long serialVersionUID = 1L;
public static <E> MapPBag<E> empty(final PMap<E, Integer> map) {
return new MapPBag<E>(map.minusAll(map.keySet()), 0);
}
private final PMap<E, Integer> map;
private final int size;
private MapPBag(final PMap<E, Integer> map, final int size) {
this.map = map;
this.size = size;
}
@Override
public int size() {
return size;
}
@Override
public Iterator<E> iterator() {
final Iterator<Entry<E, Integer>> i = map.entrySet().iterator();
return new Iterator<E>() {
private E e;
private int n = 0;
public boolean hasNext() {
return n > 0 || i.hasNext();
}
public E next() {
if (n == 0) {
Entry<E, Integer> entry = i.next();
e = entry.getKey();
n = entry.getValue();
}
n--;
return e;
}
public void remove() {
throw new UnsupportedOperationException();
}
};
}
@Override
public boolean contains(final Object e) {
return map.containsKey(e);
}
@Override
public int hashCode() {
int hashCode = 0;
for (E e : this) hashCode += e.hashCode();
return hashCode;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object that) {
if (!(that instanceof PBag)) return false;
if (!(that instanceof MapPBag)) {
MapPBag<Object> empty = (MapPBag<Object>) this.minusAll(this);
that = empty.plusAll((PBag<?>) that);
}
return this.map.equals(((MapPBag<?>) that).map);
}
public MapPBag<E> plus(final E e) {
return new MapPBag<E>(map.plus(e, count(e) + 1), size + 1);
}
@SuppressWarnings("unchecked")
public MapPBag<E> minus(final Object e) {
int n = count(e);
if (n == 0) return this;
if (n == 1)
return new MapPBag<E>(map.minus(e), size - 1);
return new MapPBag<E>(map.plus((E) e, n - 1), size - 1);
}
public MapPBag<E> plusAll(final Collection<? extends E> list) {
MapPBag<E> bag = this;
for (E e : list) bag = bag.plus(e);
return bag;
}
public MapPBag<E> minusAll(final Collection<?> list) {
PMap<E, Integer> map = this.map.minusAll(list);
return new MapPBag<E>(map, size(map));
}
@SuppressWarnings("unchecked")
private int count(final Object o) {
if (!contains(o)) return 0;
return map.get((E) o);
}
private static int size(final PMap<?, Integer> map) {
int size = 0;
for (Integer n : map.values()) size += n;
return size;
}
}