package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
final class BlockMaxConjunctionScorer extends Scorer {
final Scorer[] scorers;
final DocIdSetIterator[] approximations;
final TwoPhaseIterator[] twoPhases;
final MaxScoreSumPropagator maxScorePropagator;
float minScore;
BlockMaxConjunctionScorer(Weight weight, Collection<Scorer> scorersList) throws IOException {
super(weight);
this.scorers = scorersList.toArray(new Scorer[scorersList.size()]);
Arrays.sort(this.scorers, Comparator.comparingLong(s -> s.iterator().cost()));
this.maxScorePropagator = new MaxScoreSumPropagator(Arrays.asList(scorers));
this.approximations = new DocIdSetIterator[scorers.length];
List<TwoPhaseIterator> twoPhaseList = new ArrayList<>();
for (int i = 0; i < scorers.length; i++) {
Scorer scorer = scorers[i];
TwoPhaseIterator twoPhase = scorer.twoPhaseIterator();
if (twoPhase != null) {
twoPhaseList.add(twoPhase);
approximations[i] = twoPhase.approximation();
} else {
approximations[i] = scorer.iterator();
}
scorer.advanceShallow(0);
}
this.twoPhases = twoPhaseList.toArray(new TwoPhaseIterator[twoPhaseList.size()]);
Arrays.sort(this.twoPhases, Comparator.comparingDouble(TwoPhaseIterator::matchCost));
}
@Override
public TwoPhaseIterator twoPhaseIterator() {
if (twoPhases.length == 0) {
return null;
}
float matchCost = (float) Arrays.stream(twoPhases)
.mapToDouble(TwoPhaseIterator::matchCost)
.sum();
final DocIdSetIterator approx = approximation();
return new TwoPhaseIterator(approx) {
@Override
public boolean matches() throws IOException {
for (TwoPhaseIterator twoPhase : twoPhases) {
assert twoPhase.approximation().docID() == docID();
if (twoPhase.matches() == false) {
return false;
}
}
return true;
}
@Override
public float matchCost() {
return matchCost;
}
};
}
@Override
public DocIdSetIterator iterator() {
return twoPhases.length == 0 ? approximation() :
TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator());
}
private DocIdSetIterator approximation() {
final DocIdSetIterator lead = approximations[0];
return new DocIdSetIterator() {
float maxScore;
int upTo = -1;
@Override
public int docID() {
return lead.docID();
}
@Override
public long cost() {
return lead.cost();
}
private void moveToNextBlock(int target) throws IOException {
upTo = advanceShallow(target);
maxScore = getMaxScore(upTo);
}
private int advanceTarget(int target) throws IOException {
if (target > upTo) {
moveToNextBlock(target);
}
while (true) {
assert upTo >= target;
if (maxScore >= minScore) {
return target;
}
if (upTo == NO_MORE_DOCS) {
return NO_MORE_DOCS;
}
target = upTo + 1;
moveToNextBlock(target);
}
}
@Override
public int nextDoc() throws IOException {
return advance(docID() + 1);
}
@Override
public int advance(int target) throws IOException {
return doNext(lead.advance(advanceTarget(target)));
}
private int doNext(int doc) throws IOException {
advanceHead: for(;;) {
assert doc == lead.docID();
if (doc == NO_MORE_DOCS) {
return NO_MORE_DOCS;
}
if (doc > upTo) {
final int nextTarget = advanceTarget(doc);
if (nextTarget != doc) {
doc = lead.advance(nextTarget);
continue;
}
}
assert doc <= upTo;
for (int i = 1; i < approximations.length; ++i) {
final DocIdSetIterator other = approximations[i];
if (other.docID() < doc) {
final int next = other.advance(doc);
if (next > doc) {
doc = lead.advance(advanceTarget(next));
continue advanceHead;
}
}
assert other.docID() == doc;
}
return doc;
}
}
};
}
@Override
public int docID() {
return scorers[0].docID();
}
@Override
public float score() throws IOException {
double score = 0;
for (Scorer scorer : scorers) {
score += scorer.score();
}
return (float) score;
}
@Override
public int advanceShallow(int target) throws IOException {
int result = scorers[0].advanceShallow(target);
for (int i = 1; i < scorers.length; ++i) {
scorers[i].advanceShallow(target);
}
return result;
}
@Override
public float getMaxScore(int upTo) throws IOException {
double sum = 0;
for (Scorer scorer : scorers) {
sum += scorer.getMaxScore(upTo);
}
return (float) sum;
}
@Override
public void setMinCompetitiveScore(float score) throws IOException {
minScore = score;
maxScorePropagator.setMinCompetitiveScore(score);
}
@Override
public Collection<ChildScorable> getChildren() {
ArrayList<ChildScorable> children = new ArrayList<>();
for (Scorer scorer : scorers) {
children.add(new ChildScorable(scorer, "MUST"));
}
return children;
}
}