/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.LongStream;
import java.util.stream.StreamSupport;
import org.apache.lucene.util.PriorityQueue;
import static org.apache.lucene.search.DisiPriorityQueue.leftNode;
import static org.apache.lucene.search.DisiPriorityQueue.parentNode;
import static org.apache.lucene.search.DisiPriorityQueue.rightNode;
A Scorer
for BooleanQuery
when minShouldMatch
is between 2 and the total number of clauses. This implementation keeps sub scorers in 3 different places: - lead: a linked list of scorer that are positioned on the desired doc ID - tail: a heap that contains at most minShouldMatch - 1 scorers that are behind the desired doc ID. These scorers are ordered by cost so that we can advance the least costly ones first. - head: a heap that contains scorers which are beyond the desired doc ID, ordered by doc ID in order to move quickly to the next candidate. Finding the next match consists of first setting the desired doc ID to the least entry in 'head' and then advance 'tail' until there is a match. /**
* A {@link Scorer} for {@link BooleanQuery} when
* {@link BooleanQuery.Builder#setMinimumNumberShouldMatch(int) minShouldMatch} is
* between 2 and the total number of clauses.
*
* This implementation keeps sub scorers in 3 different places:
* - lead: a linked list of scorer that are positioned on the desired doc ID
* - tail: a heap that contains at most minShouldMatch - 1 scorers that are
* behind the desired doc ID. These scorers are ordered by cost so that we
* can advance the least costly ones first.
* - head: a heap that contains scorers which are beyond the desired doc ID,
* ordered by doc ID in order to move quickly to the next candidate.
*
* Finding the next match consists of first setting the desired doc ID to the
* least entry in 'head' and then advance 'tail' until there is a match.
*/
final class MinShouldMatchSumScorer extends Scorer {
static long cost(LongStream costs, int numScorers, int minShouldMatch) {
// the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m
// could be rewritten to:
// (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m))
// if we assume that clauses come in ascending cost, then
// the cost of the first part is the cost of c1 (because the cost of a conjunction is
// the cost of the least costly clause)
// the cost of the second part is the cost of finding m matches among the c2...cn
// remaining clauses
// since it is a disjunction overall, the total cost is the sum of the costs of these
// two parts
// If we recurse infinitely, we find out that the cost of a msm query is the sum of the
// costs of the num_scorers - minShouldMatch + 1 least costly scorers
final PriorityQueue<Long> pq = new PriorityQueue<Long>(numScorers - minShouldMatch + 1) {
@Override
protected boolean lessThan(Long a, Long b) {
return a > b;
}
};
costs.forEach(pq::insertWithOverflow);
return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
}
final int minShouldMatch;
// list of scorers which 'lead' the iteration and are currently
// positioned on 'doc'
DisiWrapper lead;
int doc; // current doc ID of the leads
int freq; // number of scorers on the desired doc ID
// priority queue of scorers that are too advanced compared to the current
// doc. Ordered by doc ID.
final DisiPriorityQueue head;
// priority queue of scorers which are behind the current doc.
// Ordered by cost.
final DisiWrapper[] tail;
int tailSize;
final long cost;
MinShouldMatchSumScorer(Weight weight, Collection<Scorer> scorers, int minShouldMatch) {
super(weight);
if (minShouldMatch > scorers.size()) {
throw new IllegalArgumentException("minShouldMatch should be <= the number of scorers");
}
if (minShouldMatch < 1) {
throw new IllegalArgumentException("minShouldMatch should be >= 1");
}
this.minShouldMatch = minShouldMatch;
this.doc = -1;
head = new DisiPriorityQueue(scorers.size() - minShouldMatch + 1);
// there can be at most minShouldMatch - 1 scorers beyond the current position
// otherwise we might be skipping over matching documents
tail = new DisiWrapper[minShouldMatch - 1];
for (Scorer scorer : scorers) {
addLead(new DisiWrapper(scorer));
}
this.cost = cost(scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost), scorers.size(), minShouldMatch);
}
@Override
public final Collection<ChildScorable> getChildren() throws IOException {
List<ChildScorable> matchingChildren = new ArrayList<>();
updateFreq();
for (DisiWrapper s = lead; s != null; s = s.next) {
matchingChildren.add(new ChildScorable(s.scorer, "SHOULD"));
}
return matchingChildren;
}
@Override
public DocIdSetIterator iterator() {
return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator());
}
@Override
public TwoPhaseIterator twoPhaseIterator() {
DocIdSetIterator approximation = new DocIdSetIterator() {
@Override
public int docID() {
assert doc == lead.doc;
return doc;
}
@Override
public int nextDoc() throws IOException {
// We are moving to the next doc ID, so scorers in 'lead' need to go in
// 'tail'. If there is not enough space in 'tail', then we take the least
// costly scorers and advance them.
for (DisiWrapper s = lead; s != null; s = s.next) {
final DisiWrapper evicted = insertTailWithOverFlow(s);
if (evicted != null) {
if (evicted.doc == doc) {
evicted.doc = evicted.iterator.nextDoc();
} else {
evicted.doc = evicted.iterator.advance(doc + 1);
}
head.add(evicted);
}
}
setDocAndFreq();
// It would be correct to return doNextCandidate() at this point but if you
// call nextDoc as opposed to advance, it probably means that you really
// need the next match. Returning 'doc' here would lead to a similar
// iteration over sub postings overall except that the decision making would
// happen at a higher level where more abstractions are involved and
// benchmarks suggested it causes a significant performance hit.
return doNext();
}
@Override
public int advance(int target) throws IOException {
// Same logic as in nextDoc
for (DisiWrapper s = lead; s != null; s = s.next) {
final DisiWrapper evicted = insertTailWithOverFlow(s);
if (evicted != null) {
evicted.doc = evicted.iterator.advance(target);
head.add(evicted);
}
}
// But this time there might also be scorers in 'head' behind the desired
// target so we need to do the same thing that we did on 'lead' on 'head'
DisiWrapper headTop = head.top();
while (headTop.doc < target) {
final DisiWrapper evicted = insertTailWithOverFlow(headTop);
// We know that the tail is full since it contains at most
// minShouldMatch - 1 entries and we just moved at least minShouldMatch
// entries to it, so evicted is not null
evicted.doc = evicted.iterator.advance(target);
headTop = head.updateTop(evicted);
}
setDocAndFreq();
return doNextCandidate();
}
@Override
public long cost() {
return cost;
}
};
return new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
while (freq < minShouldMatch) {
assert freq > 0;
if (freq + tailSize >= minShouldMatch) {
// a match on doc is still possible, try to
// advance scorers from the tail
advanceTail();
} else {
return false;
}
}
return true;
}
@Override
public float matchCost() {
// maximum number of scorer that matches() might advance
return tail.length;
}
};
}
private void addLead(DisiWrapper lead) {
lead.next = this.lead;
this.lead = lead;
freq += 1;
}
private void pushBackLeads() throws IOException {
for (DisiWrapper s = lead; s != null; s = s.next) {
addTail(s);
}
}
private void advanceTail(DisiWrapper top) throws IOException {
top.doc = top.iterator.advance(doc);
if (top.doc == doc) {
addLead(top);
} else {
head.add(top);
}
}
private void advanceTail() throws IOException {
final DisiWrapper top = popTail();
advanceTail(top);
}
Reinitializes head, freq and doc from 'head' /** Reinitializes head, freq and doc from 'head' */
private void setDocAndFreq() {
assert head.size() > 0;
// The top of `head` defines the next potential match
// pop all documents which are on this doc
lead = head.pop();
lead.next = null;
freq = 1;
doc = lead.doc;
while (head.size() > 0 && head.top().doc == doc) {
addLead(head.pop());
}
}
Advance tail to the lead until there is a match. /** Advance tail to the lead until there is a match. */
private int doNext() throws IOException {
while (freq < minShouldMatch) {
assert freq > 0;
if (freq + tailSize >= minShouldMatch) {
// a match on doc is still possible, try to
// advance scorers from the tail
advanceTail();
} else {
// no match on doc is possible anymore, move to the next potential match
pushBackLeads();
setDocAndFreq();
}
}
return doc;
}
Move iterators to the tail until the cumulated size of lead+tail is
greater than or equal to minShouldMath /** Move iterators to the tail until the cumulated size of lead+tail is
* greater than or equal to minShouldMath */
private int doNextCandidate() throws IOException {
while (freq + tailSize < minShouldMatch) {
// no match on doc is possible, move to the next potential match
pushBackLeads();
setDocAndFreq();
}
return doc;
}
Advance all entries from the tail to know about all matches on the
current doc. /** Advance all entries from the tail to know about all matches on the
* current doc. */
private void updateFreq() throws IOException {
assert freq >= minShouldMatch;
// we return the next doc when there are minShouldMatch matching clauses
// but some of the clauses in 'tail' might match as well
// in general we want to advance least-costly clauses first in order to
// skip over non-matching documents as fast as possible. However here,
// we are advancing everything anyway so iterating over clauses in
// (roughly) cost-descending order might help avoid some permutations in
// the head heap
for (int i = tailSize - 1; i >= 0; --i) {
advanceTail(tail[i]);
}
tailSize = 0;
}
@Override
public float score() throws IOException {
// we need to know about all matches
updateFreq();
double score = 0;
for (DisiWrapper s = lead; s != null; s = s.next) {
score += s.scorer.score();
}
return (float) score;
}
@Override
public float getMaxScore(int upTo) throws IOException {
// TODO: implement but be careful about floating-point errors.
return Float.POSITIVE_INFINITY;
}
@Override
public int docID() {
assert doc == lead.doc;
return doc;
}
Insert an entry in 'tail' and evict the least-costly scorer if full. /** Insert an entry in 'tail' and evict the least-costly scorer if full. */
private DisiWrapper insertTailWithOverFlow(DisiWrapper s) {
if (tailSize < tail.length) {
addTail(s);
return null;
} else if (tail.length >= 1) {
final DisiWrapper top = tail[0];
if (top.cost < s.cost) {
tail[0] = s;
downHeapCost(tail, tailSize);
return top;
}
}
return s;
}
Add an entry to 'tail'. Fails if over capacity. /** Add an entry to 'tail'. Fails if over capacity. */
private void addTail(DisiWrapper s) {
tail[tailSize] = s;
upHeapCost(tail, tailSize);
tailSize += 1;
}
Pop the least-costly scorer from 'tail'. /** Pop the least-costly scorer from 'tail'. */
private DisiWrapper popTail() {
assert tailSize > 0;
final DisiWrapper result = tail[0];
tail[0] = tail[--tailSize];
downHeapCost(tail, tailSize);
return result;
}
Heap helpers /** Heap helpers */
private static void upHeapCost(DisiWrapper[] heap, int i) {
final DisiWrapper node = heap[i];
final long nodeCost = node.cost;
int j = parentNode(i);
while (j >= 0 && nodeCost < heap[j].cost) {
heap[i] = heap[j];
i = j;
j = parentNode(j);
}
heap[i] = node;
}
private static void downHeapCost(DisiWrapper[] heap, int size) {
int i = 0;
final DisiWrapper node = heap[0];
int j = leftNode(i);
if (j < size) {
int k = rightNode(j);
if (k < size && heap[k].cost < heap[j].cost) {
j = k;
}
if (heap[j].cost < node.cost) {
do {
heap[i] = heap[j];
i = j;
j = leftNode(i);
k = rightNode(j);
if (k < size && heap[k].cost < heap[j].cost) {
j = k;
}
} while (j < size && heap[j].cost < node.cost);
heap[i] = node;
}
}
}
}