package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.DocIdSetBuilder;
final class MultiTermQueryConstantScoreWrapper<Q extends MultiTermQuery> extends Query {
private static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16;
private static class TermAndState {
final BytesRef term;
final TermState state;
final int docFreq;
final long totalTermFreq;
TermAndState(BytesRef term, TermState state, int docFreq, long totalTermFreq) {
this.term = term;
this.state = state;
this.docFreq = docFreq;
this.totalTermFreq = totalTermFreq;
}
}
private static class WeightOrDocIdSet {
final Weight weight;
final DocIdSet set;
WeightOrDocIdSet(Weight weight) {
this.weight = Objects.requireNonNull(weight);
this.set = null;
}
WeightOrDocIdSet(DocIdSet bitset) {
this.set = bitset;
this.weight = null;
}
}
protected final Q query;
protected MultiTermQueryConstantScoreWrapper(Q query) {
this.query = query;
}
@Override
public String toString(String field) {
return query.toString(field);
}
@Override
public final boolean equals(final Object other) {
return sameClassAs(other) &&
query.equals(((MultiTermQueryConstantScoreWrapper<?>) other).query);
}
@Override
public final int hashCode() {
return 31 * classHash() + query.hashCode();
}
public Q getQuery() { return query; }
public final String getField() { return query.getField(); }
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new ConstantScoreWeight(this, boost) {
private boolean collectTerms(LeafReaderContext context, TermsEnum termsEnum, List<TermAndState> terms) throws IOException {
final int threshold = Math.min(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD, BooleanQuery.getMaxClauseCount());
for (int i = 0; i < threshold; ++i) {
final BytesRef term = termsEnum.next();
if (term == null) {
return true;
}
TermState state = termsEnum.termState();
terms.add(new TermAndState(BytesRef.deepCopyOf(term), state, termsEnum.docFreq(), termsEnum.totalTermFreq()));
}
return termsEnum.next() == null;
}
private WeightOrDocIdSet rewrite(LeafReaderContext context) throws IOException {
final Terms terms = context.reader().terms(query.field);
if (terms == null) {
return new WeightOrDocIdSet((DocIdSet) null);
}
final TermsEnum termsEnum = query.getTermsEnum(terms);
assert termsEnum != null;
PostingsEnum docs = null;
final List<TermAndState> collectedTerms = new ArrayList<>();
if (collectTerms(context, termsEnum, collectedTerms)) {
BooleanQuery.Builder bq = new BooleanQuery.Builder();
for (TermAndState t : collectedTerms) {
final TermStates termStates = new TermStates(searcher.getTopReaderContext());
termStates.register(t.state, context.ord, t.docFreq, t.totalTermFreq);
bq.add(new TermQuery(new Term(query.field, t.term), termStates), Occur.SHOULD);
}
Query q = new ConstantScoreQuery(bq.build());
final Weight weight = searcher.rewrite(q).createWeight(searcher, scoreMode, score());
return new WeightOrDocIdSet(weight);
}
DocIdSetBuilder builder = new DocIdSetBuilder(context.reader().maxDoc(), terms);
if (collectedTerms.isEmpty() == false) {
TermsEnum termsEnum2 = terms.iterator();
for (TermAndState t : collectedTerms) {
termsEnum2.seekExact(t.term, t.state);
docs = termsEnum2.postings(docs, PostingsEnum.NONE);
builder.add(docs);
}
}
do {
docs = termsEnum.postings(docs, PostingsEnum.NONE);
builder.add(docs);
} while (termsEnum.next() != null);
return new WeightOrDocIdSet(builder.build());
}
private Scorer scorer(DocIdSet set) throws IOException {
if (set == null) {
return null;
}
final DocIdSetIterator disi = set.iterator();
if (disi == null) {
return null;
}
return new ConstantScoreScorer(this, score(), scoreMode, disi);
}
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
final WeightOrDocIdSet weightOrBitSet = rewrite(context);
if (weightOrBitSet.weight != null) {
return weightOrBitSet.weight.bulkScorer(context);
} else {
final Scorer scorer = scorer(weightOrBitSet.set);
if (scorer == null) {
return null;
}
return new DefaultBulkScorer(scorer);
}
}
@Override
public Matches matches(LeafReaderContext context, int doc) throws IOException {
final Terms terms = context.reader().terms(query.field);
if (terms == null) {
return null;
}
if (terms.hasPositions() == false) {
return super.matches(context, doc);
}
return MatchesUtils.forField(query.field, () -> DisjunctionMatchesIterator.fromTermsEnum(context, doc, query, query.field, query.getTermsEnum(terms)));
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
final WeightOrDocIdSet weightOrBitSet = rewrite(context);
if (weightOrBitSet.weight != null) {
return weightOrBitSet.weight.scorer(context);
} else {
return scorer(weightOrBitSet.set);
}
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(getField())) {
query.visit(visitor.getSubVisitor(Occur.FILTER, this));
}
}
}