package org.apache.lucene.queries.intervals;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.lucene.search.BooleanQuery;
final class Disjunctions {
public static List<IntervalsSource> pullUp(List<IntervalsSource> sources,
Function<List<IntervalsSource>, IntervalsSource> function) {
List<List<IntervalsSource>> rewritten = new ArrayList<>();
rewritten.add(new ArrayList<>());
for (IntervalsSource source : sources) {
List<IntervalsSource> disjuncts = splitDisjunctions(source);
if (disjuncts.size() == 1) {
rewritten.forEach(l -> l.add(disjuncts.get(0)));
}
else {
if (rewritten.size() * disjuncts.size() > BooleanQuery.getMaxClauseCount()) {
throw new IllegalArgumentException("Too many disjunctions to expand");
}
List<List<IntervalsSource>> toAdd = new ArrayList<>();
for (IntervalsSource disj : disjuncts) {
for (List<IntervalsSource> subList : rewritten) {
List<IntervalsSource> l = new ArrayList<>(subList);
l.add(disj);
toAdd.add(l);
}
}
rewritten = toAdd;
}
}
if (rewritten.size() == 1) {
return Collections.singletonList(function.apply(rewritten.get(0)));
}
return rewritten.stream().map(function).collect(Collectors.toList());
}
public static List<IntervalsSource> pullUp(IntervalsSource source, Function<IntervalsSource, IntervalsSource> function) {
List<IntervalsSource> disjuncts = splitDisjunctions(source);
if (disjuncts.size() == 1) {
return Collections.singletonList(function.apply(disjuncts.get(0)));
}
return disjuncts.stream().map(function).collect(Collectors.toList());
}
private static List<IntervalsSource> splitDisjunctions(IntervalsSource source) {
List<IntervalsSource> singletons = new ArrayList<>();
List<IntervalsSource> nonSingletons = new ArrayList<>();
for (IntervalsSource disj : source.pullUpDisjunctions()) {
if (disj.minExtent() == 1) {
singletons.add(disj);
}
else {
nonSingletons.add(disj);
}
}
List<IntervalsSource> split = new ArrayList<>();
if (singletons.size() > 0) {
split.add(Intervals.or(singletons.toArray(new IntervalsSource[0])));
}
split.addAll(nonSingletons);
return split;
}
}