25
25
import org .apache .lucene .util .IOUtils ;
26
26
import org .apache .lucene .util .RamUsageEstimator ;
27
27
import org .opensearch .common .StopWatch ;
28
+ import org .opensearch .knn .index .quantizationservice .QuantizationService ;
28
29
import org .opensearch .knn .index .VectorDataType ;
29
30
import org .opensearch .knn .index .codec .nativeindex .NativeIndexWriter ;
30
- import org .opensearch .knn .index .quantizationservice .QuantizationService ;
31
31
import org .opensearch .knn .index .vectorvalues .KNNVectorValues ;
32
+ import org .opensearch .knn .index .vectorvalues .KNNVectorValuesFactory ;
32
33
import org .opensearch .knn .plugin .stats .KNNGraphValue ;
33
34
import org .opensearch .knn .quantization .models .quantizationParams .QuantizationParams ;
34
35
import org .opensearch .knn .quantization .models .quantizationState .QuantizationState ;
35
36
36
37
import java .io .IOException ;
37
38
import java .util .ArrayList ;
38
39
import java .util .List ;
39
- import java .util .function .Supplier ;
40
40
41
41
import static org .opensearch .knn .common .FieldInfoExtractor .extractVectorDataType ;
42
- import static org .opensearch .knn .index .vectorvalues .KNNVectorValuesFactory .getVectorValues ;
43
42
44
43
/**
45
44
* A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines.
48
47
public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
49
48
private static final long SHALLOW_SIZE = RamUsageEstimator .shallowSizeOfInstance (NativeEngines990KnnVectorsWriter .class );
50
49
50
+ private static final String FLUSH_OPERATION = "flush" ;
51
+ private static final String MERGE_OPERATION = "merge" ;
52
+
51
53
private final SegmentWriteState segmentWriteState ;
52
54
private final FlatVectorsWriter flatVectorsWriter ;
53
55
private KNN990QuantizationStateWriter quantizationStateWriter ;
54
56
private final List <NativeEngineFieldVectorsWriter <?>> fields = new ArrayList <>();
55
57
private boolean finished ;
58
+ private final QuantizationService quantizationService = QuantizationService .getInstance ();
56
59
57
60
public NativeEngines990KnnVectorsWriter (SegmentWriteState segmentWriteState , FlatVectorsWriter flatVectorsWriter ) {
58
61
this .segmentWriteState = segmentWriteState ;
@@ -81,27 +84,14 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
81
84
flatVectorsWriter .flush (maxDoc , sortMap );
82
85
83
86
for (final NativeEngineFieldVectorsWriter <?> field : fields ) {
84
- final FieldInfo fieldInfo = field .getFieldInfo ();
85
- final VectorDataType vectorDataType = extractVectorDataType (fieldInfo );
86
- int totalLiveDocs = field .getVectors ().size ();
87
- if (totalLiveDocs > 0 ) {
88
- final Supplier <KNNVectorValues <?>> knnVectorValuesSupplier = () -> getVectorValues (
89
- vectorDataType ,
90
- field .getDocsWithField (),
91
- field .getVectors ()
92
- );
93
- final QuantizationState quantizationState = train (field .getFieldInfo (), knnVectorValuesSupplier , totalLiveDocs );
94
- final NativeIndexWriter writer = NativeIndexWriter .getWriter (fieldInfo , segmentWriteState , quantizationState );
95
- final KNNVectorValues <?> knnVectorValues = knnVectorValuesSupplier .get ();
96
-
97
- StopWatch stopWatch = new StopWatch ().start ();
98
- writer .flushIndex (knnVectorValues , totalLiveDocs );
99
- long time_in_millis = stopWatch .stop ().totalTime ().millis ();
100
- KNNGraphValue .REFRESH_TOTAL_TIME_IN_MILLIS .incrementBy (time_in_millis );
101
- log .debug ("Flush took {} ms for vector field [{}]" , time_in_millis , fieldInfo .getName ());
102
- } else {
103
- log .debug ("[Flush] No live docs for field {}" , fieldInfo .getName ());
104
- }
87
+ trainAndIndex (
88
+ field .getFieldInfo (),
89
+ (vectorDataType , fieldInfo , fieldVectorsWriter ) -> getKNNVectorValues (vectorDataType , fieldVectorsWriter ),
90
+ NativeIndexWriter ::flushIndex ,
91
+ field ,
92
+ KNNGraphValue .REFRESH_TOTAL_TIME_IN_MILLIS ,
93
+ FLUSH_OPERATION
94
+ );
105
95
}
106
96
}
107
97
@@ -110,29 +100,15 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
110
100
// This will ensure that we are merging the FlatIndex during force merge.
111
101
flatVectorsWriter .mergeOneField (fieldInfo , mergeState );
112
102
113
- final VectorDataType vectorDataType = extractVectorDataType (fieldInfo );
114
- final Supplier <KNNVectorValues <?>> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge (
115
- vectorDataType ,
103
+ // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs
104
+ trainAndIndex (
116
105
fieldInfo ,
117
- mergeState
106
+ this ::getKNNVectorValuesForMerge ,
107
+ NativeIndexWriter ::mergeIndex ,
108
+ mergeState ,
109
+ KNNGraphValue .MERGE_TOTAL_TIME_IN_MILLIS ,
110
+ MERGE_OPERATION
118
111
);
119
- int totalLiveDocs = getLiveDocs (knnVectorValuesSupplier .get ());
120
- if (totalLiveDocs == 0 ) {
121
- log .debug ("[Merge] No live docs for field {}" , fieldInfo .getName ());
122
- return ;
123
- }
124
-
125
- final QuantizationState quantizationState = train (fieldInfo , knnVectorValuesSupplier , totalLiveDocs );
126
- final NativeIndexWriter writer = NativeIndexWriter .getWriter (fieldInfo , segmentWriteState , quantizationState );
127
- final KNNVectorValues <?> knnVectorValues = knnVectorValuesSupplier .get ();
128
-
129
- StopWatch stopWatch = new StopWatch ().start ();
130
-
131
- writer .mergeIndex (knnVectorValues , totalLiveDocs );
132
-
133
- long time_in_millis = stopWatch .stop ().totalTime ().millis ();
134
- KNNGraphValue .MERGE_TOTAL_TIME_IN_MILLIS .incrementBy (time_in_millis );
135
- log .debug ("Merge took {} ms for vector field [{}]" , time_in_millis , fieldInfo .getName ());
136
112
}
137
113
138
114
/**
@@ -181,6 +157,18 @@ public long ramBytesUsed() {
181
157
.sum ();
182
158
}
183
159
160
+ /**
161
+ * Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer.
162
+ *
163
+ * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
164
+ * @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors.
165
+ * @param <T> The type of vectors being processed.
166
+ * @return The {@link KNNVectorValues} associated with the field.
167
+ */
168
+ private <T > KNNVectorValues <T > getKNNVectorValues (final VectorDataType vectorDataType , final NativeEngineFieldVectorsWriter <?> field ) {
169
+ return (KNNVectorValues <T >) KNNVectorValuesFactory .getVectorValues (vectorDataType , field .getDocsWithField (), field .getVectors ());
170
+ }
171
+
184
172
/**
185
173
* Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type.
186
174
*
@@ -195,41 +183,89 @@ private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
195
183
final VectorDataType vectorDataType ,
196
184
final FieldInfo fieldInfo ,
197
185
final MergeState mergeState
198
- ) {
199
- try {
200
- switch (fieldInfo .getVectorEncoding ()) {
201
- case FLOAT32 :
202
- FloatVectorValues mergedFloats = MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
203
- return getVectorValues (vectorDataType , mergedFloats );
204
- case BYTE :
205
- ByteVectorValues mergedBytes = MergedVectorValues .mergeByteVectorValues (fieldInfo , mergeState );
206
- return getVectorValues (vectorDataType , mergedBytes );
207
- default :
208
- throw new IllegalStateException ("Unsupported vector encoding [" + fieldInfo .getVectorEncoding () + "]" );
209
- }
210
- } catch (final IOException e ) {
211
- log .error ("Unable to merge vectors for field [{}]" , fieldInfo .getName (), e );
212
- throw new IllegalStateException ("Unable to merge vectors for field [" + fieldInfo .getName () + "]" , e );
186
+ ) throws IOException {
187
+ switch (fieldInfo .getVectorEncoding ()) {
188
+ case FLOAT32 :
189
+ FloatVectorValues mergedFloats = MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
190
+ return (KNNVectorValues <T >) KNNVectorValuesFactory .getVectorValues (vectorDataType , mergedFloats );
191
+ case BYTE :
192
+ ByteVectorValues mergedBytes = MergedVectorValues .mergeByteVectorValues (fieldInfo , mergeState );
193
+ return (KNNVectorValues <T >) KNNVectorValuesFactory .getVectorValues (vectorDataType , mergedBytes );
194
+ default :
195
+ throw new IllegalStateException ("Unsupported vector encoding [" + fieldInfo .getVectorEncoding () + "]" );
213
196
}
214
197
}
215
198
216
- private QuantizationState train (
199
+ /**
200
+ * Functional interface representing an operation that indexes the provided {@link KNNVectorValues}.
201
+ *
202
+ * @param <T> The type of vectors being processed.
203
+ */
204
+ @ FunctionalInterface
205
+ private interface IndexOperation <T > {
206
+ void buildAndWrite (NativeIndexWriter writer , KNNVectorValues <T > knnVectorValues , int totalLiveDocs ) throws IOException ;
207
+ }
208
+
209
+ /**
210
+ * Functional interface representing a method that retrieves {@link KNNVectorValues} based on
211
+ * the vector data type, field information, and the merge state.
212
+ *
213
+ * @param <DataType> The type of the data representing the vector (e.g., {@link VectorDataType}).
214
+ * @param <FieldInfo> The metadata about the field.
215
+ * @param <MergeState> The state of the merge operation.
216
+ * @param <Result> The result of the retrieval, typically {@link KNNVectorValues}.
217
+ */
218
+ @ FunctionalInterface
219
+ private interface VectorValuesRetriever <DataType , FieldInfo , MergeState , Result > {
220
+ Result apply (DataType vectorDataType , FieldInfo fieldInfo , MergeState mergeState ) throws IOException ;
221
+ }
222
+
223
+ /**
224
+ * Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values
225
+ * based on the provided vector data type and applies the specified index operation, potentially including quantization if needed.
226
+ *
227
+ * @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
228
+ * @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type,
229
+ * field information, and additional context (e.g., merge state or field writer).
230
+ * @param indexOperation A functional interface that performs the indexing operation using the retrieved
231
+ * {@link KNNVectorValues}.
232
+ * @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}).
233
+ * From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information
234
+ * @param <T> The type of vectors being processed.
235
+ * @param <C> The type of the context needed for retrieving the vector values.
236
+ * @throws IOException If an I/O error occurs during the processing.
237
+ */
238
+ private <T , C > void trainAndIndex (
217
239
final FieldInfo fieldInfo ,
218
- final Supplier <KNNVectorValues <?>> knnVectorValuesSupplier ,
219
- final int totalLiveDocs
240
+ final VectorValuesRetriever <VectorDataType , FieldInfo , C , KNNVectorValues <T >> vectorValuesRetriever ,
241
+ final IndexOperation <T > indexOperation ,
242
+ final C VectorProcessingContext ,
243
+ final KNNGraphValue graphBuildTime ,
244
+ final String operationName
220
245
) throws IOException {
221
-
222
- final QuantizationService quantizationService = QuantizationService . getInstance ( );
223
- final QuantizationParams quantizationParams = quantizationService .getQuantizationParams (fieldInfo );
246
+ final VectorDataType vectorDataType = extractVectorDataType ( fieldInfo );
247
+ KNNVectorValues < T > knnVectorValues = vectorValuesRetriever . apply ( vectorDataType , fieldInfo , VectorProcessingContext );
248
+ QuantizationParams quantizationParams = quantizationService .getQuantizationParams (fieldInfo );
224
249
QuantizationState quantizationState = null ;
250
+ // Count the docIds
251
+ int totalLiveDocs = getLiveDocs (vectorValuesRetriever .apply (vectorDataType , fieldInfo , VectorProcessingContext ));
225
252
if (quantizationParams != null && totalLiveDocs > 0 ) {
226
253
initQuantizationStateWriterIfNecessary ();
227
- KNNVectorValues <?> knnVectorValues = knnVectorValuesSupplier .get ();
228
254
quantizationState = quantizationService .train (quantizationParams , knnVectorValues , totalLiveDocs );
229
255
quantizationStateWriter .writeState (fieldInfo .getFieldNumber (), quantizationState );
230
256
}
257
+ NativeIndexWriter writer = (quantizationParams != null )
258
+ ? NativeIndexWriter .getWriter (fieldInfo , segmentWriteState , quantizationState )
259
+ : NativeIndexWriter .getWriter (fieldInfo , segmentWriteState );
260
+
261
+ knnVectorValues = vectorValuesRetriever .apply (vectorDataType , fieldInfo , VectorProcessingContext );
231
262
232
- return quantizationState ;
263
+ StopWatch stopWatch = new StopWatch ();
264
+ stopWatch .start ();
265
+ indexOperation .buildAndWrite (writer , knnVectorValues , totalLiveDocs );
266
+ long time_in_millis = stopWatch .totalTime ().millis ();
267
+ graphBuildTime .incrementBy (time_in_millis );
268
+ log .warn ("Graph build took " + time_in_millis + " ms for " + operationName );
233
269
}
234
270
235
271
/**
0 commit comments