package org.apache.lucene.search.grouping;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.TreeSet;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
public class FirstPassGroupingCollector<T> extends SimpleCollector {
private final GroupSelector<T> groupSelector;
private final FieldComparator<?>[] comparators;
private final LeafFieldComparator[] leafComparators;
private final int[] reversed;
private final int topNGroups;
private final boolean needsScores;
private final HashMap<T, CollectedSearchGroup<T>> groupMap;
private final int compIDXEnd;
protected TreeSet<CollectedSearchGroup<T>> orderedGroups;
private int docBase;
private int spareSlot;
@SuppressWarnings({"unchecked", "rawtypes"})
public FirstPassGroupingCollector(GroupSelector<T> groupSelector, Sort groupSort, int topNGroups) {
this.groupSelector = groupSelector;
if (topNGroups < 1) {
throw new IllegalArgumentException("topNGroups must be >= 1 (got " + topNGroups + ")");
}
this.topNGroups = topNGroups;
this.needsScores = groupSort.needsScores();
final SortField[] sortFields = groupSort.getSort();
comparators = new FieldComparator[sortFields.length];
leafComparators = new LeafFieldComparator[sortFields.length];
compIDXEnd = comparators.length - 1;
reversed = new int[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
final SortField sortField = sortFields[i];
comparators[i] = sortField.getComparator(topNGroups + 1, i);
reversed[i] = sortField.getReverse() ? -1 : 1;
}
spareSlot = topNGroups;
groupMap = new HashMap<>(topNGroups);
}
@Override
public ScoreMode scoreMode() {
return needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
}
public Collection<SearchGroup<T>> getTopGroups(int groupOffset) throws IOException {
if (groupOffset < 0) {
throw new IllegalArgumentException("groupOffset must be >= 0 (got " + groupOffset + ")");
}
if (groupMap.size() <= groupOffset) {
return null;
}
if (orderedGroups == null) {
buildSortedSet();
}
final Collection<SearchGroup<T>> result = new ArrayList<>();
int upto = 0;
final int sortFieldCount = comparators.length;
for(CollectedSearchGroup<T> group : orderedGroups) {
if (upto++ < groupOffset) {
continue;
}
SearchGroup<T> searchGroup = new SearchGroup<>();
searchGroup.groupValue = group.groupValue;
searchGroup.sortValues = new Object[sortFieldCount];
for(int sortFieldIDX=0;sortFieldIDX<sortFieldCount;sortFieldIDX++) {
searchGroup.sortValues[sortFieldIDX] = comparators[sortFieldIDX].value(group.comparatorSlot);
}
result.add(searchGroup);
}
return result;
}
@Override
public void setScorer(Scorable scorer) throws IOException {
groupSelector.setScorer(scorer);
for (LeafFieldComparator comparator : leafComparators) {
comparator.setScorer(scorer);
}
}
private boolean isCompetitive(int doc) throws IOException {
if (orderedGroups != null) {
for (int compIDX = 0;; compIDX++) {
final int c = reversed[compIDX] * leafComparators[compIDX].compareBottom(doc);
if (c < 0) {
return false;
} else if (c > 0) {
break;
} else if (compIDX == compIDXEnd) {
return false;
}
}
}
return true;
}
@Override
public void collect(int doc) throws IOException {
if (isCompetitive(doc) == false)
return;
groupSelector.advanceTo(doc);
T groupValue = groupSelector.currentValue();
final CollectedSearchGroup<T> group = groupMap.get(groupValue);
if (group == null) {
if (groupMap.size() < topNGroups) {
CollectedSearchGroup<T> sg = new CollectedSearchGroup<>();
sg.groupValue = groupSelector.copyValue();
sg.comparatorSlot = groupMap.size();
sg.topDoc = docBase + doc;
for (LeafFieldComparator fc : leafComparators) {
fc.copy(sg.comparatorSlot, doc);
}
groupMap.put(sg.groupValue, sg);
if (groupMap.size() == topNGroups) {
buildSortedSet();
}
return;
}
final CollectedSearchGroup<T> bottomGroup = orderedGroups.pollLast();
assert orderedGroups.size() == topNGroups -1;
groupMap.remove(bottomGroup.groupValue);
bottomGroup.groupValue = groupSelector.copyValue();
bottomGroup.topDoc = docBase + doc;
for (LeafFieldComparator fc : leafComparators) {
fc.copy(bottomGroup.comparatorSlot, doc);
}
groupMap.put(bottomGroup.groupValue, bottomGroup);
orderedGroups.add(bottomGroup);
assert orderedGroups.size() == topNGroups;
final int lastComparatorSlot = orderedGroups.last().comparatorSlot;
for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(lastComparatorSlot);
}
return;
}
for (int compIDX = 0;; compIDX++) {
leafComparators[compIDX].copy(spareSlot, doc);
final int c = reversed[compIDX] * comparators[compIDX].compare(group.comparatorSlot, spareSlot);
if (c < 0) {
return;
} else if (c > 0) {
for (int compIDX2=compIDX+1; compIDX2<comparators.length; compIDX2++) {
leafComparators[compIDX2].copy(spareSlot, doc);
}
break;
} else if (compIDX == compIDXEnd) {
return;
}
}
final CollectedSearchGroup<T> prevLast;
if (orderedGroups != null) {
prevLast = orderedGroups.last();
orderedGroups.remove(group);
assert orderedGroups.size() == topNGroups-1;
} else {
prevLast = null;
}
group.topDoc = docBase + doc;
final int tmp = spareSlot;
spareSlot = group.comparatorSlot;
group.comparatorSlot = tmp;
if (orderedGroups != null) {
orderedGroups.add(group);
assert orderedGroups.size() == topNGroups;
final CollectedSearchGroup<?> newLast = orderedGroups.last();
if (group == newLast || prevLast != newLast) {
for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(newLast.comparatorSlot);
}
}
}
}
private void buildSortedSet() throws IOException {
final Comparator<CollectedSearchGroup<?>> comparator = new Comparator<CollectedSearchGroup<?>>() {
@Override
public int compare(CollectedSearchGroup<?> o1, CollectedSearchGroup<?> o2) {
for (int compIDX = 0;; compIDX++) {
FieldComparator<?> fc = comparators[compIDX];
final int c = reversed[compIDX] * fc.compare(o1.comparatorSlot, o2.comparatorSlot);
if (c != 0) {
return c;
} else if (compIDX == compIDXEnd) {
return o1.topDoc - o2.topDoc;
}
}
}
};
orderedGroups = new TreeSet<>(comparator);
orderedGroups.addAll(groupMap.values());
assert orderedGroups.size() > 0;
for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(orderedGroups.last().comparatorSlot);
}
}
@Override
protected void doSetNextReader(LeafReaderContext readerContext) throws IOException {
docBase = readerContext.docBase;
for (int i=0; i<comparators.length; i++) {
leafComparators[i] = comparators[i].getLeafComparator(readerContext);
}
groupSelector.setNextReader(readerContext);
}
public GroupSelector<T> getGroupSelector() {
return groupSelector;
}
}