Skip to content

Commit

Permalink
Fix derived source for binary and byte vectors
Browse files Browse the repository at this point in the history
For binary and byte vectors, for derived source, we were not formatting
them before adding them back to the source. Thus, they were binary
strings in the source. This change fixes this formatting to format them
as ints before adding back.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Feb 18, 2025
1 parent 45ecb5b commit 63cf35c
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

import java.io.IOException;

@Log4j2
abstract class AbstractPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
/**
* Utility method for formatting the vector values based on the vector data type. KNNVectorValues must be advanced
* to the correct position.
*
* @param fieldInfo fieldinfo for the vector field
* @param vectorValues vector values of the field. getVector or getConditionalVector should return expected vector.
* @return vector formatted based on the vector data type
* @throws IOException if unable to deserialize stored vector
*/
protected Object formatVector(FieldInfo fieldInfo, KNNVectorValues<?> vectorValues) throws IOException {
Object vectorValue = vectorValues.getVector();
// If the vector value is a byte[], we must deserialize
if (vectorValue instanceof byte[]) {
BytesRef vectorBytesRef = new BytesRef((byte[]) vectorValue);
VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo);
return KNNVectorFieldMapperUtil.deserializeStoredVector(vectorBytesRef, vectorDataType);
}
return vectorValues.conditionalCloneVector();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

@Log4j2
@AllArgsConstructor
public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
public class NestedPerFieldDerivedVectorInjector extends AbstractPerFieldDerivedVectorInjector {

private final FieldInfo childFieldInfo;
private final DerivedSourceReaders derivedSourceReaders;
Expand Down Expand Up @@ -116,7 +116,7 @@ public void inject(int parentDocId, Map<String, Object> sourceAsMap) throws IOEx
reconstructedSource.add(position, new HashMap<>());
positions.add(position, docId);
}
reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector());
reconstructedSource.get(position).put(childFieldName, formatVector(childFieldInfo, vectorValues));
offsetPositionsIndex = position + 1;
}
sourceAsMap.put(parentFieldName, reconstructedSource);
Expand All @@ -137,7 +137,7 @@ private void injectObject(int docId, Map<String, Object> sourceAsMap) throws IOE
String field = fields[i];
currentMap = (Map<String, Object>) currentMap.computeIfAbsent(field, k -> new HashMap<>());
}
currentMap.put(fields[fields.length - 1], vectorValues.getVector());
currentMap.put(fields[fields.length - 1], formatVector(childFieldInfo, vectorValues));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
/**
* {@link PerFieldDerivedVectorInjector} for root fields (i.e. non nested fields).
*/
class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
class RootPerFieldDerivedVectorInjector extends AbstractPerFieldDerivedVectorInjector {

private final FieldInfo fieldInfo;
private final CheckedSupplier<KNNVectorValues<?>, IOException> vectorValuesSupplier;
Expand All @@ -40,7 +40,7 @@ public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReade
public void inject(int docId, Map<String, Object> sourceAsMap) throws IOException {
KNNVectorValues<?> vectorValues = vectorValuesSupplier.get();
if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) {
sourceAsMap.put(fieldInfo.name, vectorValues.conditionalCloneVector());
sourceAsMap.put(fieldInfo.name, formatVector(fieldInfo, vectorValues));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ public static StoredField createStoredFieldForFloatVector(String name, float[] v
* @return either int[] or float[] of corresponding vector
*/
public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) {
if (VectorDataType.BYTE == vectorDataType) {
if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) {
byte[] bytes = storedVector.bytes;
int[] byteAsIntArray = new int[bytes.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i]);
int[] byteAsIntArray = new int[storedVector.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i + storedVector.offset]);
return byteAsIntArray;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ public void testStoredFields_whenVectorIsByteType_thenSucceed() {
assertArrayEquals(byteAsIntArray, (int[]) vector);
}

public void testStoredFields_whenVectorIsBinaryType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForByteVector(TEST_FIELD_NAME, TEST_BYTE_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
assertEquals(TEST_BYTE_VECTOR, storedField.binaryValue().bytes);
Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.BINARY);
assertTrue(vector instanceof int[]);
int[] byteAsIntArray = new int[TEST_BYTE_VECTOR.length];
Arrays.setAll(byteAsIntArray, i -> TEST_BYTE_VECTOR[i]);
assertArrayEquals(byteAsIntArray, (int[]) vector);
}

public void testStoredFields_whenVectorIsFloatType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForFloatVector(TEST_FIELD_NAME, TEST_FLOAT_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
Expand Down
Loading

0 comments on commit 63cf35c

Please sign in to comment.