Skip to content

Commit b0d82b7

Browse files
Revert "Makes sure KNNVectorValues aren't recreated unnecessarily when quantization isn't needed (opensearch-project#2133) (opensearch-project#2140)" (opensearch-project#2161)
This reverts commit ca6b03f. Signed-off-by: Naveen Tatikonda <[email protected]>
1 parent fbec0aa commit b0d82b7

File tree

5 files changed

+105
-622
lines changed

5 files changed

+105
-622
lines changed

release-notes/opensearch-knn.release-notes-2.17.0.0.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ Compatible with OpenSearch 2.17.0
2121
* Fix memory overflow caused by cache behavior [#2015](https://github.com/opensearch-project/k-NN/pull/2015)
2222
* Use correct type for binary vector in ivf training [#2086](https://github.com/opensearch-project/k-NN/pull/2086)
2323
* Switch MINGW32 to MINGW64 [#2090](https://github.com/opensearch-project/k-NN/pull/2090)
24-
* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133)
2524
### Infrastructure
2625
* Parallelize make to reduce build time [#2006] (https://github.com/opensearch-project/k-NN/pull/2006)
2726
### Maintenance

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java

Lines changed: 104 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,20 @@
2525
import org.apache.lucene.util.IOUtils;
2626
import org.apache.lucene.util.RamUsageEstimator;
2727
import org.opensearch.common.StopWatch;
28+
import org.opensearch.knn.index.quantizationservice.QuantizationService;
2829
import org.opensearch.knn.index.VectorDataType;
2930
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
30-
import org.opensearch.knn.index.quantizationservice.QuantizationService;
3131
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
32+
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
3233
import org.opensearch.knn.plugin.stats.KNNGraphValue;
3334
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
3435
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
3536

3637
import java.io.IOException;
3738
import java.util.ArrayList;
3839
import java.util.List;
39-
import java.util.function.Supplier;
4040

4141
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
42-
import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues;
4342

4443
/**
4544
* A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines.
@@ -48,11 +47,15 @@
4847
public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
4948
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class);
5049

50+
private static final String FLUSH_OPERATION = "flush";
51+
private static final String MERGE_OPERATION = "merge";
52+
5153
private final SegmentWriteState segmentWriteState;
5254
private final FlatVectorsWriter flatVectorsWriter;
5355
private KNN990QuantizationStateWriter quantizationStateWriter;
5456
private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList<>();
5557
private boolean finished;
58+
private final QuantizationService quantizationService = QuantizationService.getInstance();
5659

5760
public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) {
5861
this.segmentWriteState = segmentWriteState;
@@ -81,27 +84,14 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
8184
flatVectorsWriter.flush(maxDoc, sortMap);
8285

8386
for (final NativeEngineFieldVectorsWriter<?> field : fields) {
84-
final FieldInfo fieldInfo = field.getFieldInfo();
85-
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
86-
int totalLiveDocs = field.getVectors().size();
87-
if (totalLiveDocs > 0) {
88-
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
89-
vectorDataType,
90-
field.getDocsWithField(),
91-
field.getVectors()
92-
);
93-
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
94-
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
95-
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
96-
97-
StopWatch stopWatch = new StopWatch().start();
98-
writer.flushIndex(knnVectorValues, totalLiveDocs);
99-
long time_in_millis = stopWatch.stop().totalTime().millis();
100-
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
101-
log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
102-
} else {
103-
log.debug("[Flush] No live docs for field {}", fieldInfo.getName());
104-
}
87+
trainAndIndex(
88+
field.getFieldInfo(),
89+
(vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter),
90+
NativeIndexWriter::flushIndex,
91+
field,
92+
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS,
93+
FLUSH_OPERATION
94+
);
10595
}
10696
}
10797

@@ -110,29 +100,15 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
110100
// This will ensure that we are merging the FlatIndex during force merge.
111101
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);
112102

113-
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
114-
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge(
115-
vectorDataType,
103+
// For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs
104+
trainAndIndex(
116105
fieldInfo,
117-
mergeState
106+
this::getKNNVectorValuesForMerge,
107+
NativeIndexWriter::mergeIndex,
108+
mergeState,
109+
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS,
110+
MERGE_OPERATION
118111
);
119-
int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get());
120-
if (totalLiveDocs == 0) {
121-
log.debug("[Merge] No live docs for field {}", fieldInfo.getName());
122-
return;
123-
}
124-
125-
final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
126-
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
127-
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
128-
129-
StopWatch stopWatch = new StopWatch().start();
130-
131-
writer.mergeIndex(knnVectorValues, totalLiveDocs);
132-
133-
long time_in_millis = stopWatch.stop().totalTime().millis();
134-
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
135-
log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
136112
}
137113

138114
/**
@@ -181,6 +157,18 @@ public long ramBytesUsed() {
181157
.sum();
182158
}
183159

160+
/**
161+
* Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer.
162+
*
163+
* @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
164+
* @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors.
165+
* @param <T> The type of vectors being processed.
166+
* @return The {@link KNNVectorValues} associated with the field.
167+
*/
168+
private <T> KNNVectorValues<T> getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter<?> field) {
169+
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());
170+
}
171+
184172
/**
185173
* Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type.
186174
*
@@ -195,41 +183,89 @@ private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
195183
final VectorDataType vectorDataType,
196184
final FieldInfo fieldInfo,
197185
final MergeState mergeState
198-
) {
199-
try {
200-
switch (fieldInfo.getVectorEncoding()) {
201-
case FLOAT32:
202-
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
203-
return getVectorValues(vectorDataType, mergedFloats);
204-
case BYTE:
205-
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
206-
return getVectorValues(vectorDataType, mergedBytes);
207-
default:
208-
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
209-
}
210-
} catch (final IOException e) {
211-
log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e);
212-
throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e);
186+
) throws IOException {
187+
switch (fieldInfo.getVectorEncoding()) {
188+
case FLOAT32:
189+
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
190+
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats);
191+
case BYTE:
192+
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
193+
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes);
194+
default:
195+
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
213196
}
214197
}
215198

216-
private QuantizationState train(
199+
/**
200+
* Functional interface representing an operation that indexes the provided {@link KNNVectorValues}.
201+
*
202+
* @param <T> The type of vectors being processed.
203+
*/
204+
@FunctionalInterface
205+
private interface IndexOperation<T> {
206+
void buildAndWrite(NativeIndexWriter writer, KNNVectorValues<T> knnVectorValues, int totalLiveDocs) throws IOException;
207+
}
208+
209+
/**
210+
* Functional interface representing a method that retrieves {@link KNNVectorValues} based on
211+
* the vector data type, field information, and the merge state.
212+
*
213+
* @param <DataType> The type of the data representing the vector (e.g., {@link VectorDataType}).
214+
* @param <FieldInfo> The metadata about the field.
215+
* @param <MergeState> The state of the merge operation.
216+
* @param <Result> The result of the retrieval, typically {@link KNNVectorValues}.
217+
*/
218+
@FunctionalInterface
219+
private interface VectorValuesRetriever<DataType, FieldInfo, MergeState, Result> {
220+
Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException;
221+
}
222+
223+
/**
224+
* Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values
225+
* based on the provided vector data type and applies the specified index operation, potentially including quantization if needed.
226+
*
227+
* @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
228+
* @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type,
229+
* field information, and additional context (e.g., merge state or field writer).
230+
* @param indexOperation A functional interface that performs the indexing operation using the retrieved
231+
* {@link KNNVectorValues}.
232+
* @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}).
233+
* From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information
234+
* @param <T> The type of vectors being processed.
235+
* @param <C> The type of the context needed for retrieving the vector values.
236+
* @throws IOException If an I/O error occurs during the processing.
237+
*/
238+
private <T, C> void trainAndIndex(
217239
final FieldInfo fieldInfo,
218-
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
219-
final int totalLiveDocs
240+
final VectorValuesRetriever<VectorDataType, FieldInfo, C, KNNVectorValues<T>> vectorValuesRetriever,
241+
final IndexOperation<T> indexOperation,
242+
final C VectorProcessingContext,
243+
final KNNGraphValue graphBuildTime,
244+
final String operationName
220245
) throws IOException {
221-
222-
final QuantizationService quantizationService = QuantizationService.getInstance();
223-
final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
246+
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
247+
KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
248+
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
224249
QuantizationState quantizationState = null;
250+
// Count the docIds
251+
int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext));
225252
if (quantizationParams != null && totalLiveDocs > 0) {
226253
initQuantizationStateWriterIfNecessary();
227-
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
228254
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
229255
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
230256
}
257+
NativeIndexWriter writer = (quantizationParams != null)
258+
? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)
259+
: NativeIndexWriter.getWriter(fieldInfo, segmentWriteState);
260+
261+
knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
231262

232-
return quantizationState;
263+
StopWatch stopWatch = new StopWatch();
264+
stopWatch.start();
265+
indexOperation.buildAndWrite(writer, knnVectorValues, totalLiveDocs);
266+
long time_in_millis = stopWatch.totalTime().millis();
267+
graphBuildTime.incrementBy(time_in_millis);
268+
log.warn("Graph build took " + time_in_millis + " ms for " + operationName);
233269
}
234270

235271
/**

0 commit comments

Comments
 (0)