Skip to content

Commit 350f169

Browse files
committed
Add hnsw.totalExploreAdditionalHits
1 parent 273ba4b commit 350f169

10 files changed

Lines changed: 79 additions & 36 deletions

File tree

container-search/abi-spec.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,7 @@
10241024
"public java.lang.String getIndexName()",
10251025
"public double getDistanceThreshold()",
10261026
"public int getHnswExploreAdditionalHits()",
1027+
"public java.lang.Integer getHnswTotalExploreAdditionalHits()",
10271028
"public boolean getAllowApproximate()",
10281029
"public java.lang.String getQueryTensorName()",
10291030
"public java.lang.Double getHnswApproximateThreshold()",
@@ -1037,6 +1038,7 @@
10371038
"public void setTotalTargetHits(java.lang.Integer)",
10381039
"public void setDistanceThreshold(double)",
10391040
"public void setHnswExploreAdditionalHits(int)",
1041+
"public void setHnswTotalExploreAdditionalHits(java.lang.Integer)",
10401042
"public void setAllowApproximate(boolean)",
10411043
"public void setHnswApproximateThreshold(java.lang.Double)",
10421044
"public void setHnswExplorationSlack(java.lang.Double)",

container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ public class NearestNeighborItem extends SimpleTaggableItem {
2222

2323
private Integer targetHits = null;
2424
private Integer totalTargetHits = null;
25-
private int hnswExploreAdditionalHits = 0;
26-
private double distanceThreshold = Double.POSITIVE_INFINITY;
25+
private Integer hnswExploreAdditionalHits = null;
26+
private Integer hnswTotalExploreAdditionalHits = null;
27+
private double distanceThreshold = Double.POSITIVE_INFINITY;
2728
private boolean approximate = true;
2829
private String field;
2930
private final String queryTensorName;
@@ -59,8 +60,11 @@ public NearestNeighborItem(String fieldName, String queryTensorName) {
5960
/** Returns the distance threshold for nearest-neighbor hits */
6061
public double getDistanceThreshold () { return this.distanceThreshold ; }
6162

62-
/** Returns the number of extra hits to explore in HNSW algorithm */
63-
public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; }
63+
/** Returns the number of extra hits to explore in HNSW algorithm per node. */
64+
public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits != null ? hnswExploreAdditionalHits : 0; }
65+
66+
/** Returns the total number of extra hits to explore in HNSW algorithm across all nodes, or null if not set. */
67+
public Integer getHnswTotalExploreAdditionalHits() { return hnswTotalExploreAdditionalHits; }
6468

6569
/** Returns whether approximation is allowed */
6670
public boolean getAllowApproximate() { return approximate; }
@@ -103,9 +107,12 @@ public NearestNeighborItem(String fieldName, String queryTensorName) {
103107
/** Set the distance threshold for nearest-neighbor hits */
104108
public void setDistanceThreshold(double threshold) { this.distanceThreshold = threshold; }
105109

106-
/** Set the number of extra hits to explore in HNSW algorithm */
110+
/** Set the number of extra hits to explore in HNSW algorithm per node. */
107111
public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; }
108112

113+
/** Set the total number of extra hits to explore in HNSW algorithm across all nodes. */
114+
public void setHnswTotalExploreAdditionalHits(Integer total) {this.hnswTotalExploreAdditionalHits = total; }
115+
109116
/** Set whether approximation is allowed */
110117
public void setAllowApproximate(boolean value) { this.approximate = value; }
111118

