package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.index.FilterLeafReader.FilterTerms;
import org.apache.lucene.index.FilterLeafReader.FilterTermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.CompiledAutomaton;
public class ExitableDirectoryReader extends FilterDirectoryReader {
private QueryTimeout queryTimeout;
@SuppressWarnings("serial")
public static class ExitingReaderException extends RuntimeException {
public ExitingReaderException(String msg) {
super(msg);
}
}
public static class ExitableSubReaderWrapper extends SubReaderWrapper {
private QueryTimeout queryTimeout;
public ExitableSubReaderWrapper(QueryTimeout queryTimeout) {
this.queryTimeout = queryTimeout;
}
@Override
public LeafReader wrap(LeafReader reader) {
return new ExitableFilterAtomicReader(reader, queryTimeout);
}
}
public static class ExitableFilterAtomicReader extends FilterLeafReader {
final private QueryTimeout queryTimeout;
final static int DOCS_BETWEEN_TIMEOUT_CHECK = 1000;
public ExitableFilterAtomicReader(LeafReader in, QueryTimeout queryTimeout) {
super(in);
this.queryTimeout = queryTimeout;
}
@Override
public PointValues getPointValues(String field) throws IOException {
final PointValues pointValues = in.getPointValues(field);
if (pointValues == null) {
return null;
}
return (queryTimeout.isTimeoutEnabled()) ? new ExitablePointValues(pointValues, queryTimeout) : pointValues;
}
@Override
public Terms terms(String field) throws IOException {
Terms terms = in.terms(field);
if (terms == null) {
return null;
}
return (queryTimeout.isTimeoutEnabled()) ? new ExitableTerms(terms, queryTimeout) : terms;
}
@Override
public CacheHelper getReaderCacheHelper() {
return in.getReaderCacheHelper();
}
@Override
public CacheHelper getCoreCacheHelper() {
return in.getCoreCacheHelper();
}
@Override
public NumericDocValues getNumericDocValues(String field) throws IOException {
final NumericDocValues numericDocValues = super.getNumericDocValues(field);
if (numericDocValues == null) {
return null;
}
return (queryTimeout.isTimeoutEnabled()) ? new FilterNumericDocValues(numericDocValues) {
private int docToCheck = 0;
@Override
public int advance(int target) throws IOException {
final int advance = super.advance(target);
if (advance >= docToCheck) {
checkAndThrow(in);
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advance;
}
@Override
public boolean advanceExact(int target) throws IOException {
final boolean advanceExact = super.advanceExact(target);
if (target >= docToCheck) {
checkAndThrow(in);
docToCheck=target + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advanceExact;
}
@Override
public int nextDoc() throws IOException {
final int nextDoc = super.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow(in);
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return nextDoc;
}
}: numericDocValues;
}
@Override
public BinaryDocValues getBinaryDocValues(String field) throws IOException {
final BinaryDocValues binaryDocValues = super.getBinaryDocValues(field);
if (binaryDocValues == null) {
return null;
}
return (queryTimeout.isTimeoutEnabled()) ? new FilterBinaryDocValues(binaryDocValues) {
private int docToCheck = 0;
@Override
public int advance(int target) throws IOException {
final int advance = super.advance(target);
if (target >= docToCheck) {
checkAndThrow(in);
docToCheck = target + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advance;
}
@Override
public boolean advanceExact(int target) throws IOException {
final boolean advanceExact = super.advanceExact(target);
if (target >= docToCheck) {
checkAndThrow(in);
docToCheck = target + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advanceExact;
}
@Override
public int nextDoc() throws IOException {
final int nextDoc = super.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow(in);
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return nextDoc;
}
}: binaryDocValues;
}
@Override
public SortedDocValues getSortedDocValues(String field) throws IOException {
final SortedDocValues sortedDocValues = super.getSortedDocValues(field);
if (sortedDocValues == null) {
return null;
}
return (queryTimeout.isTimeoutEnabled()) ? new FilterSortedDocValues(sortedDocValues) {
private int docToCheck = 0;
@Override
public int advance(int target) throws IOException {
final int advance = super.advance(target);
if (advance >= docToCheck) {
checkAndThrow(in);
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advance;
}
@Override
public boolean advanceExact(int target) throws IOException {
final boolean advanceExact = super.advanceExact(target);
if (target >= docToCheck) {
checkAndThrow(in);
docToCheck = target + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advanceExact;
}
@Override
public int nextDoc() throws IOException {
final int nextDoc = super.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow(in);
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return nextDoc;
}
}: sortedDocValues;
}
@Override
public SortedNumericDocValues getSortedNumericDocValues(String field) throws IOException {
final SortedNumericDocValues sortedNumericDocValues = super.getSortedNumericDocValues(field);
if (sortedNumericDocValues == null) {
return null;
}
return (queryTimeout.isTimeoutEnabled()) ? new FilterSortedNumericDocValues(sortedNumericDocValues) {
private int docToCheck = 0;
@Override
public int advance(int target) throws IOException {
final int advance = super.advance(target);
if (advance >= docToCheck) {
checkAndThrow(in);
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advance;
}
@Override
public boolean advanceExact(int target) throws IOException {
final boolean advanceExact = super.advanceExact(target);
if (target >= docToCheck) {
checkAndThrow(in);
docToCheck = target + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advanceExact;
}
@Override
public int nextDoc() throws IOException {
final int nextDoc = super.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow(in);
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return nextDoc;
}
}: sortedNumericDocValues;
}
@Override
public SortedSetDocValues getSortedSetDocValues(String field) throws IOException {
final SortedSetDocValues sortedSetDocValues = super.getSortedSetDocValues(field);
if (sortedSetDocValues == null) {
return null;
}
return (queryTimeout.isTimeoutEnabled()) ? new FilterSortedSetDocValues(sortedSetDocValues) {
private int docToCheck=0;
@Override
public int advance(int target) throws IOException {
final int advance = super.advance(target);
if (advance >= docToCheck) {
checkAndThrow(in);
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advance;
}
@Override
public boolean advanceExact(int target) throws IOException {
final boolean advanceExact = super.advanceExact(target);
if (target >= docToCheck) {
checkAndThrow(in);
docToCheck = target + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advanceExact;
}
@Override
public int nextDoc() throws IOException {
final int nextDoc = super.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow(in);
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return nextDoc;
}
}: sortedSetDocValues;
}
private void checkAndThrow(DocIdSetIterator in) {
if (queryTimeout.shouldExit()) {
throw new ExitingReaderException("The request took too long to iterate over doc values. Timeout: "
+ queryTimeout.toString() + ", DocValues=" + in
);
} else if (Thread.interrupted()) {
throw new ExitingReaderException("Interrupted while iterating over point values. PointValues=" + in);
}
}
}
private static class ExitablePointValues extends PointValues {
private final PointValues in;
private final QueryTimeout queryTimeout;
private ExitablePointValues(PointValues in, QueryTimeout queryTimeout) {
this.in = in;
this.queryTimeout = queryTimeout;
checkAndThrow();
}
private void checkAndThrow() {
if (queryTimeout.shouldExit()) {
throw new ExitingReaderException("The request took too long to iterate over point values. Timeout: "
+ queryTimeout.toString()
+ ", PointValues=" + in
);
} else if (Thread.interrupted()) {
throw new ExitingReaderException("Interrupted while iterating over point values. PointValues=" + in);
}
}
@Override
public void intersect(IntersectVisitor visitor) throws IOException {
checkAndThrow();
in.intersect(new ExitableIntersectVisitor(visitor, queryTimeout));
}
@Override
public long estimatePointCount(IntersectVisitor visitor) {
checkAndThrow();
return in.estimatePointCount(visitor);
}
@Override
public byte[] getMinPackedValue() throws IOException {
checkAndThrow();
return in.getMinPackedValue();
}
@Override
public byte[] getMaxPackedValue() throws IOException {
checkAndThrow();
return in.getMaxPackedValue();
}
@Override
public int getNumDimensions() throws IOException {
checkAndThrow();
return in.getNumDimensions();
}
@Override
public int getNumIndexDimensions() throws IOException {
checkAndThrow();
return in.getNumIndexDimensions();
}
@Override
public int getBytesPerDimension() throws IOException {
checkAndThrow();
return in.getBytesPerDimension();
}
@Override
public long size() {
checkAndThrow();
return in.size();
}
@Override
public int getDocCount() {
checkAndThrow();
return in.getDocCount();
}
}
private static class ExitableIntersectVisitor implements PointValues.IntersectVisitor {
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
private final PointValues.IntersectVisitor in;
private final QueryTimeout queryTimeout;
private int calls;
private ExitableIntersectVisitor(PointValues.IntersectVisitor in, QueryTimeout queryTimeout) {
this.in = in;
this.queryTimeout = queryTimeout;
}
private void checkAndThrowWithSampling() {
if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
checkAndThrow();
}
}
private void checkAndThrow() {
if (queryTimeout.shouldExit()) {
throw new ExitingReaderException("The request took too long to intersect point values. Timeout: "
+ queryTimeout.toString()
+ ", PointValues=" + in
);
} else if (Thread.interrupted()) {
throw new ExitingReaderException("Interrupted while intersecting point values. PointValues=" + in);
}
}
@Override
public void visit(int docID) throws IOException {
checkAndThrowWithSampling();
in.visit(docID);
}
@Override
public void visit(int docID, byte[] packedValue) throws IOException {
checkAndThrowWithSampling();
in.visit(docID, packedValue);
}
@Override
public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
checkAndThrow();
return in.compare(minPackedValue, maxPackedValue);
}
@Override
public void grow(int count) {
checkAndThrow();
in.grow(count);
}
}
public static class ExitableTerms extends FilterTerms {
private QueryTimeout queryTimeout;
public ExitableTerms(Terms terms, QueryTimeout queryTimeout) {
super(terms);
this.queryTimeout = queryTimeout;
}
@Override
public TermsEnum intersect(CompiledAutomaton compiled, BytesRef startTerm) throws IOException {
return new ExitableTermsEnum(in.intersect(compiled, startTerm), queryTimeout);
}
@Override
public TermsEnum iterator() throws IOException {
return new ExitableTermsEnum(in.iterator(), queryTimeout);
}
}
public static class ExitableTermsEnum extends FilterTermsEnum {
private QueryTimeout queryTimeout;
public ExitableTermsEnum(TermsEnum termsEnum, QueryTimeout queryTimeout) {
super(termsEnum);
this.queryTimeout = queryTimeout;
checkAndThrow();
}
private void checkAndThrow() {
if (queryTimeout.shouldExit()) {
throw new ExitingReaderException("The request took too long to iterate over terms. Timeout: "
+ queryTimeout.toString()
+ ", TermsEnum=" + in
);
} else if (Thread.interrupted()) {
throw new ExitingReaderException("Interrupted while iterating over terms. TermsEnum=" + in);
}
}
@Override
public BytesRef next() throws IOException {
checkAndThrow();
return in.next();
}
}
public ExitableDirectoryReader(DirectoryReader in, QueryTimeout queryTimeout) throws IOException {
super(in, new ExitableSubReaderWrapper(queryTimeout));
this.queryTimeout = queryTimeout;
}
@Override
protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException {
return new ExitableDirectoryReader(in, queryTimeout);
}
public static DirectoryReader wrap(DirectoryReader in, QueryTimeout queryTimeout) throws IOException {
return new ExitableDirectoryReader(in, queryTimeout);
}
@Override
public CacheHelper getReaderCacheHelper() {
return in.getReaderCacheHelper();
}
@Override
public String toString() {
return "ExitableDirectoryReader(" + in.toString() + ")";
}
}