@@ -22,8 +22,9 @@ public class NearestNeighborItem extends SimpleTaggableItem {
2222
2323 private Integer targetHits = null ;
2424 private Integer totalTargetHits = null ;
25- private int hnswExploreAdditionalHits = 0 ;
26- private double distanceThreshold = Double .POSITIVE_INFINITY ;
25+ private Integer hnswExploreAdditionalHits = null ;
26+ private Integer hnswTotalExploreAdditionalHits = null ;
27+ private double distanceThreshold = Double .POSITIVE_INFINITY ;
2728 private boolean approximate = true ;
2829 private String field ;
2930 private final String queryTensorName ;
@@ -59,8 +60,11 @@ public NearestNeighborItem(String fieldName, String queryTensorName) {
5960 /** Returns the distance threshold for nearest-neighbor hits */
6061 public double getDistanceThreshold () { return this .distanceThreshold ; }
6162
62- /** Returns the number of extra hits to explore in HNSW algorithm */
63- public int getHnswExploreAdditionalHits () { return hnswExploreAdditionalHits ; }
63+ /** Returns the number of extra hits to explore in HNSW algorithm per node. */
64+ public int getHnswExploreAdditionalHits () { return hnswExploreAdditionalHits != null ? hnswExploreAdditionalHits : 0 ; }
65+
66+ /** Returns the total number of extra hits to explore in HNSW algorithm across all nodes, or null if not set. */
67+ public Integer getHnswTotalExploreAdditionalHits () { return hnswTotalExploreAdditionalHits ; }
6468
6569 /** Returns whether approximation is allowed */
6670 public boolean getAllowApproximate () { return approximate ; }
@@ -103,9 +107,12 @@ public NearestNeighborItem(String fieldName, String queryTensorName) {
103107 /** Set the distance threshold for nearest-neighbor hits */
104108 public void setDistanceThreshold (double threshold ) { this .distanceThreshold = threshold ; }
105109
106- /** Set the number of extra hits to explore in HNSW algorithm */
110+ /** Set the number of extra hits to explore in HNSW algorithm per node. */
107111 public void setHnswExploreAdditionalHits (int num ) { this .hnswExploreAdditionalHits = num ; }
108112
113+ /** Set the total number of extra hits to explore in HNSW algorithm across all nodes. */
114+ public void setHnswTotalExploreAdditionalHits (Integer total ) {this .hnswTotalExploreAdditionalHits = total ; }
115+
109116 /** Set whether approximation is allowed */
110117 public void setAllowApproximate (boolean value ) { this .approximate = value ; }
111118
@@ -147,7 +154,7 @@ public int encode(ByteBuffer buffer, SerializationContext context) {
147154 int approxNum = (approximate ? 1 : 0 );
148155 IntegerCompressor .putCompressedPositiveNumber (resolveTargetHits (context ), buffer );
149156 IntegerCompressor .putCompressedPositiveNumber (approxNum , buffer );
150- IntegerCompressor .putCompressedPositiveNumber (hnswExploreAdditionalHits , buffer );
157+ IntegerCompressor .putCompressedPositiveNumber (resolveHnswExploreAdditionalHits ( context ) , buffer );
151158 buffer .putDouble (distanceThreshold );
152159 return 1 ; // number of encoded stack dump items
153160 }
@@ -156,13 +163,16 @@ public int encode(ByteBuffer buffer, SerializationContext context) {
156163 protected void appendBodyString (StringBuilder buffer ) {
157164 buffer .append ("{field=" ).append (field );
158165 buffer .append (",queryTensorName=" ).append (queryTensorName );
159- buffer .append (",hnsw.exploreAdditionalHits=" ).append (hnswExploreAdditionalHits );
160166 buffer .append (",distanceThreshold=" ).append (distanceThreshold );
161167 buffer .append (",approximate=" ).append (approximate );
162168 if (targetHits != null )
163169 buffer .append (",targetHits=" ).append (targetHits );
164170 if (totalTargetHits != null )
165171 buffer .append (",totalTargetHits=" ).append (totalTargetHits );
172+ if (hnswExploreAdditionalHits != null )
173+ buffer .append (",hnsw.exploreAdditionalHits=" ).append (hnswExploreAdditionalHits );
174+ if (hnswTotalExploreAdditionalHits != null )
175+ buffer .append (",hnsw.totalExploreAdditionalHits=" ).append (hnswTotalExploreAdditionalHits );
166176 if (hnswApproximateThreshold != null )
167177 buffer .append (",hnsw.approximateThreshold=" ).append (hnswApproximateThreshold );
168178 if (hnswExplorationSlack != null )
@@ -183,13 +193,16 @@ public void disclose(Discloser discloser) {
183193 super .disclose (discloser );
184194 discloser .addProperty ("field" , field );
185195 discloser .addProperty ("queryTensorName" , queryTensorName );
186- discloser .addProperty ("hnsw.exploreAdditionalHits" , hnswExploreAdditionalHits );
187196 discloser .addProperty ("distanceThreshold" , distanceThreshold );
188197 discloser .addProperty ("approximate" , approximate );
189198 if (targetHits != null )
190199 discloser .addProperty ("targetHits" , targetHits );
191200 if (totalTargetHits != null )
192201 discloser .addProperty ("totalTargetHits" , totalTargetHits );
202+ if (hnswExploreAdditionalHits != null )
203+ discloser .addProperty ("hnsw.exploreAdditionalHits" , hnswExploreAdditionalHits );
204+ if (hnswTotalExploreAdditionalHits != null )
205+ discloser .addProperty ("hnsw.totalExploreAdditionalHits" , hnswTotalExploreAdditionalHits );
193206 if (hnswApproximateThreshold != null )
194207 discloser .addProperty ("hnsw.approximateThreshold" , hnswApproximateThreshold );
195208 if (hnswExplorationSlack != null )
@@ -210,7 +223,8 @@ public boolean equals(Object o) {
210223 NearestNeighborItem other = (NearestNeighborItem )o ;
211224 if ( ! Objects .equals (this .targetHits , other .targetHits )) return false ;
212225 if ( ! Objects .equals (this .totalTargetHits , other .totalTargetHits )) return false ;
213- if (this .hnswExploreAdditionalHits != other .hnswExploreAdditionalHits ) return false ;
226+ if ( ! Objects .equals (this .hnswExploreAdditionalHits , other .hnswExploreAdditionalHits )) return false ;
227+ if ( ! Objects .equals (this .hnswTotalExploreAdditionalHits , other .hnswTotalExploreAdditionalHits )) return false ;
214228 if (this .distanceThreshold != other .distanceThreshold ) return false ;
215229 if (this .approximate != other .approximate ) return false ;
216230 if ( ! this .field .equals (other .field )) return false ;
@@ -226,7 +240,8 @@ public boolean equals(Object o) {
226240
227241 @ Override
228242 public int hashCode () {
229- return Objects .hash (super .hashCode (), targetHits , totalTargetHits , hnswExploreAdditionalHits ,
243+ return Objects .hash (super .hashCode (), targetHits , totalTargetHits ,
244+ hnswExploreAdditionalHits , hnswTotalExploreAdditionalHits ,
230245 distanceThreshold , approximate , field , queryTensorName ,
231246 hnswApproximateThreshold , hnswExplorationSlack ,
232247 hnswFilterFirstExploration , hnswFilterFirstThreshold ,
@@ -240,7 +255,7 @@ SearchProtocol.QueryTreeItem toProtobuf(SerializationContext context) {
240255 builder .setQueryTensorName (queryTensorName );
241256 builder .setTargetNumHits (resolveTargetHits (context ));
242257 builder .setAllowApproximate (approximate );
243- builder .setExploreAdditionalHits (hnswExploreAdditionalHits );
258+ builder .setExploreAdditionalHits (resolveHnswExploreAdditionalHits ( context ) );
244259 builder .setDistanceThreshold (distanceThreshold );
245260 if (hnswApproximateThreshold != null ) {
246261 builder .setApproximateThreshold (hnswApproximateThreshold );
@@ -265,6 +280,12 @@ SearchProtocol.QueryTreeItem toProtobuf(SerializationContext context) {
265280 .build ();
266281 }
267282
283+ private int resolveHnswExploreAdditionalHits (SerializationContext context ) {
284+ if (hnswExploreAdditionalHits != null ) return hnswExploreAdditionalHits ;
285+ if (hnswTotalExploreAdditionalHits == null ) return 0 ;
286+ return context .contentShareOf (hnswTotalExploreAdditionalHits );
287+ }
288+
268289 private int resolveTargetHits (SerializationContext context ) {
269290 if (targetHits != null ) return targetHits ;
270291 if (totalTargetHits == null )
0 commit comments