@@ -147,7 +154,7 @@ public int encode(ByteBuffer buffer, SerializationContext context) {
147154
int approxNum = (approximate ? 1 : 0);
148155
IntegerCompressor.putCompressedPositiveNumber(resolveTargetHits(context), buffer);
149156
IntegerCompressor.putCompressedPositiveNumber(approxNum, buffer);
150-
IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer);
157+
IntegerCompressor.putCompressedPositiveNumber(resolveHnswExploreAdditionalHits(context), buffer);
151158
buffer.putDouble(distanceThreshold);
152159
return 1; // number of encoded stack dump items
153160
}
@@ -156,13 +163,16 @@ public int encode(ByteBuffer buffer, SerializationContext context) {
156163
protected void appendBodyString(StringBuilder buffer) {
157164
buffer.append("{field=").append(field);
158165
buffer.append(",queryTensorName=").append(queryTensorName);
159-
buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits);
160166
buffer.append(",distanceThreshold=").append(distanceThreshold);
161167
buffer.append(",approximate=").append(approximate);
162168
if (targetHits != null)
163169
buffer.append(",targetHits=").append(targetHits);
164170
if (totalTargetHits != null)
165171
buffer.append(",totalTargetHits=").append(totalTargetHits);
172+
if (hnswExploreAdditionalHits != null)
173+
buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits);
174+
if (hnswTotalExploreAdditionalHits != null)
175+
buffer.append(",hnsw.totalExploreAdditionalHits=").append(hnswTotalExploreAdditionalHits);
166176
if (hnswApproximateThreshold != null)
167177
buffer.append(",hnsw.approximateThreshold=").append(hnswApproximateThreshold);
168178
if (hnswExplorationSlack != null)
@@ -183,13 +193,16 @@ public void disclose(Discloser discloser) {
183193
super.disclose(discloser);
184194
discloser.addProperty("field", field);
185195
discloser.addProperty("queryTensorName", queryTensorName);
186-
discloser.addProperty("hnsw.exploreAdditionalHits", hnswExploreAdditionalHits);
187196
discloser.addProperty("distanceThreshold", distanceThreshold);
188197
discloser.addProperty("approximate", approximate);
189198
if (targetHits != null)
190199
discloser.addProperty("targetHits", targetHits);
191200
if (totalTargetHits != null)
192201
discloser.addProperty("totalTargetHits", totalTargetHits);
202+
if (hnswExploreAdditionalHits != null)
203+
discloser.addProperty("hnsw.exploreAdditionalHits", hnswExploreAdditionalHits);
204+
if (hnswTotalExploreAdditionalHits != null)
205+
discloser.addProperty("hnsw.totalExploreAdditionalHits", hnswTotalExploreAdditionalHits);
193206
if (hnswApproximateThreshold != null)
194207
discloser.addProperty("hnsw.approximateThreshold", hnswApproximateThreshold);
195208
if (hnswExplorationSlack != null)
@@ -210,7 +223,8 @@ public boolean equals(Object o) {
210223
NearestNeighborItem other = (NearestNeighborItem)o;
211224
if ( ! Objects.equals(this.targetHits, other.targetHits)) return false;
212225
if ( ! Objects.equals(this.totalTargetHits, other.totalTargetHits)) return false;
213-
if (this.hnswExploreAdditionalHits != other.hnswExploreAdditionalHits) return false;
226+
if ( ! Objects.equals(this.hnswExploreAdditionalHits, other.hnswExploreAdditionalHits)) return false;
227+
if ( ! Objects.equals(this.hnswTotalExploreAdditionalHits, other.hnswTotalExploreAdditionalHits)) return false;
214228
if (this.distanceThreshold != other.distanceThreshold) return false;
215229
if (this.approximate != other.approximate) return false;
216230
if ( ! this.field.equals(other.field)) return false;
@@ -226,7 +240,8 @@ public boolean equals(Object o) {
226240

227241
@Override
228242
public int hashCode() {
229-
return Objects.hash(super.hashCode(), targetHits, totalTargetHits, hnswExploreAdditionalHits,
243+
return Objects.hash(super.hashCode(), targetHits, totalTargetHits,
244+
hnswExploreAdditionalHits, hnswTotalExploreAdditionalHits,
230245
distanceThreshold, approximate, field, queryTensorName,
231246
hnswApproximateThreshold, hnswExplorationSlack,
232247
hnswFilterFirstExploration, hnswFilterFirstThreshold,
@@ -240,7 +255,7 @@ SearchProtocol.QueryTreeItem toProtobuf(SerializationContext context) {
240255
builder.setQueryTensorName(queryTensorName);
241256
builder.setTargetNumHits(resolveTargetHits(context));
242257
builder.setAllowApproximate(approximate);
243-
builder.setExploreAdditionalHits(hnswExploreAdditionalHits);
258+
builder.setExploreAdditionalHits(resolveHnswExploreAdditionalHits(context));
244259
builder.setDistanceThreshold(distanceThreshold);
245260
if (hnswApproximateThreshold != null) {
246261
builder.setApproximateThreshold(hnswApproximateThreshold);
@@ -265,6 +280,12 @@ SearchProtocol.QueryTreeItem toProtobuf(SerializationContext context) {
265280
.build();
266281
}
267282

283+
private int resolveHnswExploreAdditionalHits(SerializationContext context) {
284+
if (hnswExploreAdditionalHits != null) return hnswExploreAdditionalHits;
285+
if (hnswTotalExploreAdditionalHits == null) return 0;
286+
return context.contentShareOf(hnswTotalExploreAdditionalHits);
287+
}
288+
268289
private int resolveTargetHits(SerializationContext context) {
269290
if (targetHits != null) return targetHits;
270291
if (totalTargetHits == null)

container-search/src/main/java/com/yahoo/search/query/SelectParser.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
import com.yahoo.slime.Type;
7272
import java.nio.charset.StandardCharsets;
7373
import java.util.ArrayList;
74-
import java.util.Collection;
7574
import java.util.HashMap;
7675
import java.util.List;
7776
import java.util.Map;
@@ -108,6 +107,7 @@
108107
import static com.yahoo.search.yql.YqlParser.GEO_LOCATION;
109108
import static com.yahoo.search.yql.YqlParser.HIT_LIMIT;
110109
import static com.yahoo.search.yql.YqlParser.HNSW_EXPLORE_ADDITIONAL_HITS;
110+
import static com.yahoo.search.yql.YqlParser.TOTAL_HNSW_EXPLORE_ADDITIONAL_HITS;
111111
import static com.yahoo.search.yql.YqlParser.IMPLICIT_TRANSFORMS;
112112
import static com.yahoo.search.yql.YqlParser.LABEL;
113113
import static com.yahoo.search.yql.YqlParser.NEAR;
@@ -533,8 +533,10 @@ private Item buildNearestNeighbor(String key, Inspector value) {
533533
item.setDistanceThreshold(distanceThreshold);
534534
}
535535
if (HNSW_EXPLORE_ADDITIONAL_HITS.equals(annotation_name)) {
536-
int hnswExploreAdditionalHits = (int)(annotation_value.asDouble());
537-
item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits);
536+
item.setHnswExploreAdditionalHits((int)(annotation_value.asDouble()));
537+
}
538+
if (TOTAL_HNSW_EXPLORE_ADDITIONAL_HITS.equals(annotation_name)) {
539+
item.setHnswTotalExploreAdditionalHits((int)(annotation_value.asDouble()));
538540
}
539541
if (APPROXIMATE.equals(annotation_name)) {
540542
boolean allowApproximate = annotation_value.asBool();

container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ private String validate(NearestNeighborItem item) {
122122
if (item.getTotalTargetHits() != null && item.getTotalTargetHits() < 1)
123123
return item + " has invalid totalTargetHits " + item.getTotalTargetHits() + ": Must be >= 1";
124124

125+
if (item.getHnswExploreAdditionalHits() != 0 && item.getHnswTotalExploreAdditionalHits() != null)
126+
return item + " cannot have both hnsw.exploreAdditionalHits and hnsw.totalExploreAdditionalHits set";
127+
125128
String queryFeatureName = "query(" + item.getQueryTensorName() + ")";
126129
Optional<Tensor> queryTensor = query.getRanking().getFeatures().getTensor(queryFeatureName);
127130
if (queryTensor.isEmpty())

container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,12 @@ boolean serialize(StringBuilder destination, NearestNeighborItem item, Boolean i
776776
int explore = item.getHnswExploreAdditionalHits();
777777
if (explore != 0) {
778778
comma(destination, initLen);
779-
String key = YqlParser.HNSW_EXPLORE_ADDITIONAL_HITS;
780-
annotationKey(destination, key).append(explore);
779+
annotationKey(destination, YqlParser.HNSW_EXPLORE_ADDITIONAL_HITS).append(explore);
780+
}
781+
Integer totalExplore = item.getHnswTotalExploreAdditionalHits();
782+
if (totalExplore != null) {
783+
comma(destination, initLen);
784+
annotationKey(destination, YqlParser.TOTAL_HNSW_EXPLORE_ADDITIONAL_HITS).append(totalExplore);
781785
}
782786
boolean allow_approx = item.getAllowApproximate();
783787
if (! allow_approx) {

container-search/src/main/java/com/yahoo/search/yql/YqlParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ private static class IndexNameExpander {
185185
public static final String HNSW_APPROXIMATE_THRESHOLD = "hnsw.approximateThreshold"; // default 0.05
186186
public static final String HNSW_EXPLORATION_SLACK = "hnsw.explorationSlack"; // 'adaptive beam', default 0.0 (aka off)
187187
public static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits"; // 'ef' in HNSW
188+
public static final String TOTAL_HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.totalExploreAdditionalHits";
188189
public static final String HNSW_FILTER_FIRST_EXPLORATION = "hnsw.filterFirstExploration"; // acorn-1 aggression, default 0.3
189190
public static final String HNSW_FILTER_FIRST_THRESHOLD = "hnsw.filterFirstThreshold"; // 'acorn-1', default 0.0 (aka off)
190191
public static final String HNSW_POST_FILTER_THRESHOLD = "hnsw.postFilterThreshold"; // default 1.0 (aka off)
@@ -621,6 +622,8 @@ private Item buildNearestNeighbor(OperatorNode<ExpressionOperator> ast) {
621622
if (hnswExploreAdditionalHits != null) {
622623
item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits);
623624
}
625+
item.setHnswTotalExploreAdditionalHits(getAnnotation(ast, TOTAL_HNSW_EXPLORE_ADDITIONAL_HITS,
626+
Integer.class, null, "total extra hits to explore for HNSW algorithm across all nodes"));
624627
Boolean allowApproximate = getAnnotation(ast, APPROXIMATE,
625628
Boolean.class, null, "allow approximate nearest neighbor search");
626629
if (allowApproximate != null) {

container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,15 @@ void testUpdateOfRpcResourcePool() {
9797
}
9898

9999
@Test
100-
void contentShareIsUsedToSetTargetHits() throws IOException {
100+
void contentShareIsUsedToSetQueryParameters() throws IOException {
101101
// Total target is distributed proportional to content share (by active document count)
102-
assertAdjustedTotalTargetHits(List.of(46, 55), List.of(1000, 1200));
102+
assertAdjustedQueryParameters(List.of(46, 55), List.of(1000, 1200));
103103

104104
// Small differences (<5%) do not justify reserialization and so get the same value
105-
assertAdjustedTotalTargetHits(List.of(50, 50), List.of(1000, 1035));
105+
assertAdjustedQueryParameters(List.of(50, 50), List.of(1000, 1035));
106106

107107
// Nodes with 0 documents get default content share: 1/nodes
108-
assertAdjustedTotalTargetHits(List.of(49, 49, 20, 1, 1), List.of(1000, 1035, 0, 1, 13));
108+
assertAdjustedQueryParameters(List.of(49, 49, 20, 1, 1), List.of(1000, 1035, 0, 1, 13));
109109
}
110110

111111
@Test
@@ -316,7 +316,7 @@ private void assertProperty(String name, int value, SearchProtocol.SearchRequest
316316
fail("Property '" + name + "' is not present");
317317
}
318318

319-
private void assertAdjustedTotalTargetHits(List<Integer> expected, List<Integer> activeDocs) throws IOException {
319+
private void assertAdjustedQueryParameters(List<Integer> expected, List<Integer> activeDocs) throws IOException {
320320
Query query = new Query();
321321
var root = new OrItem();
322322

@@ -328,6 +328,7 @@ private void assertAdjustedTotalTargetHits(List<Integer> expected, List<Integer>
328328

329329
var nn = new NearestNeighborItem("myField", "myQueryTensor");
330330
nn.setTotalTargetHits(100);
331+
nn.setHnswTotalExploreAdditionalHits(100);
331332
root.addItem(nn);
332333

333334
query.getModel().getQueryTree().setRoot(root);
@@ -341,7 +342,9 @@ private void assertAdjustedTotalTargetHits(List<Integer> expected, List<Integer>
341342
assertEquals(expected.get(i), or.getChildren(0).getItemWeakAnd().getTargetNumHits(),
342343
"WeakAnd in node " + i);
343344
assertEquals(expected.get(i), or.getChildren(1).getItemNearestNeighbor().getTargetNumHits(),
344-
"NearestNeighbor in node " + i);
345+
"TargetNumHits in NearestNeighbor in node " + i);
346+
assertEquals(expected.get(i), or.getChildren(1).getItemNearestNeighbor().getExploreAdditionalHits(),
347+
"ExploreAdditionalHits in NearestNeighbor in node " + i);
345348
}
346349
}
347350

container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ static String desc(String field, String qt, int th, String errmsg) {
141141
return "NEAREST_NEIGHBOR {" +
142142
"field=" + field +
143143
",queryTensorName=" + qt +
144-
",hnsw.exploreAdditionalHits=0" +
145144
",distanceThreshold=Infinity" +
146145
",approximate=true" +
147146
(th != 0 ? ",targetHits=" + th : "") +

0 commit comments

Comments
 (0)