Skip to content

Commit

Permalink
Merge branch 'opensearch-project:main' into Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikasht34 authored Feb 18, 2025
2 parents cc2df76 + 0df5f62 commit 48c5fa2
Show file tree
Hide file tree
Showing 9 changed files with 410 additions and 73 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
* Fix derived source for binary and byte vectors [#2533](https://github.com/opensearch-project/k-NN/pull/2533/)
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

import java.io.IOException;

@Log4j2
abstract class AbstractPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
/**
* Utility method for formatting the vector values based on the vector data type. KNNVectorValues must be advanced
* to the correct position.
*
* @param fieldInfo fieldinfo for the vector field
* @param vectorValues vector values of the field. getVector or getConditionalVector should return expected vector.
* @return vector formatted based on the vector data type
* @throws IOException if unable to deserialize stored vector
*/
protected Object formatVector(FieldInfo fieldInfo, KNNVectorValues<?> vectorValues) throws IOException {
Object vectorValue = vectorValues.getVector();
// If the vector value is a byte[], we must deserialize
if (vectorValue instanceof byte[]) {
BytesRef vectorBytesRef = new BytesRef((byte[]) vectorValue);
VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo);
return KNNVectorFieldMapperUtil.deserializeStoredVector(vectorBytesRef, vectorDataType);
}
return vectorValues.conditionalCloneVector();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

@Log4j2
@AllArgsConstructor
public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
public class NestedPerFieldDerivedVectorInjector extends AbstractPerFieldDerivedVectorInjector {

private final FieldInfo childFieldInfo;
private final DerivedSourceReaders derivedSourceReaders;
Expand Down Expand Up @@ -116,7 +116,7 @@ public void inject(int parentDocId, Map<String, Object> sourceAsMap) throws IOEx
reconstructedSource.add(position, new HashMap<>());
positions.add(position, docId);
}
reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector());
reconstructedSource.get(position).put(childFieldName, formatVector(childFieldInfo, vectorValues));
offsetPositionsIndex = position + 1;
}
sourceAsMap.put(parentFieldName, reconstructedSource);
Expand All @@ -137,7 +137,7 @@ private void injectObject(int docId, Map<String, Object> sourceAsMap) throws IOE
String field = fields[i];
currentMap = (Map<String, Object>) currentMap.computeIfAbsent(field, k -> new HashMap<>());
}
currentMap.put(fields[fields.length - 1], vectorValues.getVector());
currentMap.put(fields[fields.length - 1], formatVector(childFieldInfo, vectorValues));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
/**
* {@link PerFieldDerivedVectorInjector} for root fields (i.e. non nested fields).
*/
class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {
class RootPerFieldDerivedVectorInjector extends AbstractPerFieldDerivedVectorInjector {

private final FieldInfo fieldInfo;
private final CheckedSupplier<KNNVectorValues<?>, IOException> vectorValuesSupplier;
Expand All @@ -40,7 +40,7 @@ public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReade
public void inject(int docId, Map<String, Object> sourceAsMap) throws IOException {
KNNVectorValues<?> vectorValues = vectorValuesSupplier.get();
if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) {
sourceAsMap.put(fieldInfo.name, vectorValues.conditionalCloneVector());
sourceAsMap.put(fieldInfo.name, formatVector(fieldInfo, vectorValues));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ public static StoredField createStoredFieldForFloatVector(String name, float[] v
* @return either int[] or float[] of corresponding vector
*/
public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) {
if (VectorDataType.BYTE == vectorDataType) {
if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) {
byte[] bytes = storedVector.bytes;
int[] byteAsIntArray = new int[bytes.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i]);
int[] byteAsIntArray = new int[storedVector.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i + storedVector.offset]);
return byteAsIntArray;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ public void testStoredFields_whenVectorIsByteType_thenSucceed() {
assertArrayEquals(byteAsIntArray, (int[]) vector);
}

public void testStoredFields_whenVectorIsBinaryType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForByteVector(TEST_FIELD_NAME, TEST_BYTE_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
assertEquals(TEST_BYTE_VECTOR, storedField.binaryValue().bytes);
Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.BINARY);
assertTrue(vector instanceof int[]);
int[] byteAsIntArray = new int[TEST_BYTE_VECTOR.length];
Arrays.setAll(byteAsIntArray, i -> TEST_BYTE_VECTOR[i]);
assertArrayEquals(byteAsIntArray, (int[]) vector);
}

public void testStoredFields_whenVectorIsFloatType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForFloatVector(TEST_FIELD_NAME, TEST_FLOAT_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
Expand Down
Loading

0 comments on commit 48c5fa2

Please sign in to comment.