package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SlowImpactsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.mutable.MutableValueBool;
public class PhraseWildcardQuery extends Query {
protected static final Query NO_MATCH_QUERY = new MatchNoDocsQuery("Empty " + PhraseWildcardQuery.class.getSimpleName());
protected final String field;
protected final List<PhraseTerm> phraseTerms;
protected final int slop;
protected final int maxMultiTermExpansions;
protected final boolean segmentOptimizationEnabled;
protected PhraseWildcardQuery(
String field,
List<PhraseTerm> phraseTerms,
int slop,
int maxMultiTermExpansions,
boolean segmentOptimizationEnabled) {
this.field = field;
this.phraseTerms = phraseTerms;
this.slop = slop;
this.maxMultiTermExpansions = maxMultiTermExpansions;
this.segmentOptimizationEnabled = segmentOptimizationEnabled;
}
public String getField() {
return field;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
if (phraseTerms.isEmpty()) {
return NO_MATCH_QUERY;
}
if (phraseTerms.size() == 1) {
return phraseTerms.get(0).getQuery();
}
return super.rewrite(reader);
}
@Override
public void visit(QueryVisitor visitor) {
if (!visitor.acceptField(field)) {
return;
}
QueryVisitor v = visitor.getSubVisitor(BooleanClause.Occur.MUST, this);
for (PhraseTerm phraseTerm : phraseTerms) {
phraseTerm.getQuery().visit(v);
}
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
IndexReader reader = searcher.getIndexReader();
List<LeafReaderContext> sizeSortedSegments =
new SegmentTermsSizeComparator().createTermsSizeSortedCopyOf(reader.leaves());
TermsData termsData = createTermsData(sizeSortedSegments.size());
int numMultiTerms = 0;
for (PhraseTerm phraseTerm : phraseTerms) {
if (phraseTerm.hasExpansions()) {
numMultiTerms++;
} else {
assert TestCounters.get().incSingleTermAnalysisCount();
int numMatches = phraseTerm.collectTermData(this, searcher, sizeSortedSegments, termsData);
if (numMatches == 0) {
return earlyStopWeight();
}
}
}
int remainingExpansions = maxMultiTermExpansions;
int remainingMultiTerms = numMultiTerms;
for (PhraseTerm phraseTerm : phraseTerms) {
if (phraseTerm.hasExpansions()) {
assert TestCounters.get().incMultiTermAnalysisCount();
assert remainingExpansions >= 0 && remainingExpansions <= maxMultiTermExpansions;
assert remainingMultiTerms > 0;
int maxExpansionsForTerm = remainingExpansions / remainingMultiTerms;
int numExpansions = phraseTerm.collectTermData(this, searcher, sizeSortedSegments, remainingMultiTerms, maxExpansionsForTerm, termsData);
assert numExpansions >= 0 && numExpansions <= maxExpansionsForTerm;
if (numExpansions == 0) {
return earlyStopWeight();
}
remainingExpansions -= numExpansions;
remainingMultiTerms--;
}
}
assert remainingMultiTerms == 0;
assert remainingExpansions >= 0;
return termsData.areAllTermsMatching() ?
createPhraseWeight(searcher, scoreMode, boost, termsData)
: noMatchWeight();
}
protected TermsData createTermsData(int numSegments) {
return new TermsData(phraseTerms.size(), numSegments);
}
protected Weight earlyStopWeight() {
assert TestCounters.get().incQueryEarlyStopCount();
return noMatchWeight();
}
protected Weight noMatchWeight() {
return new ConstantScoreWeight(this, 0) {
@Override
public Scorer scorer(LeafReaderContext leafReaderContext) {
return null;
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
PhraseWeight createPhraseWeight(IndexSearcher searcher, ScoreMode scoreMode,
float boost, TermsData termsData) throws IOException {
return new PhraseWeight(this, field, searcher, scoreMode) {
@Override
protected Similarity.SimScorer getStats(IndexSearcher searcher) throws IOException {
if (termsData.termStatsList.isEmpty()) {
return null;
}
return searcher.getSimilarity().scorer(
boost,
searcher.collectionStatistics(field),
termsData.termStatsList.toArray(new TermStatistics[0]));
}
@Override
protected PhraseMatcher getPhraseMatcher(LeafReaderContext leafReaderContext, Similarity.SimScorer scorer, boolean exposeOffsets) throws IOException {
Terms fieldTerms = leafReaderContext.reader().terms(field);
if (fieldTerms == null) {
return null;
}
TermsEnum termsEnum = fieldTerms.iterator();
float totalMatchCost = 0;
PhraseQuery.PostingsAndFreq[] postingsFreqs = new PhraseQuery.PostingsAndFreq[phraseTerms.size()];
for (int termPosition = 0; termPosition < postingsFreqs.length; termPosition++) {
TermData termData = termsData.getTermData(termPosition);
assert termData != null;
List<TermBytesTermState> termStates = termData.getTermStatesForSegment(leafReaderContext);
if (termStates == null) {
return null;
}
assert !termStates.isEmpty();
List<PostingsEnum> postingsEnums = new ArrayList<>(termStates.size());
for (TermBytesTermState termBytesTermState : termStates) {
termsEnum.seekExact(termBytesTermState.termBytes, termBytesTermState.termState);
postingsEnums.add(termsEnum.postings(null, exposeOffsets ? PostingsEnum.ALL : PostingsEnum.POSITIONS));
totalMatchCost += PhraseQuery.termPositionsCost(termsEnum);
}
PostingsEnum unionPostingsEnum;
if (postingsEnums.size() == 1) {
unionPostingsEnum = postingsEnums.get(0);
} else {
unionPostingsEnum = exposeOffsets ? new MultiPhraseQuery.UnionFullPostingsEnum(postingsEnums) : new MultiPhraseQuery.UnionPostingsEnum(postingsEnums);
}
postingsFreqs[termPosition] = new PhraseQuery.PostingsAndFreq(unionPostingsEnum, new SlowImpactsEnum(unionPostingsEnum), termPosition, termData.terms);
}
if (slop == 0) {
ArrayUtil.timSort(postingsFreqs);
return new ExactPhraseMatcher(postingsFreqs, scoreMode, scorer, totalMatchCost);
} else {
return new SloppyPhraseMatcher(postingsFreqs, slop, scoreMode, scorer, totalMatchCost, exposeOffsets);
}
}
@Override
public void (Set<Term> terms) {
for (int i = 0, size = phraseTerms.size(); i < size; i++) {
terms.addAll(termsData.getTermData(i).terms);
}
}
};
}
@Override
public boolean equals(Object o) {
if (!(o instanceof PhraseWildcardQuery)) {
return false;
}
PhraseWildcardQuery pwq = (PhraseWildcardQuery) o;
return slop == pwq.slop && phraseTerms.equals(pwq.phraseTerms);
}
@Override
public int hashCode() {
return classHash() ^ slop ^ phraseTerms.hashCode();
}
@Override
public final String toString(String omittedField) {
StringBuilder builder = new StringBuilder();
builder.append("phraseWildcard(");
if (field == null || !field.equals(omittedField)) {
builder.append(field).append(':');
}
builder.append('\"');
for (int i = 0; i < phraseTerms.size(); i++) {
if (i != 0) {
builder.append(' ');
}
phraseTerms.get(i).toString(builder);
}
builder.append('\"');
if (slop != 0) {
builder.append('~');
builder.append(slop);
}
builder.append(")");
return builder.toString();
}
protected int collectSingleTermData(
SingleTerm singleTerm,
IndexSearcher searcher,
List<LeafReaderContext> segments,
TermsData termsData) throws IOException {
TermData termData = termsData.getOrCreateTermData(singleTerm.termPosition);
Term term = singleTerm.term;
termData.terms.add(term);
TermStates termStates = TermStates.build(searcher.getIndexReader().getContext(), term, true);
int numMatches = 0;
Iterator<LeafReaderContext> segmentIterator = segments.iterator();
while (segmentIterator.hasNext()) {
LeafReaderContext leafReaderContext = segmentIterator.next();
assert TestCounters.get().incSegmentUseCount();
boolean termMatchesInSegment = false;
Terms terms = leafReaderContext.reader().terms(term.field());
if (terms != null) {
checkTermsHavePositions(terms);
TermState termState = termStates.get(leafReaderContext);
if (termState != null) {
termMatchesInSegment = true;
numMatches++;
termData.setTermStatesForSegment(leafReaderContext, Collections.singletonList(new TermBytesTermState(term.bytes(), termState)));
}
}
if (!termMatchesInSegment && shouldOptimizeSegments()) {
segmentIterator.remove();
assert TestCounters.get().incSegmentSkipCount();
}
}
if (termStates.docFreq() > 0) {
termsData.termStatsList.add(searcher.termStatistics(term, termStates.docFreq(), termStates.totalTermFreq()));
}
return numMatches;
}
protected int collectMultiTermData(
MultiTerm multiTerm,
IndexSearcher searcher,
List<LeafReaderContext> segments,
int remainingMultiTerms,
int maxExpansionsForTerm,
TermsData termsData) throws IOException {
TermData termData = termsData.getOrCreateTermData(multiTerm.termPosition);
Map<BytesRef, TermStats> termStatsMap = createTermStatsMap(multiTerm);
int numExpansions = 0;
Iterator<LeafReaderContext> segmentIterator = segments.iterator();
MutableValueBool shouldStopSegmentIteration = new MutableValueBool();
while (segmentIterator.hasNext() && !shouldStopSegmentIteration.value) {
LeafReaderContext leafReaderContext = segmentIterator.next();
int remainingExpansions = maxExpansionsForTerm - numExpansions;
assert remainingExpansions >= 0;
List<TermBytesTermState> termStates = collectMultiTermDataForSegment(
multiTerm, leafReaderContext, remainingExpansions, shouldStopSegmentIteration, termStatsMap);
if (!termStates.isEmpty()) {
assert termStates.size() <= remainingExpansions;
numExpansions += termStates.size();
assert numExpansions <= maxExpansionsForTerm;
termData.setTermStatesForSegment(leafReaderContext, termStates);
} else if (shouldOptimizeSegments()) {
segmentIterator.remove();
assert TestCounters.get().incSegmentSkipCount();
}
}
collectMultiTermStats(searcher, termStatsMap, termsData, termData);
return numExpansions;
}
protected boolean shouldOptimizeSegments() {
return segmentOptimizationEnabled;
}
protected Map<BytesRef, TermStats> createTermStatsMap(MultiTerm multiTerm) {
return new HashMap<>();
}
protected List<TermBytesTermState> collectMultiTermDataForSegment(
MultiTerm multiTerm,
LeafReaderContext leafReaderContext,
int remainingExpansions,
MutableValueBool shouldStopSegmentIteration,
Map<BytesRef, TermStats> termStatsMap) throws IOException {
TermsEnum termsEnum = createTermsEnum(multiTerm, leafReaderContext);
if (termsEnum == null) {
return Collections.emptyList();
}
assert TestCounters.get().incSegmentUseCount();
List<TermBytesTermState> termStates = new ArrayList<>();
while (termsEnum.next() != null && remainingExpansions > 0) {
TermStats termStats = termStatsMap.get(termsEnum.term());
if (termStats == null) {
BytesRef termBytes = BytesRef.deepCopyOf(termsEnum.term());
termStats = new TermStats(termBytes);
termStatsMap.put(termBytes, termStats);
}
termStats.addStats(termsEnum.docFreq(), termsEnum.totalTermFreq());
termStates.add(new TermBytesTermState(termStats.termBytes, termsEnum.termState()));
remainingExpansions--;
assert TestCounters.get().incExpansionCount();
}
assert remainingExpansions >= 0;
shouldStopSegmentIteration.value = remainingExpansions == 0;
return termStates;
}
protected TermsEnum createTermsEnum(MultiTerm multiTerm, LeafReaderContext leafReaderContext) throws IOException {
Terms terms = leafReaderContext.reader().terms(field);
if (terms == null) {
return null;
}
checkTermsHavePositions(terms);
TermsEnum termsEnum = multiTerm.query.getTermsEnum(terms);
assert termsEnum != null;
return termsEnum;
}
protected void collectMultiTermStats(
IndexSearcher searcher,
Map<BytesRef, TermStats> termStatsMap,
TermsData termsData,
TermData termData) throws IOException {
for (Map.Entry<BytesRef, TermStats> termStatsEntry : termStatsMap.entrySet()) {
Term term = new Term(field, termStatsEntry.getKey());
termData.terms.add(term);
TermStats termStats = termStatsEntry.getValue();
if (termStats.docFreq > 0) {
termsData.termStatsList.add(searcher.termStatistics(term, termStats.docFreq, termStats.totalTermFreq));
}
}
}
protected void checkTermsHavePositions(Terms terms) {
if (!terms.hasPositions()) {
throw new IllegalStateException("field \"" + field + "\" was indexed without position data;" +
" cannot run " + PhraseWildcardQuery.class.getSimpleName());
}
}
public static class Builder {
protected final String field;
protected final List<PhraseTerm> phraseTerms;
protected int slop;
protected final int maxMultiTermExpansions;
protected final boolean segmentOptimizationEnabled;
public Builder(String field, int maxMultiTermExpansions) {
this(field, maxMultiTermExpansions, true);
}
public Builder(String field, int maxMultiTermExpansions, boolean segmentOptimizationEnabled) {
this.field = field;
this.maxMultiTermExpansions = maxMultiTermExpansions;
this.segmentOptimizationEnabled = segmentOptimizationEnabled;
phraseTerms = new ArrayList<>();
}
public Builder addTerm(BytesRef termBytes) {
return addTerm(new Term(field, termBytes));
}
public Builder addTerm(Term term) {
if (!term.field().equals(field)) {
throw new IllegalArgumentException(term.getClass().getSimpleName()
+ " field \"" + term.field() + "\" cannot be different from the "
+ PhraseWildcardQuery.class.getSimpleName() + " field \"" + field + "\"");
}
phraseTerms.add(new SingleTerm(term, phraseTerms.size()));
return this;
}
public Builder addMultiTerm(MultiTermQuery multiTermQuery) {
if (!multiTermQuery.getField().equals(field)) {
throw new IllegalArgumentException(multiTermQuery.getClass().getSimpleName()
+ " field \"" + multiTermQuery.getField() + "\" cannot be different from the "
+ PhraseWildcardQuery.class.getSimpleName() + " field \"" + field + "\"");
}
phraseTerms.add(new MultiTerm(multiTermQuery, phraseTerms.size()));
return this;
}
public Builder setSlop(int slop) {
if (slop < 0) {
throw new IllegalArgumentException("slop value cannot be negative");
}
this.slop = slop;
return this;
}
public PhraseWildcardQuery build() {
return new PhraseWildcardQuery(field, phraseTerms, slop, maxMultiTermExpansions, segmentOptimizationEnabled);
}
}
protected abstract static class PhraseTerm {
protected final int termPosition;
protected PhraseTerm(int termPosition) {
this.termPosition = termPosition;
}
protected abstract boolean hasExpansions();
protected abstract Query getQuery();
protected int collectTermData(
PhraseWildcardQuery query,
IndexSearcher searcher,
List<LeafReaderContext> segments,
TermsData termsData) throws IOException {
throw new UnsupportedOperationException();
}
protected abstract int collectTermData(
PhraseWildcardQuery query,
IndexSearcher searcher,
List<LeafReaderContext> segments,
int remainingMultiTerms,
int maxExpansionsForTerm,
TermsData termsData) throws IOException;
protected abstract void toString(StringBuilder builder);
@Override
public abstract boolean equals(Object o);
@Override
public abstract int hashCode();
}
protected static class SingleTerm extends PhraseTerm {
protected final Term term;
protected SingleTerm(Term term, int termPosition) {
super(termPosition);
this.term = term;
}
@Override
protected boolean hasExpansions() {
return false;
}
@Override
protected Query getQuery() {
return new TermQuery(term);
}
@Override
protected int collectTermData(
PhraseWildcardQuery query,
IndexSearcher searcher,
List<LeafReaderContext> segments,
TermsData termsData) throws IOException {
return collectTermData(query, searcher, segments, 0, 0, termsData);
}
@Override
protected int collectTermData(
PhraseWildcardQuery query,
IndexSearcher searcher,
List<LeafReaderContext> segments,
int remainingMultiTerms,
int maxExpansionsForTerm,
TermsData termsData) throws IOException {
return query.collectSingleTermData(this, searcher, segments, termsData);
}
@Override
protected void toString(StringBuilder builder) {
builder.append(term.text());
}
@Override
public boolean equals(Object o) {
if (!(o instanceof SingleTerm)) {
return false;
}
SingleTerm singleTerm = (SingleTerm) o;
return term.equals(singleTerm.term);
}
@Override
public int hashCode() {
return term.hashCode();
}
}
protected static class MultiTerm extends PhraseTerm {
protected final MultiTermQuery query;
protected MultiTerm(MultiTermQuery query, int termPosition) {
super(termPosition);
this.query = query;
}
@Override
protected boolean hasExpansions() {
return true;
}
@Override
protected Query getQuery() {
return query;
}
@Override
protected int collectTermData(
PhraseWildcardQuery query,
IndexSearcher searcher,
List<LeafReaderContext> segments,
int remainingMultiTerms,
int maxExpansionsForTerm,
TermsData termsData) throws IOException {
return query.collectMultiTermData(this, searcher, segments, remainingMultiTerms, maxExpansionsForTerm, termsData);
}
@Override
protected void toString(StringBuilder builder) {
builder.append(query.toString(query.field));
}
@Override
public boolean equals(Object o) {
if (!(o instanceof MultiTerm)) {
return false;
}
MultiTerm multiTerm = (MultiTerm) o;
return query.equals(multiTerm.query);
}
@Override
public int hashCode() {
return query.hashCode();
}
}
protected static class TermsData {
protected final int numTerms;
protected final int numSegments;
protected final List<TermStatistics> termStatsList;
protected final TermData[] termDataPerPosition;
protected int numTermsMatching;
protected TermsData(int numTerms, int numSegments) {
this.numTerms = numTerms;
this.numSegments = numSegments;
termStatsList = new ArrayList<>();
termDataPerPosition = new TermData[numTerms];
}
protected TermData getOrCreateTermData(int termPosition) {
TermData termData = termDataPerPosition[termPosition];
if (termData == null) {
termData = new TermData(numSegments, this);
termDataPerPosition[termPosition] = termData;
}
return termData;
}
protected TermData getTermData(int termPosition) {
return termDataPerPosition[termPosition];
}
protected boolean areAllTermsMatching() {
assert numTermsMatching <= numTerms;
return numTermsMatching == numTerms;
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("TermsData(");
builder.append("numSegments=").append(numSegments);
builder.append(", termDataPerPosition=").append(Arrays.asList(termDataPerPosition));
builder.append(", termsStatsList=[");
for (TermStatistics termStatistics : termStatsList) {
builder.append("{")
.append(termStatistics.term().utf8ToString())
.append(", ").append(termStatistics.docFreq())
.append(", ").append(termStatistics.totalTermFreq())
.append("}");
}
builder.append("]");
builder.append(")");
return builder.toString();
}
}
protected static class TermData {
protected final int numSegments;
protected final TermsData termsData;
protected List<TermBytesTermState>[] termStatesPerSegment;
protected final List<Term> terms;
protected TermData(int numSegments, TermsData termsData) {
this.numSegments = numSegments;
this.termsData = termsData;
terms = new ArrayList<>();
}
@SuppressWarnings("unchecked")
protected void setTermStatesForSegment(LeafReaderContext leafReaderContext, List<TermBytesTermState> termStates) {
if (termStatesPerSegment == null) {
termStatesPerSegment = (List<TermBytesTermState>[]) new List[numSegments];
termsData.numTermsMatching++;
}
termStatesPerSegment[leafReaderContext.ord] = termStates;
}
protected List<TermBytesTermState> getTermStatesForSegment(LeafReaderContext leafReaderContext) {
assert termStatesPerSegment != null : "No TermState for any segment; the query should have been stopped before";
return termStatesPerSegment[leafReaderContext.ord];
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("TermData(");
builder.append("termStates=");
if (termStatesPerSegment == null) {
builder.append("null");
} else {
builder.append(Arrays.asList(termStatesPerSegment));
}
builder.append(", terms=").append(terms);
builder.append(")");
return builder.toString();
}
}
public static class TermBytesTermState {
protected final BytesRef termBytes;
protected final TermState termState;
public TermBytesTermState(BytesRef termBytes, TermState termState) {
this.termBytes = termBytes;
this.termState = termState;
}
@Override
public String toString() {
return "\"" + termBytes.utf8ToString() + "\"->" + termState;
}
}
public static class TermStats {
protected final BytesRef termBytes;
protected int docFreq;
protected long totalTermFreq;
protected TermStats(BytesRef termBytes) {
this.termBytes = termBytes;
}
public BytesRef getTermBytes() {
return termBytes;
}
protected void addStats(int docFreq, long totalTermFreq) {
this.docFreq += docFreq;
if (this.totalTermFreq >= 0 && totalTermFreq >= 0) {
this.totalTermFreq += totalTermFreq;
} else {
this.totalTermFreq = -1;
}
}
}
protected class SegmentTermsSizeComparator implements Comparator<LeafReaderContext> {
private static final String COMPARISON_ERROR_MESSAGE = "Segment comparison error";
@Override
public int compare(LeafReaderContext leafReaderContext1, LeafReaderContext leafReaderContext2) {
try {
return Long.compare(getTermsSize(leafReaderContext1), getTermsSize(leafReaderContext2));
} catch (IOException e) {
throw new RuntimeException(COMPARISON_ERROR_MESSAGE, e);
}
}
protected List<LeafReaderContext> createTermsSizeSortedCopyOf(List<LeafReaderContext> segments) throws IOException {
List<LeafReaderContext> copy = new ArrayList<>(segments);
try {
copy.sort(this);
} catch (RuntimeException e) {
if (COMPARISON_ERROR_MESSAGE.equals(e.getMessage())) {
throw (IOException) e.getCause();
}
throw e;
}
return copy;
}
private long getTermsSize(LeafReaderContext leafReaderContext) throws IOException {
Terms terms = leafReaderContext.reader().terms(field);
return terms == null ? 0 : terms.size();
}
}
protected static class TestCounters {
private static final TestCounters SINGLETON = new TestCounters();
protected long singleTermAnalysisCount;
protected long multiTermAnalysisCount;
protected long expansionCount;
protected long segmentUseCount;
protected long segmentSkipCount;
protected long queryEarlyStopCount;
protected static TestCounters get() {
return SINGLETON;
}
protected boolean incSingleTermAnalysisCount() {
singleTermAnalysisCount++;
return true;
}
protected boolean incMultiTermAnalysisCount() {
multiTermAnalysisCount++;
return true;
}
protected boolean incExpansionCount() {
expansionCount++;
return true;
}
protected boolean incSegmentUseCount() {
segmentUseCount++;
return true;
}
protected boolean incSegmentSkipCount() {
segmentSkipCount++;
return true;
}
protected boolean incQueryEarlyStopCount() {
queryEarlyStopCount++;
return true;
}
protected void clear() {
singleTermAnalysisCount = 0;
multiTermAnalysisCount = 0;
expansionCount = 0;
segmentUseCount = 0;
segmentSkipCount = 0;
queryEarlyStopCount = 0;
}
}
}