package org.apache.lucene.search;
import java.io.IOException;
import java.util.AbstractCollection;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.NoSuchElementException;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.index.PrefixCodedTerms;
import org.apache.lucene.index.PrefixCodedTerms.TermIterator;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.BytesRefIterator;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.FutureArrays;
import org.apache.lucene.util.RamUsageEstimator;
public abstract class PointInSetQuery extends Query implements Accountable {
protected static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(PointInSetQuery.class);
final PrefixCodedTerms sortedPackedPoints;
final int sortedPackedPointsHashCode;
final String field;
final int numDims;
final int bytesPerDim;
final long ramBytesUsed;
public static abstract class Stream implements BytesRefIterator {
@Override
public abstract BytesRef next();
};
protected PointInSetQuery(String field, int numDims, int bytesPerDim, Stream packedPoints) {
this.field = field;
if (bytesPerDim < 1 || bytesPerDim > PointValues.MAX_NUM_BYTES) {
throw new IllegalArgumentException("bytesPerDim must be > 0 and <= " + PointValues.MAX_NUM_BYTES + "; got " + bytesPerDim);
}
this.bytesPerDim = bytesPerDim;
if (numDims < 1 || numDims > PointValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException("numDims must be > 0 and <= " + PointValues.MAX_DIMENSIONS + "; got " + numDims);
}
this.numDims = numDims;
PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder();
BytesRefBuilder previous = null;
BytesRef current;
while ((current = packedPoints.next()) != null) {
if (current.length != numDims * bytesPerDim) {
throw new IllegalArgumentException("packed point length should be " + (numDims * bytesPerDim) + " but got " + current.length + "; field=\"" + field + "\" numDims=" + numDims + " bytesPerDim=" + bytesPerDim);
}
if (previous == null) {
previous = new BytesRefBuilder();
} else {
int cmp = previous.get().compareTo(current);
if (cmp == 0) {
continue;
} else if (cmp > 0) {
throw new IllegalArgumentException("values are out of order: saw " + previous + " before " + current);
}
}
builder.add(field, current);
previous.copyBytes(current);
}
sortedPackedPoints = builder.finish();
sortedPackedPointsHashCode = sortedPackedPoints.hashCode();
ramBytesUsed = BASE_RAM_BYTES +
RamUsageEstimator.sizeOfObject(field) +
RamUsageEstimator.sizeOfObject(sortedPackedPoints);
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public final Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new ConstantScoreWeight(this, boost) {
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();
PointValues values = reader.getPointValues(field);
if (values == null) {
return null;
}
if (values.getNumIndexDimensions() != numDims) {
throw new IllegalArgumentException("field=\"" + field + "\" was indexed with numIndexDims=" + values.getNumIndexDimensions() + " but this query has numIndexDims=" + numDims);
}
if (values.getBytesPerDimension() != bytesPerDim) {
throw new IllegalArgumentException("field=\"" + field + "\" was indexed with bytesPerDim=" + values.getBytesPerDimension() + " but this query has bytesPerDim=" + bytesPerDim);
}
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
if (numDims == 1) {
values.intersect(new MergePointVisitor(sortedPackedPoints, result));
} else {
SinglePointVisitor visitor = new SinglePointVisitor(result);
TermIterator iterator = sortedPackedPoints.iterator();
for (BytesRef point = iterator.next(); point != null; point = iterator.next()) {
visitor.setPoint(point);
values.intersect(visitor);
}
}
return new ConstantScoreScorer(this, score(), scoreMode, result.build().iterator());
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
private class MergePointVisitor implements IntersectVisitor {
private final DocIdSetBuilder result;
private TermIterator iterator;
private BytesRef nextQueryPoint;
private final BytesRef scratch = new BytesRef();
private final PrefixCodedTerms sortedPackedPoints;
private DocIdSetBuilder.BulkAdder adder;
public MergePointVisitor(PrefixCodedTerms sortedPackedPoints, DocIdSetBuilder result) throws IOException {
this.result = result;
this.sortedPackedPoints = sortedPackedPoints;
scratch.length = bytesPerDim;
this.iterator = this.sortedPackedPoints.iterator();
nextQueryPoint = iterator.next();
}
@Override
public void grow(int count) {
adder = result.grow(count);
}
@Override
public void visit(int docID) {
adder.add(docID);
}
@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue)) {
visit(docID);
}
}
@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
if (matches(packedValue)) {
int docID;
while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
visit(docID);
}
}
}
private boolean matches(byte[] packedValue) {
scratch.bytes = packedValue;
while (nextQueryPoint != null) {
int cmp = nextQueryPoint.compareTo(scratch);
if (cmp == 0) {
return true;
} else if (cmp < 0) {
nextQueryPoint = iterator.next();
} else {
break;
}
}
return false;
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
while (nextQueryPoint != null) {
scratch.bytes = minPackedValue;
int cmpMin = nextQueryPoint.compareTo(scratch);
if (cmpMin < 0) {
nextQueryPoint = iterator.next();
continue;
}
scratch.bytes = maxPackedValue;
int cmpMax = nextQueryPoint.compareTo(scratch);
if (cmpMax > 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
if (cmpMin == 0 && cmpMax == 0) {
return Relation.CELL_INSIDE_QUERY;
} else {
return Relation.CELL_CROSSES_QUERY;
}
}
return Relation.CELL_OUTSIDE_QUERY;
}
}
private class SinglePointVisitor implements IntersectVisitor {
private final DocIdSetBuilder result;
private final byte[] pointBytes;
private DocIdSetBuilder.BulkAdder adder;
public SinglePointVisitor(DocIdSetBuilder result) {
this.result = result;
this.pointBytes = new byte[bytesPerDim * numDims];
}
public void setPoint(BytesRef point) {
assert point.length == pointBytes.length;
System.arraycopy(point.bytes, point.offset, pointBytes, 0, pointBytes.length);
}
@Override
public void grow(int count) {
adder = result.grow(count);
}
@Override
public void visit(int docID) {
adder.add(docID);
}
@Override
public void visit(int docID, byte[] packedValue) {
assert packedValue.length == pointBytes.length;
if (Arrays.equals(packedValue, pointBytes)) {
visit(docID);
}
}
@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
assert packedValue.length == pointBytes.length;
if (Arrays.equals(packedValue, pointBytes)) {
int docID;
while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
visit(docID);
}
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
boolean crosses = false;
for(int dim=0;dim<numDims;dim++) {
int offset = dim*bytesPerDim;
int cmpMin = FutureArrays.compareUnsigned(minPackedValue, offset, offset + bytesPerDim, pointBytes, offset, offset + bytesPerDim);
if (cmpMin > 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
int cmpMax = FutureArrays.compareUnsigned(maxPackedValue, offset, offset + bytesPerDim, pointBytes, offset, offset + bytesPerDim);
if (cmpMax < 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
if (cmpMin != 0 || cmpMax != 0) {
crosses = true;
}
}
if (crosses) {
return Relation.CELL_CROSSES_QUERY;
} else {
return Relation.CELL_INSIDE_QUERY;
}
}
}
public Collection<byte[]> getPackedPoints() {
return new AbstractCollection<byte[]>() {
@Override
public Iterator<byte[]> iterator() {
int size = (int) sortedPackedPoints.size();
PrefixCodedTerms.TermIterator iterator = sortedPackedPoints.iterator();
return new Iterator<byte[]>() {
int upto = 0;
@Override
public boolean hasNext() {
return upto < size;
}
@Override
public byte[] next() {
if (upto == size) {
throw new NoSuchElementException();
}
upto++;
BytesRef next = iterator.next();
return BytesRef.deepCopyOf(next).bytes;
}
};
}
@Override
public int size() {
return (int) sortedPackedPoints.size();
}
};
}
public String getField() {
return field;
}
public int getNumDims() {
return numDims;
}
public int getBytesPerDim() {
return bytesPerDim;
}
@Override
public final int hashCode() {
int hash = classHash();
hash = 31 * hash + field.hashCode();
hash = 31 * hash + sortedPackedPointsHashCode;
hash = 31 * hash + numDims;
hash = 31 * hash + bytesPerDim;
return hash;
}
@Override
public final boolean equals(Object other) {
return sameClassAs(other) &&
equalsTo(getClass().cast(other));
}
private boolean equalsTo(PointInSetQuery other) {
return other.field.equals(field) &&
other.numDims == numDims &&
other.bytesPerDim == bytesPerDim &&
other.sortedPackedPointsHashCode == sortedPackedPointsHashCode &&
other.sortedPackedPoints.equals(sortedPackedPoints);
}
@Override
public final String toString(String field) {
final StringBuilder sb = new StringBuilder();
if (this.field.equals(field) == false) {
sb.append(this.field);
sb.append(':');
}
sb.append("{");
TermIterator iterator = sortedPackedPoints.iterator();
byte[] pointBytes = new byte[numDims * bytesPerDim];
boolean first = true;
for (BytesRef point = iterator.next(); point != null; point = iterator.next()) {
if (first == false) {
sb.append(" ");
}
first = false;
System.arraycopy(point.bytes, point.offset, pointBytes, 0, pointBytes.length);
sb.append(toString(pointBytes));
}
sb.append("}");
return sb.toString();
}
protected abstract String toString(byte[] value);
@Override
public long ramBytesUsed() {
return ramBytesUsed;
}
}