Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ Optimizations

* GITHUB#15085, GITHUB#15092: Hunspell suggestions: Ensure candidate roots are not worse before updating. (Ilia Permiashkin)

* GITHUB#15210: Enable pruning and skipping support for FirstPassGroupingCollector (Alexander Mueller)

Bug Fixes
---------------------
* GITHUB#14049: Randomize KNN codec params in RandomCodec. Fixes scalar quantization div-by-zero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ protected void collect(LeafCollector collector, int i) throws IOException {
}
}

private static class NoSkippingScorable extends FilterScorable {
public NoSkippingScorable(Scorable in) {
super(in);
}

@Override
public void setMinCompetitiveScore(float minScore) {
// ignore to enforce exhaustive hits
}
}

private class NoScoreCachingLeafCollector extends FilterLeafCollector {

final int maxDocsToCache;
Expand Down Expand Up @@ -185,6 +196,11 @@ protected void buffer(int doc) throws IOException {
docs[docCount] = doc;
}

@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(new NoSkippingScorable(scorer));
}

@Override
public void collect(int doc) throws IOException {
if (docs != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.HashMap;
import java.util.TreeSet;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
Expand All @@ -49,9 +50,10 @@ public class FirstPassGroupingCollector<T> extends SimpleCollector {
private final LeafFieldComparator[] leafComparators;
private final int[] reversed;
private final int topNGroups;
private final boolean needsScores;
private final HashMap<T, CollectedSearchGroup<T>> groupMap;
private final int compIDXEnd;
private final ScoreMode scoreMode;
private final boolean canSetMinScore;

// Set once we reach topNGroups unique groups:
/**
Expand All @@ -61,6 +63,9 @@ public class FirstPassGroupingCollector<T> extends SimpleCollector {

private int docBase;
private int spareSlot;
private Scorable scorer;
private int bottomSlot;
private float minCompetitiveScore;

/**
* Create the first pass collector.
Expand All @@ -82,7 +87,6 @@ public FirstPassGroupingCollector(
// and specialize it?

this.topNGroups = topNGroups;
this.needsScores = groupSort.needsScores();
final SortField[] sortFields = groupSort.getSort();
comparators = new FieldComparator<?>[sortFields.length];
leafComparators = new LeafFieldComparator[sortFields.length];
Expand All @@ -91,19 +95,34 @@ public FirstPassGroupingCollector(
for (int i = 0; i < sortFields.length; i++) {
final SortField sortField = sortFields[i];

final Pruning pruning;
if (i == 0) {
pruning = compIDXEnd >= 0 ? Pruning.GREATER_THAN : Pruning.GREATER_THAN_OR_EQUAL_TO;
} else {
pruning = Pruning.NONE;
}

// use topNGroups + 1 so we have a spare slot to use for comparing (tracked by
// this.spareSlot):
comparators[i] = sortField.getComparator(topNGroups + 1, Pruning.NONE);
comparators[i] = sortField.getComparator(topNGroups + 1, pruning);
reversed[i] = sortField.getReverse() ? -1 : 1;
}

if (SortField.FIELD_SCORE.equals(sortFields[0]) == true) {
scoreMode = ScoreMode.TOP_SCORES;
canSetMinScore = true;
} else {
scoreMode = groupSort.needsScores() ? ScoreMode.TOP_DOCS_WITH_SCORES : ScoreMode.TOP_DOCS;
canSetMinScore = false;
}

spareSlot = topNGroups;
groupMap = CollectionUtil.newHashMap(topNGroups);
}

@Override
public ScoreMode scoreMode() {
return needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
return scoreMode;
}

/**
Expand Down Expand Up @@ -154,10 +173,12 @@ public Collection<SearchGroup<T>> getTopGroups(int groupOffset) throws IOExcepti

@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
groupSelector.setScorer(scorer);
for (LeafFieldComparator comparator : leafComparators) {
comparator.setScorer(scorer);
}
setMinCompetitiveScore(scorer);
}

private boolean isCompetitive(int doc) throws IOException {
Expand Down Expand Up @@ -238,6 +259,9 @@ private void collectNewGroup(final int doc) throws IOException {
// number of groups; from here on we will drop
// bottom group when we insert new one:
buildSortedSet();

// Allow pruning for compatible leaf comparators.
leafComparators[0].setHitsThresholdReached();
}

} else {
Expand All @@ -262,9 +286,7 @@ private void collectNewGroup(final int doc) throws IOException {
assert orderedGroups.size() == topNGroups;

final int lastComparatorSlot = orderedGroups.last().comparatorSlot;
for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(lastComparatorSlot);
}
setBottomSlot(lastComparatorSlot);
}
}

Expand Down Expand Up @@ -320,13 +342,16 @@ private void collectExistingGroup(final int doc, final CollectedSearchGroup<T> g
// If we changed the value of the last group, or changed which group was last, then update
// bottom:
if (group == newLast || prevLast != newLast) {
for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(newLast.comparatorSlot);
}
setBottomSlot(newLast.comparatorSlot);
}
}
}

@Override
public DocIdSetIterator competitiveIterator() throws IOException {
return leafComparators[0].competitiveIterator();
}

private void buildSortedSet() throws IOException {
final Comparator<CollectedSearchGroup<?>> comparator =
new Comparator<>() {
Expand All @@ -348,13 +373,12 @@ public int compare(CollectedSearchGroup<?> o1, CollectedSearchGroup<?> o2) {
orderedGroups.addAll(groupMap.values());
assert orderedGroups.size() > 0;

for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(orderedGroups.last().comparatorSlot);
}
setBottomSlot(orderedGroups.last().comparatorSlot);
}

@Override
protected void doSetNextReader(LeafReaderContext readerContext) throws IOException {
minCompetitiveScore = 0f;
docBase = readerContext.docBase;
for (int i = 0; i < comparators.length; i++) {
leafComparators[i] = comparators[i].getLeafComparator(readerContext);
Expand All @@ -372,4 +396,25 @@ public GroupSelector<T> getGroupSelector() {
private boolean isGroupMapFull() {
return groupMap.size() >= topNGroups;
}

private void setBottomSlot(final int bottomSlot) throws IOException {
for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(bottomSlot);
}

this.bottomSlot = bottomSlot;
setMinCompetitiveScore(scorer);
}

private void setMinCompetitiveScore(final Scorable scorer) throws IOException {
if (canSetMinScore == false || isGroupMapFull() == false) {
return;
}

final float minScore = (float) comparators[0].value(bottomSlot);
if (minScore > minCompetitiveScore) {
scorer.setMinCompetitiveScore(minScore);
minCompetitiveScore = minScore;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.lucene.search.grouping;

import static org.hamcrest.Matchers.lessThanOrEqualTo;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -1373,13 +1375,6 @@ private TopGroups<BytesRef> searchShards(
+ canUseIDV);
}
// Run 1st pass collector to get top groups per shard
final Weight w =
topSearcher.createWeight(
topSearcher.rewrite(query),
groupSort.needsScores() || docSort.needsScores() || getMaxScores
? ScoreMode.COMPLETE
: ScoreMode.COMPLETE_NO_SCORES,
1);
final List<Collection<SearchGroup<BytesRef>>> shardGroups = new ArrayList<>();
List<FirstPassGroupingCollector<?>> firstPassGroupingCollectors = new ArrayList<>();
FirstPassGroupingCollector<?> firstPassCollector = null;
Expand All @@ -1403,6 +1398,10 @@ private TopGroups<BytesRef> searchShards(
System.out.println(" 1st pass collector=" + firstPassCollector);
}
firstPassGroupingCollectors.add(firstPassCollector);

final Weight w =
topSearcher.createWeight(topSearcher.rewrite(query), firstPassCollector.scoreMode(), 1);

subSearchers[shardIDX].search(w, firstPassCollector);
final Collection<SearchGroup<BytesRef>> topGroups = getSearchGroups(firstPassCollector, 0);
if (topGroups != null) {
Expand Down Expand Up @@ -1460,6 +1459,11 @@ private TopGroups<BytesRef> searchShards(
docSort,
docOffset + topNDocs,
getMaxScores);

final Weight w =
topSearcher.createWeight(
topSearcher.rewrite(query), secondPassCollector.scoreMode(), 1);

subSearchers[shardIDX].search(w, secondPassCollector);
shardTopGroups[shardIDX] = getTopGroups(secondPassCollector, 0);
if (VERBOSE) {
Expand Down Expand Up @@ -1520,14 +1524,14 @@ private void assertEquals(
"expected.groups.length != actual.groups.length",
expected.groups.length,
actual.groups.length);
assertEquals(
"expected.totalHitCount != actual.totalHitCount",
expected.totalHitCount,
actual.totalHitCount);
assertEquals(
"expected.totalGroupedHitCount != actual.totalGroupedHitCount",
expected.totalGroupedHitCount,
actual.totalGroupedHitCount);
assertThat(
"expected.totalHitCount >= actual.totalHitCount",
actual.totalHitCount,
lessThanOrEqualTo(expected.totalHitCount));
assertThat(
"expected.totalGroupedHitCount >= actual.totalGroupedHitCount",
actual.totalGroupedHitCount,
lessThanOrEqualTo(expected.totalGroupedHitCount));
if (expected.totalGroupCount != null && verifyTotalGroupCount) {
assertEquals(
"expected.totalGroupCount != actual.totalGroupCount",
Expand Down Expand Up @@ -1556,7 +1560,10 @@ private void assertEquals(

// TODO
// assertEquals(expectedGroup.maxScore, actualGroup.maxScore);
assertEquals(expectedGroup.totalHits().value(), actualGroup.totalHits().value());
assertThat(
"expectedGroup.totalHits().value() >= actualGroup.totalHits().value()",
actualGroup.totalHits().value(),
lessThanOrEqualTo(expectedGroup.totalHits().value()));

final ScoreDoc[] expectedFDs = expectedGroup.scoreDocs();
final ScoreDoc[] actualFDs = actualGroup.scoreDocs();
Expand Down
Loading