package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.ArrayUtil;
public abstract class QueryRescorer extends Rescorer {
private final Query query;
public QueryRescorer(Query query) {
this.query = query;
}
protected abstract float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore);
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) throws IOException {
ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone();
Arrays.sort(hits,
new Comparator<ScoreDoc>() {
@Override
public int compare(ScoreDoc a, ScoreDoc b) {
return a.doc - b.doc;
}
});
List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
Query rewritten = searcher.rewrite(query);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
int hitUpto = 0;
int readerUpto = -1;
int endDoc = 0;
int docBase = 0;
Scorer scorer = null;
while (hitUpto < hits.length) {
ScoreDoc hit = hits[hitUpto];
int docID = hit.doc;
LeafReaderContext readerContext = null;
while (docID >= endDoc) {
readerUpto++;
readerContext = leaves.get(readerUpto);
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
}
if (readerContext != null) {
docBase = readerContext.docBase;
scorer = weight.scorer(readerContext);
}
if (scorer != null) {
int targetDoc = docID - docBase;
int actualDoc = scorer.docID();
if (actualDoc < targetDoc) {
actualDoc = scorer.iterator().advance(targetDoc);
}
if (actualDoc == targetDoc) {
hit.score = combine(hit.score, true, scorer.score());
} else {
assert actualDoc > targetDoc;
hit.score = combine(hit.score, false, 0.0f);
}
} else {
hit.score = combine(hit.score, false, 0.0f);
}
hitUpto++;
}
Comparator<ScoreDoc> sortDocComparator = new Comparator<ScoreDoc>() {
@Override
public int compare(ScoreDoc a, ScoreDoc b) {
if (a.score > b.score) {
return -1;
} else if (a.score < b.score) {
return 1;
} else {
return a.doc - b.doc;
}
}
};
if (topN < hits.length) {
ArrayUtil.select(hits, 0, hits.length, topN, sortDocComparator);
ScoreDoc[] subset = new ScoreDoc[topN];
System.arraycopy(hits, 0, subset, 0, topN);
hits = subset;
}
Arrays.sort(hits, sortDocComparator);
return new TopDocs(firstPassTopDocs.totalHits, hits);
}
@Override
public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID) throws IOException {
Explanation secondPassExplanation = searcher.explain(query, docID);
Number secondPassScore = secondPassExplanation.isMatch() ? secondPassExplanation.getValue() : null;
float score;
if (secondPassScore == null) {
score = combine(firstPassExplanation.getValue().floatValue(), false, 0.0f);
} else {
score = combine(firstPassExplanation.getValue().floatValue(), true, secondPassScore.floatValue());
}
Explanation first = Explanation.match(firstPassExplanation.getValue(), "first pass score", firstPassExplanation);
Explanation second;
if (secondPassScore == null) {
second = Explanation.noMatch("no second pass score");
} else {
second = Explanation.match(secondPassScore, "second pass score", secondPassExplanation);
}
return Explanation.match(score, "combined first and second pass score using " + getClass(), first, second);
}
public static TopDocs rescore(IndexSearcher searcher, TopDocs topDocs, Query query, final double weight, int topN) throws IOException {
return new QueryRescorer(query) {
@Override
protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
float score = firstPassScore;
if (secondPassMatches) {
score += weight * secondPassScore;
}
return score;
}
}.rescore(searcher, topDocs, topN);
}
}