/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.PrefixCodedTerms;
import org.apache.lucene.index.PrefixCodedTerms.TermIterator;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.automaton.Automata;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.ByteRunAutomaton;
import org.apache.lucene.util.automaton.CompiledAutomaton;
import org.apache.lucene.util.automaton.Operations;

Specialization for a disjunction over many terms that behaves like a ConstantScoreQuery over a BooleanQuery containing only Occur.SHOULD clauses.

For instance in the following example, both q1 and q2 would yield the same scores:

Query q1 = new TermInSetQuery("field", new BytesRef("foo"), new BytesRef("bar"));
BooleanQuery bq = new BooleanQuery();
bq.add(new TermQuery(new Term("field", "foo")), Occur.SHOULD);
bq.add(new TermQuery(new Term("field", "bar")), Occur.SHOULD);
Query q2 = new ConstantScoreQuery(bq);

When there are few terms, this query executes like a regular disjunction. However, when there are many terms, instead of merging iterators on the fly, it will populate a bit set with matching docs and return a Scorer over this bit set.

NOTE: This query produces scores that are equal to its boost

/** * Specialization for a disjunction over many terms that behaves like a * {@link ConstantScoreQuery} over a {@link BooleanQuery} containing only * {@link org.apache.lucene.search.BooleanClause.Occur#SHOULD} clauses. * <p>For instance in the following example, both {@code q1} and {@code q2} * would yield the same scores: * <pre class="prettyprint"> * Query q1 = new TermInSetQuery("field", new BytesRef("foo"), new BytesRef("bar")); * * BooleanQuery bq = new BooleanQuery(); * bq.add(new TermQuery(new Term("field", "foo")), Occur.SHOULD); * bq.add(new TermQuery(new Term("field", "bar")), Occur.SHOULD); * Query q2 = new ConstantScoreQuery(bq); * </pre> * <p>When there are few terms, this query executes like a regular disjunction. * However, when there are many terms, instead of merging iterators on the fly, * it will populate a bit set with matching docs and return a {@link Scorer} * over this bit set. * <p>NOTE: This query produces scores that are equal to its boost */
public class TermInSetQuery extends Query implements Accountable { private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(TermInSetQuery.class); // Same threshold as MultiTermQueryConstantScoreWrapper static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16; private final String field; private final PrefixCodedTerms termData; private final int termDataHashCode; // cached hashcode of termData
Creates a new TermInSetQuery from the given collection of terms.
/** * Creates a new {@link TermInSetQuery} from the given collection of terms. */
public TermInSetQuery(String field, Collection<BytesRef> terms) { BytesRef[] sortedTerms = terms.toArray(new BytesRef[0]); // already sorted if we are a SortedSet with natural order boolean sorted = terms instanceof SortedSet && ((SortedSet<BytesRef>)terms).comparator() == null; if (!sorted) { ArrayUtil.timSort(sortedTerms); } PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder(); BytesRefBuilder previous = null; for (BytesRef term : sortedTerms) { if (previous == null) { previous = new BytesRefBuilder(); } else if (previous.get().equals(term)) { continue; // deduplicate } builder.add(field, term); previous.copyBytes(term); } this.field = field; termData = builder.finish(); termDataHashCode = termData.hashCode(); }
Creates a new TermInSetQuery from the given array of terms.
/** * Creates a new {@link TermInSetQuery} from the given array of terms. */
public TermInSetQuery(String field, BytesRef...terms) { this(field, Arrays.asList(terms)); } @Override public Query rewrite(IndexReader reader) throws IOException { final int threshold = Math.min(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD, BooleanQuery.getMaxClauseCount()); if (termData.size() <= threshold) { BooleanQuery.Builder bq = new BooleanQuery.Builder(); TermIterator iterator = termData.iterator(); for (BytesRef term = iterator.next(); term != null; term = iterator.next()) { bq.add(new TermQuery(new Term(iterator.field(), BytesRef.deepCopyOf(term))), Occur.SHOULD); } return new ConstantScoreQuery(bq.build()); } return super.rewrite(reader); } @Override public void visit(QueryVisitor visitor) { if (visitor.acceptField(field) == false) { return; } if (termData.size() == 1) { visitor.consumeTerms(this, new Term(field, termData.iterator().next())); } if (termData.size() > 1) { visitor.consumeTermsMatching(this, field, this::asByteRunAutomaton); } } private ByteRunAutomaton asByteRunAutomaton() { TermIterator iterator = termData.iterator(); List<Automaton> automata = new ArrayList<>(); for (BytesRef term = iterator.next(); term != null; term = iterator.next()) { automata.add(Automata.makeBinary(term)); } return new CompiledAutomaton(Operations.union(automata)).runAutomaton; } @Override public boolean equals(Object other) { return sameClassAs(other) && equalsTo(getClass().cast(other)); } private boolean equalsTo(TermInSetQuery other) { // no need to check 'field' explicitly since it is encoded in 'termData' // termData might be heavy to compare so check the hash code first return termDataHashCode == other.termDataHashCode && termData.equals(other.termData); } @Override public int hashCode() { return 31 * classHash() + termDataHashCode; }
Returns the terms wrapped in a PrefixCodedTerms.
/** Returns the terms wrapped in a PrefixCodedTerms. */
public PrefixCodedTerms getTermData() { return termData; } @Override public String toString(String defaultField) { StringBuilder builder = new StringBuilder(); builder.append(field); builder.append(":("); TermIterator iterator = termData.iterator(); boolean first = true; for (BytesRef term = iterator.next(); term != null; term = iterator.next()) { if (!first) { builder.append(' '); } first = false; builder.append(Term.toString(term)); } builder.append(')'); return builder.toString(); } @Override public long ramBytesUsed() { return BASE_RAM_BYTES_USED + termData.ramBytesUsed(); } @Override public Collection<Accountable> getChildResources() { return Collections.emptyList(); } private static class TermAndState { final String field; final TermsEnum termsEnum; final BytesRef term; final TermState state; final int docFreq; final long totalTermFreq; TermAndState(String field, TermsEnum termsEnum) throws IOException { this.field = field; this.termsEnum = termsEnum; this.term = BytesRef.deepCopyOf(termsEnum.term()); this.state = termsEnum.termState(); this.docFreq = termsEnum.docFreq(); this.totalTermFreq = termsEnum.totalTermFreq(); } } private static class WeightOrDocIdSet { final Weight weight; final DocIdSet set; WeightOrDocIdSet(Weight weight) { this.weight = Objects.requireNonNull(weight); this.set = null; } WeightOrDocIdSet(DocIdSet bitset) { this.set = bitset; this.weight = null; } } @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { return new ConstantScoreWeight(this, boost) { @Override public void extractTerms(Set<Term> terms) { // no-op // This query is for abuse cases when the number of terms is too high to // run efficiently as a BooleanQuery. So likewise we hide its terms in // order to protect highlighters } @Override public Matches matches(LeafReaderContext context, int doc) throws IOException { Terms terms = context.reader().terms(field); if (terms == null || terms.hasPositions() == false) { return super.matches(context, doc); } return MatchesUtils.forField(field, () -> DisjunctionMatchesIterator.fromTermsEnum(context, doc, getQuery(), field, termData.iterator())); }
On the given leaf context, try to either rewrite to a disjunction if there are few matching terms, or build a bitset containing matching docs.
/** * On the given leaf context, try to either rewrite to a disjunction if * there are few matching terms, or build a bitset containing matching docs. */
private WeightOrDocIdSet rewrite(LeafReaderContext context) throws IOException { final LeafReader reader = context.reader(); Terms terms = reader.terms(field); if (terms == null) { return null; } TermsEnum termsEnum = terms.iterator(); PostingsEnum docs = null; TermIterator iterator = termData.iterator(); // We will first try to collect up to 'threshold' terms into 'matchingTerms' // if there are two many terms, we will fall back to building the 'builder' final int threshold = Math.min(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD, BooleanQuery.getMaxClauseCount()); assert termData.size() > threshold : "Query should have been rewritten"; List<TermAndState> matchingTerms = new ArrayList<>(threshold); DocIdSetBuilder builder = null; for (BytesRef term = iterator.next(); term != null; term = iterator.next()) { assert field.equals(iterator.field()); if (termsEnum.seekExact(term)) { if (matchingTerms == null) { docs = termsEnum.postings(docs, PostingsEnum.NONE); builder.add(docs); } else if (matchingTerms.size() < threshold) { matchingTerms.add(new TermAndState(field, termsEnum)); } else { assert matchingTerms.size() == threshold; builder = new DocIdSetBuilder(reader.maxDoc(), terms); docs = termsEnum.postings(docs, PostingsEnum.NONE); builder.add(docs); for (TermAndState t : matchingTerms) { t.termsEnum.seekExact(t.term, t.state); docs = t.termsEnum.postings(docs, PostingsEnum.NONE); builder.add(docs); } matchingTerms = null; } } } if (matchingTerms != null) { assert builder == null; BooleanQuery.Builder bq = new BooleanQuery.Builder(); for (TermAndState t : matchingTerms) { final TermStates termStates = new TermStates(searcher.getTopReaderContext()); termStates.register(t.state, context.ord, t.docFreq, t.totalTermFreq); bq.add(new TermQuery(new Term(t.field, t.term), termStates), Occur.SHOULD); } Query q = new ConstantScoreQuery(bq.build()); final Weight weight = searcher.rewrite(q).createWeight(searcher, scoreMode, score()); return new WeightOrDocIdSet(weight); } else { assert builder != null; return new WeightOrDocIdSet(builder.build()); } } private Scorer scorer(DocIdSet set) throws IOException { if (set == null) { return null; } final DocIdSetIterator disi = set.iterator(); if (disi == null) { return null; } return new ConstantScoreScorer(this, score(), scoreMode, disi); } @Override public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { final WeightOrDocIdSet weightOrBitSet = rewrite(context); if (weightOrBitSet == null) { return null; } else if (weightOrBitSet.weight != null) { return weightOrBitSet.weight.bulkScorer(context); } else { final Scorer scorer = scorer(weightOrBitSet.set); if (scorer == null) { return null; } return new DefaultBulkScorer(scorer); } } @Override public Scorer scorer(LeafReaderContext context) throws IOException { final WeightOrDocIdSet weightOrBitSet = rewrite(context); if (weightOrBitSet == null) { return null; } else if (weightOrBitSet.weight != null) { return weightOrBitSet.weight.scorer(context); } else { return scorer(weightOrBitSet.set); } } @Override public boolean isCacheable(LeafReaderContext ctx) { // Only cache instances that have a reasonable size. Otherwise it might cause memory issues // with the query cache if most memory ends up being spent on queries rather than doc id sets. return ramBytesUsed() <= RamUsageEstimator.QUERY_DEFAULT_RAM_BYTES_USED; } }; } }