package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.CollectionUtil;
public final class ConjunctionDISI extends DocIdSetIterator {
public static DocIdSetIterator intersectScorers(Collection<Scorer> scorers) {
if (scorers.size() < 2) {
throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators");
}
final List<DocIdSetIterator> allIterators = new ArrayList<>();
final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>();
for (Scorer scorer : scorers) {
addScorer(scorer, allIterators, twoPhaseIterators);
}
return createConjunction(allIterators, twoPhaseIterators);
}
public static DocIdSetIterator intersectIterators(List<DocIdSetIterator> iterators) {
if (iterators.size() < 2) {
throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators");
}
final List<DocIdSetIterator> allIterators = new ArrayList<>();
final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>();
for (DocIdSetIterator iterator : iterators) {
addIterator(iterator, allIterators, twoPhaseIterators);
}
return createConjunction(allIterators, twoPhaseIterators);
}
public static DocIdSetIterator intersectSpans(List<Spans> spanList) {
if (spanList.size() < 2) {
throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators");
}
final List<DocIdSetIterator> allIterators = new ArrayList<>();
final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>();
for (Spans spans : spanList) {
addSpans(spans, allIterators, twoPhaseIterators);
}
return createConjunction(allIterators, twoPhaseIterators);
}
private static void addScorer(Scorer scorer, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) {
TwoPhaseIterator twoPhaseIter = scorer.twoPhaseIterator();
if (twoPhaseIter != null) {
addTwoPhaseIterator(twoPhaseIter, allIterators, twoPhaseIterators);
} else {
addIterator(scorer.iterator(), allIterators, twoPhaseIterators);
}
}
private static void addSpans(Spans spans, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) {
TwoPhaseIterator twoPhaseIter = spans.asTwoPhaseIterator();
if (twoPhaseIter != null) {
addTwoPhaseIterator(twoPhaseIter, allIterators, twoPhaseIterators);
} else {
addIterator(spans, allIterators, twoPhaseIterators);
}
}
private static void addIterator(DocIdSetIterator disi, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) {
TwoPhaseIterator twoPhase = TwoPhaseIterator.unwrap(disi);
if (twoPhase != null) {
addTwoPhaseIterator(twoPhase, allIterators, twoPhaseIterators);
} else if (disi.getClass() == ConjunctionDISI.class) {
ConjunctionDISI conjunction = (ConjunctionDISI) disi;
allIterators.add(conjunction.lead1);
allIterators.add(conjunction.lead2);
Collections.addAll(allIterators, conjunction.others);
} else if (disi.getClass() == BitSetConjunctionDISI.class) {
BitSetConjunctionDISI conjunction = (BitSetConjunctionDISI) disi;
allIterators.add(conjunction.lead);
Collections.addAll(allIterators, conjunction.bitSetIterators);
} else {
allIterators.add(disi);
}
}
private static void addTwoPhaseIterator(TwoPhaseIterator twoPhaseIter, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) {
addIterator(twoPhaseIter.approximation(), allIterators, twoPhaseIterators);
if (twoPhaseIter.getClass() == ConjunctionTwoPhaseIterator.class) {
Collections.addAll(twoPhaseIterators, ((ConjunctionTwoPhaseIterator) twoPhaseIter).twoPhaseIterators);
} else {
twoPhaseIterators.add(twoPhaseIter);
}
}
private static DocIdSetIterator createConjunction(
List<DocIdSetIterator> allIterators,
List<TwoPhaseIterator> twoPhaseIterators) {
long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong();
List<BitSetIterator> bitSetIterators = new ArrayList<>();
List<DocIdSetIterator> iterators = new ArrayList<>();
for (DocIdSetIterator iterator : allIterators) {
if (iterator.cost() > minCost && iterator instanceof BitSetIterator) {
bitSetIterators.add((BitSetIterator) iterator);
} else {
iterators.add(iterator);
}
}
DocIdSetIterator disi;
if (iterators.size() == 1) {
disi = iterators.get(0);
} else {
disi = new ConjunctionDISI(iterators);
}
if (bitSetIterators.size() > 0) {
disi = new BitSetConjunctionDISI(disi, bitSetIterators);
}
if (twoPhaseIterators.isEmpty() == false) {
disi = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(disi, twoPhaseIterators));
}
return disi;
}
final DocIdSetIterator lead1, lead2;
final DocIdSetIterator[] others;
private ConjunctionDISI(List<? extends DocIdSetIterator> iterators) {
assert iterators.size() >= 2;
CollectionUtil.timSort(iterators, new Comparator<DocIdSetIterator>() {
@Override
public int compare(DocIdSetIterator o1, DocIdSetIterator o2) {
return Long.compare(o1.cost(), o2.cost());
}
});
lead1 = iterators.get(0);
lead2 = iterators.get(1);
others = iterators.subList(2, iterators.size()).toArray(new DocIdSetIterator[0]);
}
private int doNext(int doc) throws IOException {
advanceHead: for(;;) {
assert doc == lead1.docID();
final int next2 = lead2.advance(doc);
if (next2 != doc) {
doc = lead1.advance(next2);
if (next2 != doc) {
continue;
}
}
for (DocIdSetIterator other : others) {
if (other.docID() < doc) {
final int next = other.advance(doc);
if (next > doc) {
doc = lead1.advance(next);
continue advanceHead;
}
}
}
return doc;
}
}
@Override
public int advance(int target) throws IOException {
return doNext(lead1.advance(target));
}
@Override
public int docID() {
return lead1.docID();
}
@Override
public int nextDoc() throws IOException {
return doNext(lead1.nextDoc());
}
@Override
public long cost() {
return lead1.cost();
}
private static class BitSetConjunctionDISI extends DocIdSetIterator {
private final DocIdSetIterator lead;
private final BitSetIterator[] bitSetIterators;
private final BitSet[] bitSets;
private final int minLength;
BitSetConjunctionDISI(DocIdSetIterator lead, Collection<BitSetIterator> bitSetIterators) {
this.lead = lead;
assert bitSetIterators.size() > 0;
this.bitSetIterators = bitSetIterators.toArray(new BitSetIterator[0]);
ArrayUtil.timSort(this.bitSetIterators, (a, b) -> Long.compare(a.cost(), b.cost()));
this.bitSets = new BitSet[this.bitSetIterators.length];
int minLen = Integer.MAX_VALUE;
for (int i = 0; i < this.bitSetIterators.length; ++i) {
BitSet bitSet = this.bitSetIterators[i].getBitSet();
this.bitSets[i] = bitSet;
minLen = Math.min(minLen, bitSet.length());
}
this.minLength = minLen;
}
@Override
public int docID() {
return lead.docID();
}
@Override
public int nextDoc() throws IOException {
return doNext(lead.nextDoc());
}
@Override
public int advance(int target) throws IOException {
return doNext(lead.advance(target));
}
private int doNext(int doc) throws IOException {
advanceLead: for (;; doc = lead.nextDoc()) {
if (doc >= minLength) {
return NO_MORE_DOCS;
}
for (BitSet bitSet : bitSets) {
if (bitSet.get(doc) == false) {
continue advanceLead;
}
}
for (BitSetIterator iterator : bitSetIterators) {
iterator.setDocId(doc);
}
return doc;
}
}
@Override
public long cost() {
return lead.cost();
}
}
private static final class ConjunctionTwoPhaseIterator extends TwoPhaseIterator {
private final TwoPhaseIterator[] twoPhaseIterators;
private final float matchCost;
private ConjunctionTwoPhaseIterator(DocIdSetIterator approximation,
List<? extends TwoPhaseIterator> twoPhaseIterators) {
super(approximation);
assert twoPhaseIterators.size() > 0;
CollectionUtil.timSort(twoPhaseIterators, new Comparator<TwoPhaseIterator>() {
@Override
public int compare(TwoPhaseIterator o1, TwoPhaseIterator o2) {
return Float.compare(o1.matchCost(), o2.matchCost());
}
});
this.twoPhaseIterators = twoPhaseIterators.toArray(new TwoPhaseIterator[twoPhaseIterators.size()]);
float totalMatchCost = 0;
for (TwoPhaseIterator tpi : twoPhaseIterators) {
totalMatchCost += tpi.matchCost();
}
matchCost = totalMatchCost;
}
@Override
public boolean matches() throws IOException {
for (TwoPhaseIterator twoPhaseIterator : twoPhaseIterators) {
if (twoPhaseIterator.matches() == false) {
return false;
}
}
return true;
}
@Override
public float matchCost() {
return matchCost;
}
}
}