package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.FieldValueHitQueue.Entry;
import org.apache.lucene.search.MaxScoreAccumulator.DocAndScore;
import org.apache.lucene.search.TotalHits.Relation;
import org.apache.lucene.util.FutureObjects;
public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
private static abstract class MultiComparatorLeafCollector implements LeafCollector {
final LeafFieldComparator comparator;
final int reverseMul;
Scorable scorer;
MultiComparatorLeafCollector(LeafFieldComparator[] comparators, int[] reverseMul) {
if (comparators.length == 1) {
this.reverseMul = reverseMul[0];
this.comparator = comparators[0];
} else {
this.reverseMul = 1;
this.comparator = new MultiLeafFieldComparator(comparators, reverseMul);
}
}
@Override
public void setScorer(Scorable scorer) throws IOException {
comparator.setScorer(scorer);
this.scorer = scorer;
}
}
static boolean canEarlyTerminate(Sort searchSort, Sort indexSort) {
return canEarlyTerminateOnDocId(searchSort) ||
canEarlyTerminateOnPrefix(searchSort, indexSort);
}
private static boolean canEarlyTerminateOnDocId(Sort searchSort) {
final SortField[] fields1 = searchSort.getSort();
return SortField.FIELD_DOC.equals(fields1[0]);
}
private static boolean canEarlyTerminateOnPrefix(Sort searchSort, Sort indexSort) {
if (indexSort != null) {
final SortField[] fields1 = searchSort.getSort();
final SortField[] fields2 = indexSort.getSort();
if (fields1.length > fields2.length) {
return false;
}
return Arrays.asList(fields1).equals(Arrays.asList(fields2).subList(0, fields1.length));
} else {
return false;
}
}
private static class SimpleFieldCollector extends TopFieldCollector {
final Sort sort;
final FieldValueHitQueue<Entry> queue;
public SimpleFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, int numHits,
HitsThresholdChecker hitsThresholdChecker,
MaxScoreAccumulator minScoreAcc) {
super(queue, numHits, hitsThresholdChecker, sort.needsScores(), minScoreAcc);
this.sort = sort;
this.queue = queue;
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
docBase = context.docBase;
final LeafFieldComparator[] comparators = queue.getComparators(context);
final int[] reverseMul = queue.getReverseMul();
final Sort indexSort = context.reader().getMetaData().getSort();
final boolean canEarlyTerminate = canEarlyTerminate(sort, indexSort);
return new MultiComparatorLeafCollector(comparators, reverseMul) {
boolean collectedAllCompetitiveHits = false;
@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
minCompetitiveScore = 0f;
updateMinCompetitiveScore(scorer);
if (minScoreAcc != null) {
updateGlobalMinCompetitiveScore(scorer);
}
}
@Override
public void collect(int doc) throws IOException {
++totalHits;
hitsThresholdChecker.incrementHitCount();
if (minScoreAcc != null && (totalHits & minScoreAcc.modInterval) == 0) {
updateGlobalMinCompetitiveScore(scorer);
}
if (scoreMode.isExhaustive() == false && totalHitsRelation == TotalHits.Relation.EQUAL_TO &&
hitsThresholdChecker.isThresholdReached()) {
comparator.setHitsThresholdReached();
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
}
if (queueFull) {
if (collectedAllCompetitiveHits || reverseMul * comparator.compareBottom(doc) <= 0) {
if (canEarlyTerminate) {
if (hitsThresholdChecker.isThresholdReached()) {
totalHitsRelation = Relation.GREATER_THAN_OR_EQUAL_TO;
throw new CollectionTerminatedException();
} else {
collectedAllCompetitiveHits = true;
}
} else if (totalHitsRelation == Relation.EQUAL_TO) {
updateMinCompetitiveScore(scorer);
}
return;
}
comparator.copy(bottom.slot, doc);
updateBottom(doc);
comparator.setBottom(bottom.slot);
updateMinCompetitiveScore(scorer);
} else {
final int slot = totalHits - 1;
comparator.copy(slot, doc);
add(slot, doc);
if (queueFull) {
comparator.setBottom(bottom.slot);
updateMinCompetitiveScore(scorer);
}
}
}
@Override
public DocIdSetIterator competitiveIterator() throws IOException {
return comparator.competitiveIterator();
}
};
}
}
private final static class PagingFieldCollector extends TopFieldCollector {
final Sort sort;
int collectedHits;
final FieldValueHitQueue<Entry> queue;
final FieldDoc after;
public PagingFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, FieldDoc after, int numHits,
HitsThresholdChecker hitsThresholdChecker, MaxScoreAccumulator minScoreAcc) {
super(queue, numHits, hitsThresholdChecker, sort.needsScores(), minScoreAcc);
this.sort = sort;
this.queue = queue;
this.after = after;
FieldComparator<?>[] comparators = queue.comparators;
for(int i=0;i<comparators.length;i++) {
@SuppressWarnings("unchecked")
FieldComparator<Object> comparator = (FieldComparator<Object>) comparators[i];
comparator.setTopValue(after.fields[i]);
}
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
docBase = context.docBase;
final int afterDoc = after.doc - docBase;
final Sort indexSort = context.reader().getMetaData().getSort();
final boolean canEarlyTerminate = canEarlyTerminate(sort, indexSort);
return new MultiComparatorLeafCollector(queue.getComparators(context), queue.getReverseMul()) {
boolean collectedAllCompetitiveHits = false;
@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
minCompetitiveScore = 0f;
updateMinCompetitiveScore(scorer);
if (minScoreAcc != null) {
updateGlobalMinCompetitiveScore(scorer);
}
}
@Override
public void collect(int doc) throws IOException {
totalHits++;
hitsThresholdChecker.incrementHitCount();
if (minScoreAcc != null && (totalHits & minScoreAcc.modInterval) == 0) {
updateGlobalMinCompetitiveScore(scorer);
}
if (scoreMode.isExhaustive() == false && totalHitsRelation == TotalHits.Relation.EQUAL_TO &&
hitsThresholdChecker.isThresholdReached()) {
comparator.setHitsThresholdReached();
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
}
if (queueFull) {
if (collectedAllCompetitiveHits || reverseMul * comparator.compareBottom(doc) <= 0) {
if (canEarlyTerminate) {
if (hitsThresholdChecker.isThresholdReached()) {
totalHitsRelation = Relation.GREATER_THAN_OR_EQUAL_TO;
throw new CollectionTerminatedException();
} else {
collectedAllCompetitiveHits = true;
}
} else if (totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
updateMinCompetitiveScore(scorer);
}
return;
}
}
final int topCmp = reverseMul * comparator.compareTop(doc);
if (topCmp > 0 || (topCmp == 0 && doc <= afterDoc)) {
if (totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
updateMinCompetitiveScore(scorer);
}
return;
}
if (queueFull) {
comparator.copy(bottom.slot, doc);
updateBottom(doc);
comparator.setBottom(bottom.slot);
updateMinCompetitiveScore(scorer);
} else {
collectedHits++;
final int slot = collectedHits - 1;
comparator.copy(slot, doc);
bottom = pq.add(new Entry(slot, docBase + doc));
queueFull = collectedHits == numHits;
if (queueFull) {
comparator.setBottom(bottom.slot);
updateMinCompetitiveScore(scorer);
}
}
}
@Override
public DocIdSetIterator competitiveIterator() throws IOException {
return comparator.competitiveIterator();
}
};
}
}
private static final ScoreDoc[] EMPTY_SCOREDOCS = new ScoreDoc[0];
final int numHits;
final HitsThresholdChecker hitsThresholdChecker;
final FieldComparator.RelevanceComparator relevanceComparator;
final boolean canSetMinScore;
final MaxScoreAccumulator minScoreAcc;
float minCompetitiveScore;
final int numComparators;
FieldValueHitQueue.Entry bottom = null;
boolean queueFull;
int docBase;
final boolean needsScores;
final ScoreMode scoreMode;
private TopFieldCollector(FieldValueHitQueue<Entry> pq, int numHits,
HitsThresholdChecker hitsThresholdChecker, boolean needsScores,
MaxScoreAccumulator minScoreAcc) {
super(pq);
this.needsScores = needsScores;
this.numHits = numHits;
this.hitsThresholdChecker = hitsThresholdChecker;
this.numComparators = pq.getComparators().length;
FieldComparator<?> firstComparator = pq.getComparators()[0];
int reverseMul = pq.reverseMul[0];
if (firstComparator.getClass().equals(FieldComparator.RelevanceComparator.class)
&& reverseMul == 1
&& hitsThresholdChecker.getHitsThreshold() != Integer.MAX_VALUE) {
relevanceComparator = (FieldComparator.RelevanceComparator) firstComparator;
scoreMode = ScoreMode.TOP_SCORES;
canSetMinScore = true;
} else {
relevanceComparator = null;
canSetMinScore = false;
if (hitsThresholdChecker.getHitsThreshold() != Integer.MAX_VALUE) {
scoreMode = needsScores ? ScoreMode.TOP_DOCS_WITH_SCORES : ScoreMode.TOP_DOCS;
} else {
scoreMode = needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
}
}
this.minScoreAcc = minScoreAcc;
}
@Override
public ScoreMode scoreMode() {
return scoreMode;
}
protected void updateGlobalMinCompetitiveScore(Scorable scorer) throws IOException {
assert minScoreAcc != null;
if (canSetMinScore
&& hitsThresholdChecker.isThresholdReached()) {
DocAndScore maxMinScore = minScoreAcc.get();
if (maxMinScore != null && maxMinScore.score > minCompetitiveScore) {
scorer.setMinCompetitiveScore(maxMinScore.score);
minCompetitiveScore = maxMinScore.score;
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
}
}
}
protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
if (canSetMinScore
&& queueFull
&& hitsThresholdChecker.isThresholdReached()) {
assert bottom != null && relevanceComparator != null;
float minScore = relevanceComparator.value(bottom.slot);
if (minScore > minCompetitiveScore) {
scorer.setMinCompetitiveScore(minScore);
minCompetitiveScore = minScore;
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
if (minScoreAcc != null) {
minScoreAcc.accumulate(bottom.doc, minScore);
}
}
}
}
public static TopFieldCollector create(Sort sort, int numHits, int totalHitsThreshold) {
return create(sort, numHits, null, totalHitsThreshold);
}
public static TopFieldCollector create(Sort sort, int numHits, FieldDoc after, int totalHitsThreshold) {
if (totalHitsThreshold < 0) {
throw new IllegalArgumentException("totalHitsThreshold must be >= 0, got " + totalHitsThreshold);
}
return create(sort, numHits, after, HitsThresholdChecker.create(Math.max(totalHitsThreshold, numHits)), null );
}
static TopFieldCollector create(Sort sort, int numHits, FieldDoc after,
HitsThresholdChecker hitsThresholdChecker, MaxScoreAccumulator minScoreAcc) {
if (sort.fields.length == 0) {
throw new IllegalArgumentException("Sort must contain at least one field");
}
if (numHits <= 0) {
throw new IllegalArgumentException("numHits must be > 0; please use TotalHitCountCollector if you just need the total hit count");
}
if (hitsThresholdChecker == null) {
throw new IllegalArgumentException("hitsThresholdChecker should not be null");
}
FieldValueHitQueue<Entry> queue = FieldValueHitQueue.create(sort.fields, numHits);
if (after == null) {
return new SimpleFieldCollector(sort, queue, numHits, hitsThresholdChecker, minScoreAcc);
} else {
if (after.fields == null) {
throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search");
}
if (after.fields.length != sort.getSort().length) {
throw new IllegalArgumentException("after.fields has " + after.fields.length + " values but sort has " + sort.getSort().length);
}
return new PagingFieldCollector(sort, queue, after, numHits, hitsThresholdChecker, minScoreAcc);
}
}
public static CollectorManager<TopFieldCollector, TopFieldDocs> createSharedManager(Sort sort, int numHits, FieldDoc after,
int totalHitsThreshold) {
return new CollectorManager<TopFieldCollector, TopFieldDocs>() {
private final HitsThresholdChecker hitsThresholdChecker = HitsThresholdChecker.createShared(Math.max(totalHitsThreshold, numHits));
private final MaxScoreAccumulator minScoreAcc = new MaxScoreAccumulator();
@Override
public TopFieldCollector newCollector() throws IOException {
return create(sort, numHits, after, hitsThresholdChecker, minScoreAcc);
}
@Override
public TopFieldDocs reduce(Collection<TopFieldCollector> collectors) throws IOException {
final TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()];
int i = 0;
for (TopFieldCollector collector : collectors) {
topDocs[i++] = collector.topDocs();
}
return TopDocs.merge(sort, numHits, topDocs);
}
};
}
public static void populateScores(ScoreDoc[] topDocs, IndexSearcher searcher, Query query) throws IOException {
topDocs = topDocs.clone();
Arrays.sort(topDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
final Weight weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE, 1);
List<LeafReaderContext> contexts = searcher.getIndexReader().leaves();
LeafReaderContext currentContext = null;
Scorer currentScorer = null;
for (ScoreDoc scoreDoc : topDocs) {
if (currentContext == null || scoreDoc.doc >= currentContext.docBase + currentContext.reader().maxDoc()) {
FutureObjects.checkIndex(scoreDoc.doc, searcher.getIndexReader().maxDoc());
int newContextIndex = ReaderUtil.subIndex(scoreDoc.doc, contexts);
currentContext = contexts.get(newContextIndex);
final ScorerSupplier scorerSupplier = weight.scorerSupplier(currentContext);
if (scorerSupplier == null) {
throw new IllegalArgumentException("Doc id " + scoreDoc.doc + " doesn't match the query");
}
currentScorer = scorerSupplier.get(1);
}
final int leafDoc = scoreDoc.doc - currentContext.docBase;
assert leafDoc >= 0;
final int advanced = currentScorer.iterator().advance(leafDoc);
if (leafDoc != advanced) {
throw new IllegalArgumentException("Doc id " + scoreDoc.doc + " doesn't match the query");
}
scoreDoc.score = currentScorer.score();
}
}
final void add(int slot, int doc) {
bottom = pq.add(new Entry(slot, docBase + doc));
queueFull = totalHits == numHits;
}
final void updateBottom(int doc) {
bottom.doc = docBase + doc;
bottom = pq.updateTop();
}
@Override
protected void populateResults(ScoreDoc[] results, int howMany) {
FieldValueHitQueue<Entry> queue = (FieldValueHitQueue<Entry>) pq;
for (int i = howMany - 1; i >= 0; i--) {
results[i] = queue.fillFields(queue.pop());
}
}
@Override
protected TopDocs newTopDocs(ScoreDoc[] results, int start) {
if (results == null) {
results = EMPTY_SCOREDOCS;
}
return new TopFieldDocs(new TotalHits(totalHits, totalHitsRelation), results, ((FieldValueHitQueue<Entry>) pq).getFields());
}
@Override
public TopFieldDocs topDocs() {
return (TopFieldDocs) super.topDocs();
}
public boolean isEarlyTerminated() {
return totalHitsRelation == Relation.GREATER_THAN_OR_EQUAL_TO;
}
}