package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
class ConjunctionScorer extends Scorer {
final DocIdSetIterator disi;
final Scorer[] scorers;
final Collection<Scorer> required;
ConjunctionScorer(Weight weight, Collection<Scorer> required, Collection<Scorer> scorers) throws IOException {
super(weight);
assert required.containsAll(scorers);
this.disi = ConjunctionDISI.intersectScorers(required);
this.scorers = scorers.toArray(new Scorer[scorers.size()]);
this.required = required;
}
@Override
public TwoPhaseIterator twoPhaseIterator() {
return TwoPhaseIterator.unwrap(disi);
}
@Override
public DocIdSetIterator iterator() {
return disi;
}
@Override
public int docID() {
return disi.docID();
}
@Override
public float score() throws IOException {
double sum = 0.0d;
for (Scorer scorer : scorers) {
sum += scorer.score();
}
return (float) sum;
}
@Override
public float getMaxScore(int upTo) throws IOException {
switch (scorers.length) {
case 0:
return 0;
case 1:
return scorers[0].getMaxScore(upTo);
default:
return Float.POSITIVE_INFINITY;
}
}
@Override
public int advanceShallow(int target) throws IOException {
if (scorers.length == 1) {
return scorers[0].advanceShallow(target);
}
return super.advanceShallow(target);
}
@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
if (scorers.length == 1) {
scorers[0].setMinCompetitiveScore(minScore);
}
}
@Override
public Collection<ChildScorable> getChildren() {
ArrayList<ChildScorable> children = new ArrayList<>();
for (Scorer scorer : required) {
children.add(new ChildScorable(scorer, "MUST"));
}
return children;
}
static final class DocsAndFreqs {
final long cost;
final DocIdSetIterator iterator;
int doc = -1;
DocsAndFreqs(DocIdSetIterator iterator) {
this.iterator = iterator;
this.cost = iterator.cost();
}
}
}