package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import org.apache.lucene.geo.Rectangle;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SloppyMath;
import org.apache.lucene.util.bkd.BKDReader;
import org.apache.lucene.util.bkd.BKDReader.IndexTree;
import org.apache.lucene.util.bkd.BKDReader.IntersectState;
import static org.apache.lucene.geo.GeoEncodingUtils.decodeLatitude;
import static org.apache.lucene.geo.GeoEncodingUtils.decodeLongitude;
class NearestNeighbor {
static class Cell implements Comparable<Cell> {
final int readerIndex;
final byte[] minPacked;
final byte[] maxPacked;
final IndexTree index;
final double distanceSortKey;
public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceSortKey) {
this.index = index;
this.readerIndex = readerIndex;
this.minPacked = minPacked.clone();
this.maxPacked = maxPacked.clone();
this.distanceSortKey = distanceSortKey;
}
public int compareTo(Cell other) {
return Double.compare(distanceSortKey, other.distanceSortKey);
}
@Override
public String toString() {
double minLat = decodeLatitude(minPacked, 0);
double minLon = decodeLongitude(minPacked, Integer.BYTES);
double maxLat = decodeLatitude(maxPacked, 0);
double maxLon = decodeLongitude(maxPacked, Integer.BYTES);
return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceSortKey=" + distanceSortKey + ")";
}
}
private static class NearestVisitor implements IntersectVisitor {
public int curDocBase;
public Bits curLiveDocs;
final int topN;
final PriorityQueue<NearestHit> hitQueue;
final double pointLat;
final double pointLon;
private int setBottomCounter;
private double minLon = Double.NEGATIVE_INFINITY;
private double maxLon = Double.POSITIVE_INFINITY;
private double minLat = Double.NEGATIVE_INFINITY;
private double maxLat = Double.POSITIVE_INFINITY;
private double minLon2 = Double.POSITIVE_INFINITY;
public NearestVisitor(PriorityQueue<NearestHit> hitQueue, int topN, double pointLat, double pointLon) {
this.hitQueue = hitQueue;
this.topN = topN;
this.pointLat = pointLat;
this.pointLon = pointLon;
}
@Override
public void visit(int docID) {
throw new AssertionError();
}
private void maybeUpdateBBox() {
if (setBottomCounter < 1024 || (setBottomCounter & 0x3F) == 0x3F) {
NearestHit hit = hitQueue.peek();
Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon,
SloppyMath.haversinMeters(hit.distanceSortKey));
minLat = box.minLat;
maxLat = box.maxLat;
if (box.crossesDateline()) {
minLon = Double.NEGATIVE_INFINITY;
maxLon = box.maxLon;
minLon2 = box.minLon;
} else {
minLon = box.minLon;
maxLon = box.maxLon;
minLon2 = Double.POSITIVE_INFINITY;
}
}
setBottomCounter++;
}
@Override
public void visit(int docID, byte[] packedValue) {
if (curLiveDocs != null && curLiveDocs.get(docID) == false) {
return;
}
double docLatitude = decodeLatitude(packedValue, 0);
double docLongitude = decodeLongitude(packedValue, Integer.BYTES);
if (docLatitude < minLat || docLatitude > maxLat) {
return;
}
if ((docLongitude < minLon || docLongitude > maxLon) && (docLongitude < minLon2)) {
return;
}
double distanceSortKey = SloppyMath.haversinSortKey(pointLat, pointLon, docLatitude, docLongitude);
int fullDocID = curDocBase + docID;
if (hitQueue.size() == topN) {
NearestHit hit = hitQueue.peek();
if (distanceSortKey < hit.distanceSortKey || (distanceSortKey == hit.distanceSortKey && fullDocID < hit.docID)) {
hitQueue.poll();
hit.docID = fullDocID;
hit.distanceSortKey = distanceSortKey;
hitQueue.offer(hit);
maybeUpdateBBox();
}
} else {
NearestHit hit = new NearestHit();
hit.docID = fullDocID;
hit.distanceSortKey = distanceSortKey;
hitQueue.offer(hit);
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
double cellMinLat = decodeLatitude(minPackedValue, 0);
double cellMinLon = decodeLongitude(minPackedValue, Integer.BYTES);
double cellMaxLat = decodeLatitude(maxPackedValue, 0);
double cellMaxLon = decodeLongitude(maxPackedValue, Integer.BYTES);
if (cellMaxLat < minLat || maxLat < cellMinLat || ((cellMaxLon < minLon || maxLon < cellMinLon) && cellMaxLon < minLon2)) {
return Relation.CELL_OUTSIDE_QUERY;
}
return Relation.CELL_CROSSES_QUERY;
}
}
static class NearestHit {
public int docID;
public double distanceSortKey;
@Override
public String toString() {
return "NearestHit(docID=" + docID + " distanceSortKey=" + distanceSortKey + ")";
}
}
public static NearestHit[] nearest(double pointLat, double pointLon, List<BKDReader> readers, List<Bits> liveDocs, List<Integer> docBases, final int n) throws IOException {
final PriorityQueue<NearestHit> hitQueue = new PriorityQueue<>(n, new Comparator<NearestHit>() {
@Override
public int compare(NearestHit a, NearestHit b) {
int cmp = Double.compare(a.distanceSortKey, b.distanceSortKey);
if (cmp != 0) {
return -cmp;
}
return b.docID - a.docID;
}
});
PriorityQueue<Cell> cellQueue = new PriorityQueue<>();
NearestVisitor visitor = new NearestVisitor(hitQueue, n, pointLat, pointLon);
List<BKDReader.IntersectState> states = new ArrayList<>();
int bytesPerDim = -1;
for(int i=0;i<readers.size();i++) {
BKDReader reader = readers.get(i);
if (bytesPerDim == -1) {
bytesPerDim = reader.getBytesPerDimension();
} else if (bytesPerDim != reader.getBytesPerDimension()) {
throw new IllegalStateException("bytesPerDim changed from " + bytesPerDim + " to " + reader.getBytesPerDimension() + " across readers");
}
byte[] minPackedValue = reader.getMinPackedValue();
byte[] maxPackedValue = reader.getMaxPackedValue();
IntersectState state = reader.getIntersectState(visitor);
states.add(state);
cellQueue.offer(new Cell(state.index, i, reader.getMinPackedValue(), reader.getMaxPackedValue(),
approxBestDistance(minPackedValue, maxPackedValue, pointLat, pointLon)));
}
while (cellQueue.size() > 0) {
Cell cell = cellQueue.poll();
BKDReader reader = readers.get(cell.readerIndex);
if (cell.index.isLeafNode()) {
visitor.curDocBase = docBases.get(cell.readerIndex);
visitor.curLiveDocs = liveDocs.get(cell.readerIndex);
reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex));
} else {
if (visitor.compare(cell.minPacked, cell.maxPacked) == Relation.CELL_OUTSIDE_QUERY) {
continue;
}
BytesRef splitValue = BytesRef.deepCopyOf(cell.index.getSplitDimValue());
int splitDim = cell.index.getSplitDim();
IndexTree newIndex = cell.index.clone();
byte[] splitPackedValue = cell.maxPacked.clone();
System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim*bytesPerDim, bytesPerDim);
cell.index.pushLeft();
cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.minPacked, splitPackedValue,
approxBestDistance(cell.minPacked, splitPackedValue, pointLat, pointLon)));
splitPackedValue = cell.minPacked.clone();
System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim*bytesPerDim, bytesPerDim);
newIndex.pushRight();
cellQueue.offer(new Cell(newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked,
approxBestDistance(splitPackedValue, cell.maxPacked, pointLat, pointLon)));
}
}
NearestHit[] hits = new NearestHit[hitQueue.size()];
int downTo = hitQueue.size()-1;
while (hitQueue.size() != 0) {
hits[downTo] = hitQueue.poll();
downTo--;
}
return hits;
}
private static double approxBestDistance(byte[] minPackedValue, byte[] maxPackedValue, double pointLat, double pointLon) {
double minLat = decodeLatitude(minPackedValue, 0);
double minLon = decodeLongitude(minPackedValue, Integer.BYTES);
double maxLat = decodeLatitude(maxPackedValue, 0);
double maxLon = decodeLongitude(maxPackedValue, Integer.BYTES);
return approxBestDistance(minLat, maxLat, minLon, maxLon, pointLat, pointLon);
}
private static double approxBestDistance(double minLat, double maxLat, double minLon, double maxLon, double pointLat, double pointLon) {
if (pointLat >= minLat && pointLat <= maxLat && pointLon >= minLon && pointLon <= maxLon) {
return 0.0;
}
double d1 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, minLon);
double d2 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, maxLon);
double d3 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, maxLon);
double d4 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, minLon);
return Math.min(Math.min(d1, d2), Math.min(d3, d4));
}
}