package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.lucene.util.PriorityQueue;
abstract class DisjunctionScorer extends Scorer {
private final boolean needsScores;
private final DisiPriorityQueue subScorers;
private final DocIdSetIterator approximation;
private final BlockMaxDISI blockMaxApprox;
private final TwoPhase twoPhase;
protected DisjunctionScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight);
if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
}
this.subScorers = new DisiPriorityQueue(subScorers.size());
for (Scorer scorer : subScorers) {
final DisiWrapper w = new DisiWrapper(scorer);
this.subScorers.add(w);
}
this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
if (scoreMode == ScoreMode.TOP_SCORES) {
for (Scorer scorer : subScorers) {
scorer.advanceShallow(0);
}
this.blockMaxApprox = new BlockMaxDISI(new DisjunctionDISIApproximation(this.subScorers), this);
this.approximation = blockMaxApprox;
} else {
this.approximation = new DisjunctionDISIApproximation(this.subScorers);
this.blockMaxApprox = null;
}
boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
for (DisiWrapper w : this.subScorers) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
}
if (hasApproximation == false) {
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost);
}
}
@Override
public DocIdSetIterator iterator() {
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);
} else {
return approximation;
}
}
@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}
private class TwoPhase extends TwoPhaseIterator {
private final float matchCost;
DisiWrapper verifiedMatches;
final PriorityQueue<DisiWrapper> unverifiedMatches;
private TwoPhase(DocIdSetIterator approximation, float matchCost) {
super(approximation);
this.matchCost = matchCost;
unverifiedMatches = new PriorityQueue<DisiWrapper>(DisjunctionScorer.this.subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
}
DisiWrapper getSubMatches() throws IOException {
for (DisiWrapper w : unverifiedMatches) {
if (w.twoPhaseView.matches()) {
w.next = verifiedMatches;
verifiedMatches = w;
}
}
unverifiedMatches.clear();
return verifiedMatches;
}
@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();
for (DisiWrapper w = subScorers.topList(); w != null; ) {
DisiWrapper next = w.next;
if (w.twoPhaseView == null) {
w.next = verifiedMatches;
verifiedMatches = w;
if (needsScores == false) {
return true;
}
} else {
unverifiedMatches.add(w);
}
w = next;
}
if (verifiedMatches != null) {
return true;
}
while (unverifiedMatches.size() > 0) {
DisiWrapper w = unverifiedMatches.pop();
if (w.twoPhaseView.matches()) {
w.next = null;
verifiedMatches = w;
return true;
}
}
return false;
}
@Override
public float matchCost() {
return matchCost;
}
}
@Override
public final int docID() {
return subScorers.top().doc;
}
BlockMaxDISI getBlockMaxApprox() {
return blockMaxApprox;
}
DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorers.topList();
} else {
return twoPhase.getSubMatches();
}
}
@Override
public final float score() throws IOException {
return score(getSubMatches());
}
protected abstract float score(DisiWrapper topList) throws IOException;
@Override
public final Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));
}
return children;
}
}