Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix derived source for binary and byte vectors #2533

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
* Fix derived source for binary and byte vectors [#2533](https://github.com/opensearch-project/k-NN/pull/2533/)
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

import java.io.IOException;

@Log4j2
abstract class AbstractPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
/**
* Utility method for formatting the vector values based on the vector data type. KNNVectorValues must be advanced
* to the correct position.
*
* @param fieldInfo fieldinfo for the vector field
* @param vectorValues vector values of the field. getVector or getConditionalVector should return expected vector.
* @return vector formatted based on the vector data type
* @throws IOException if unable to deserialize stored vector
*/
protected Object formatVector(FieldInfo fieldInfo, KNNVectorValues<?> vectorValues) throws IOException {
Object vectorValue = vectorValues.getVector();
// If the vector value is a byte[], we must deserialize
if (vectorValue instanceof byte[]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the datatype of the field here, rather than instance of check on byte[].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a byte[] in order to deserialize, so this check is required. In terms of displaying, deserializeStoredVector takes the vectorDataType, so we can be sure that it will format it properly.

BytesRef vectorBytesRef = new BytesRef((byte[]) vectorValue);
VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo);
return KNNVectorFieldMapperUtil.deserializeStoredVector(vectorBytesRef, vectorDataType);
Comment on lines +34 to +35
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate why we need to do this? I am trying to understand this like why we need it, since we already have byte[]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this is what the IT I added looks like before:

  2> java.lang.AssertionError: Docs do not match: 1 expected:<{test_vector=[115, -43, 26, -69, -40, -100, -72, 25, 111, 14, -5, 104, -110, -7, 77, 104]}> but was:<{test_vector=c9Uau9icuBlvDvtokvlNaA==}>

Basically, source is expected to be an int array, but because we are adding a byte array, it gets serialized as a byte string

}
return vectorValues.conditionalCloneVector();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

@Log4j2
@AllArgsConstructor
public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
public class NestedPerFieldDerivedVectorInjector extends AbstractPerFieldDerivedVectorInjector {

private final FieldInfo childFieldInfo;
private final DerivedSourceReaders derivedSourceReaders;
Expand Down Expand Up @@ -116,7 +116,7 @@ public void inject(int parentDocId, Map<String, Object> sourceAsMap) throws IOEx
reconstructedSource.add(position, new HashMap<>());
positions.add(position, docId);
}
reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector());
reconstructedSource.get(position).put(childFieldName, formatVector(childFieldInfo, vectorValues));
offsetPositionsIndex = position + 1;
}
sourceAsMap.put(parentFieldName, reconstructedSource);
Expand All @@ -137,7 +137,7 @@ private void injectObject(int docId, Map<String, Object> sourceAsMap) throws IOE
String field = fields[i];
currentMap = (Map<String, Object>) currentMap.computeIfAbsent(field, k -> new HashMap<>());
}
currentMap.put(fields[fields.length - 1], vectorValues.getVector());
currentMap.put(fields[fields.length - 1], formatVector(childFieldInfo, vectorValues));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
/**
* {@link PerFieldDerivedVectorInjector} for root fields (i.e. non nested fields).
*/
class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
class RootPerFieldDerivedVectorInjector extends AbstractPerFieldDerivedVectorInjector {

private final FieldInfo fieldInfo;
private final CheckedSupplier<KNNVectorValues<?>, IOException> vectorValuesSupplier;
Expand All @@ -40,7 +40,7 @@ public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReade
public void inject(int docId, Map<String, Object> sourceAsMap) throws IOException {
KNNVectorValues<?> vectorValues = vectorValuesSupplier.get();
if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) {
sourceAsMap.put(fieldInfo.name, vectorValues.conditionalCloneVector());
sourceAsMap.put(fieldInfo.name, formatVector(fieldInfo, vectorValues));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ public static StoredField createStoredFieldForFloatVector(String name, float[] v
* @return either int[] or float[] of corresponding vector
*/
public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) {
if (VectorDataType.BYTE == vectorDataType) {
if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) {
byte[] bytes = storedVector.bytes;
int[] byteAsIntArray = new int[bytes.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i]);
int[] byteAsIntArray = new int[storedVector.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i + storedVector.offset]);
return byteAsIntArray;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ public void testStoredFields_whenVectorIsByteType_thenSucceed() {
assertArrayEquals(byteAsIntArray, (int[]) vector);
}

public void testStoredFields_whenVectorIsBinaryType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForByteVector(TEST_FIELD_NAME, TEST_BYTE_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
assertEquals(TEST_BYTE_VECTOR, storedField.binaryValue().bytes);
Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.BINARY);
assertTrue(vector instanceof int[]);
int[] byteAsIntArray = new int[TEST_BYTE_VECTOR.length];
Arrays.setAll(byteAsIntArray, i -> TEST_BYTE_VECTOR[i]);
assertArrayEquals(byteAsIntArray, (int[]) vector);
}

public void testStoredFields_whenVectorIsFloatType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForFloatVector(TEST_FIELD_NAME, TEST_FLOAT_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
Expand Down
Loading
Loading