From aa3551bc367a7c8c338af947b4587e93f849e0c8 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Mon, 3 Feb 2025 15:29:25 -0800 Subject: [PATCH] Fix bytes bug and add uTs for derived source Fixes a bug in the derived source writer where we are reading the entire bytes array from the bytes ref instead of just the offset+length. Along with that, touches up the ParentChildHelper (no prod impact) and also adds some unit tests. Signed-off-by: John Mazanec --- .../codec/KNN10010Codec/KNN10010Codec.java | 45 +++++- .../DerivedSourceStoredFieldsWriter.java | 2 +- .../codec/KNN9120Codec/KNN9120Codec.java | 6 + .../knn/index/codec/KNNCodecVersion.java | 1 + .../derivedsource/ParentChildHelper.java | 19 ++- .../VectorValueExtractorStrategy.java | 2 +- .../DerivedSourceStoredFieldsWriterTests.java | 42 ++++++ .../DerivedSourceStoredFieldVisitorTests.java | 43 ++++++ .../DerivedSourceVectorInjectorTests.java | 136 ++++++++++++++++++ .../derivedsource/ParentChildHelperTests.java | 42 ++++++ ...ieldDerivedVectorInjectorFactoryTests.java | 29 ++++ ...ootPerFieldDerivedVectorInjectorTests.java | 85 +++++++++++ 12 files changed, 445 insertions(+), 7 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java index daacedcbc..97848bb35 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java @@ -11,9 +11,13 @@ import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.StoredFieldsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN9120Codec.DerivedSourceStoredFieldsFormat; import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; /** * KNN Codec that wraps the Lucene Codec which is part of Lucene 10.0.1 @@ -24,12 +28,15 @@ public class KNN10010Codec extends FilterCodec { private static final KNNCodecVersion VERSION = KNNCodecVersion.V_10_01_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final StoredFieldsFormat storedFieldsFormat; + + private final MapperService mapperService; /** * No arg constructor that uses Lucene99 as the delegate */ public KNN10010Codec() { - this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat(), null); } /** @@ -40,10 +47,12 @@ public KNN10010Codec() { * @param knnVectorsFormat per field format for KnnVector */ @Builder - protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { + protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat, MapperService mapperService) { super(VERSION.getCodecName(), delegate); knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; + this.mapperService = mapperService; + this.storedFieldsFormat = getStoredFieldsFormat(); } @Override @@ -60,4 +69,36 @@ public CompoundFormat compoundFormat() { public KnnVectorsFormat knnVectorsFormat() { return perFieldKnnVectorsFormat; } + + @Override + public StoredFieldsFormat storedFieldsFormat() { + return storedFieldsFormat; + } + + private StoredFieldsFormat getStoredFieldsFormat() { + DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> { + if (segmentReadState.fieldInfos.hasVectorValues()) { + return knnVectorsFormat().fieldsReader(segmentReadState); + } + return null; + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasDocValues()) { + return docValuesFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState) -> { + if (segmentReadState.fieldInfos.hasPostings()) { + return postingsFormat().fieldsProducer(segmentReadState); + } + return null; + + }, (segmentReadState -> { + if (segmentReadState.fieldInfos.hasNorms()) { + return normsFormat().normsProducer(segmentReadState); + } + return null; + })); + return new DerivedSourceStoredFieldsFormat(delegate.storedFieldsFormat(), derivedSourceReadersSupplier, mapperService); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java index c585b09f7..0c43f6a49 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -79,7 +79,7 @@ public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOExceptio // Reference: // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 Tuple> mapTuple = XContentHelper.convertToMap( - BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)), + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)), true, MediaTypeRegistry.JSON ); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index b8a5e6a12..5e40faf1a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -25,6 +25,7 @@ public class KNN9120Codec extends FilterCodec { private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final StoredFieldsFormat storedFieldsFormat; private final MapperService mapperService; @@ -48,6 +49,7 @@ protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; this.mapperService = mapperService; + this.storedFieldsFormat = getStoredFieldsFormat(); } @Override @@ -67,6 +69,10 @@ public KnnVectorsFormat knnVectorsFormat() { @Override public StoredFieldsFormat storedFieldsFormat() { + return storedFieldsFormat; + } + + private StoredFieldsFormat getStoredFieldsFormat() { DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> { if (segmentReadState.fieldInfos.hasVectorValues()) { return knnVectorsFormat().fieldsReader(segmentReadState); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index f80dd23cb..0f03170c2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -143,6 +143,7 @@ public enum KNNCodecVersion { (userCodec, mapperService) -> KNN10010Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .mapperService(mapperService) .build(), KNN10010Codec::new ); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java index 534cf93d7..ae249b1b3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java @@ -15,9 +15,12 @@ public class ParentChildHelper { * this would return "parent.to". * * @param field nested field path - * @return parent field path without the child + * @return parent field path without the child. Null if no parent exists */ public static String getParentField(String field) { + if (field == null) { + return null; + } int lastDot = field.lastIndexOf('.'); if (lastDot == -1) { return null; @@ -30,10 +33,16 @@ public static String getParentField(String field) { * return "child". * * @param field nested field path - * @return child field path without the parent path + * @return child field path without the parent path. Null if no child exists */ public static String getChildField(String field) { + if (field == null) { + return null; + } int lastDot = field.lastIndexOf('.'); + if (lastDot == -1) { + return null; + } return field.substring(lastDot + 1); } @@ -46,7 +55,11 @@ public static String getChildField(String field) { * @return sibling field path */ public static String constructSiblingField(String field, String sibling) { - return getParentField(field) + "." + sibling; + String parent = getParentField(field); + if (parent == null) { + return sibling; + } + return parent + "." + sibling; } /** diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java index 04ed007e1..7aafae308 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java @@ -21,7 +21,7 @@ /** * Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator} */ -interface VectorValueExtractorStrategy { +public interface VectorValueExtractorStrategy { /** * Extract a float vector from KNNVectorValuesIterator. diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java new file mode 100644 index 000000000..2953539ad --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class DerivedSourceStoredFieldsWriterTests extends KNNTestCase { + + @SneakyThrows + public void testWriteField() { + StoredFieldsWriter delegate = mock(StoredFieldsWriter.class); + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(); + List fields = List.of("test"); + + DerivedSourceStoredFieldsWriter derivedSourceStoredFieldsWriter = new DerivedSourceStoredFieldsWriter(delegate, fields); + + Map source = Map.of("test", new float[] { 1.0f, 2.0f, 3.0f }, "text_field", "text_value"); + BytesStreamOutput bStream = new BytesStreamOutput(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(source); + builder.close(); + byte[] originalBytes = bStream.bytes().toBytesRef().bytes; + byte[] shiftedBytes = new byte[originalBytes.length + 2]; + System.arraycopy(originalBytes, 0, shiftedBytes, 1, originalBytes.length); + derivedSourceStoredFieldsWriter.writeField(fieldInfo, new BytesRef(shiftedBytes, 1, originalBytes.length)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java new file mode 100644 index 000000000..943146a91 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DerivedSourceStoredFieldVisitorTests extends KNNTestCase { + + public void testBinaryField() throws Exception { + StoredFieldVisitor delegate = mock(StoredFieldVisitor.class); + doAnswer(invocationOnMock -> null).when(delegate).binaryField(any(), any()); + DerivedSourceVectorInjector derivedSourceVectorInjector = mock(DerivedSourceVectorInjector.class); + when(derivedSourceVectorInjector.injectVectors(anyInt(), any())).thenReturn(new byte[0]); + DerivedSourceStoredFieldVisitor derivedSourceStoredFieldVisitor = new DerivedSourceStoredFieldVisitor( + delegate, + 0, + derivedSourceVectorInjector + ); + + // When field is not _source, then do not call the injector + derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), null); + verify(derivedSourceVectorInjector, times(0)).injectVectors(anyInt(), any()); + verify(delegate, times(1)).binaryField(any(), any()); + + // When field is not _source, then do call the injector + derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(), null); + verify(derivedSourceVectorInjector, times(1)).injectVectors(anyInt(), any()); + verify(delegate, times(2)).binaryField(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java new file mode 100644 index 000000000..1fa4b9364 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; + +public class DerivedSourceVectorInjectorTests extends KNNTestCase { + + @SneakyThrows + @SuppressWarnings("unchecked") + public void testInjectVectors() { + List fields = List.of( + KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() + ); + + Map fieldToVector = Collections.unmodifiableMap(new HashMap<>() { + { + put("test1", new float[] { 1.0f, 2.0f, 3.0f }); + put("test2", new float[] { 4.0f, 5.0f, 6.0f, 7.0f }); + put("test3", new float[] { 7.0f, 8.0f, 9.0f, 1.0f, 3.0f, 4.0f }); + put("test4", null); + } + }); + + try (MockedStatic factory = Mockito.mockStatic(PerFieldDerivedVectorInjectorFactory.class)) { + factory.when(() -> PerFieldDerivedVectorInjectorFactory.create(any(), any(), any())).thenAnswer(invocation -> { + FieldInfo fieldInfo = invocation.getArgument(0); + return (PerFieldDerivedVectorInjector) (docId, sourceAsMap) -> { + float[] vector = fieldToVector.get(fieldInfo.name); + if (vector != null) { + sourceAsMap.put(fieldInfo.name, vector); + } + }; + }); + + DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( + new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + null, + fields + ); + + int docId = 2; + String existingFieldKey = "existingField"; + String existingFieldValue = "existingField"; + Map source = Map.of(existingFieldKey, existingFieldValue); + byte[] originalSourceBytes = mapToBytes(source); + byte[] modifiedSourceByttes = derivedSourceVectorInjector.injectVectors(docId, originalSourceBytes); + Map modifiedSource = bytesToMap(modifiedSourceByttes); + + assertEquals(existingFieldValue, modifiedSource.get(existingFieldKey)); + + assertArrayEquals(fieldToVector.get("test1"), toFloatArray((List) modifiedSource.get("test1")), 0.000001f); + assertArrayEquals(fieldToVector.get("test2"), toFloatArray((List) modifiedSource.get("test2")), 0.000001f); + assertArrayEquals(fieldToVector.get("test3"), toFloatArray((List) modifiedSource.get("test3")), 0.000001f); + assertFalse(modifiedSource.containsKey("test4")); + } + } + + @SneakyThrows + private byte[] mapToBytes(Map map) { + + BytesStreamOutput bStream = new BytesStreamOutput(1024); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(map); + builder.close(); + return BytesReference.toBytes(BytesReference.bytes(builder)); + } + + private float[] toFloatArray(List list) { + float[] array = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i).floatValue(); + } + return array; + } + + private Map bytesToMap(byte[] bytes) { + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytes)), + true, + MediaTypeRegistry.getDefaultMediaType() + ); + + return mapTuple.v2(); + } + + @SneakyThrows + public void testShouldInject() { + + List fields = List.of( + KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() + ); + + try ( + DerivedSourceVectorInjector vectorInjector = new DerivedSourceVectorInjector( + new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + null, + fields + ) + ) { + assertTrue(vectorInjector.shouldInject(null, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1", "test2", "test3" }, null)); + assertTrue(vectorInjector.shouldInject(null, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2", "test3" })); + assertFalse(vectorInjector.shouldInject(null, new String[] { "test1", "test2", "test3" })); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java new file mode 100644 index 000000000..085222460 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.opensearch.knn.KNNTestCase; + +public class ParentChildHelperTests extends KNNTestCase { + + public void testGetParentField() { + assertEquals("parent.to", ParentChildHelper.getParentField("parent.to.child")); + assertEquals("parent", ParentChildHelper.getParentField("parent.to")); + assertNull(ParentChildHelper.getParentField("child")); + assertNull(ParentChildHelper.getParentField("")); + assertNull(ParentChildHelper.getParentField(null)); + } + + public void testGetChildField() { + assertEquals("child", ParentChildHelper.getChildField("parent.to.child")); + assertNull(ParentChildHelper.getChildField(null)); + assertNull(ParentChildHelper.getChildField("child")); + } + + public void testConstructSiblingField() { + assertEquals("parent.to.sibling", ParentChildHelper.constructSiblingField("parent.to.child", "sibling")); + assertEquals("sibling", ParentChildHelper.constructSiblingField("parent", "sibling")); + } + + public void testSplitPath() { + String[] path = ParentChildHelper.splitPath("parent.to.child"); + assertEquals(3, path.length); + assertEquals("parent", path[0]); + assertEquals("to", path[1]); + assertEquals("child", path[2]); + + path = ParentChildHelper.splitPath("parent"); + assertEquals(1, path.length); + assertEquals("parent", path[0]); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java new file mode 100644 index 000000000..f117db8c8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +public class PerFieldDerivedVectorInjectorFactoryTests extends KNNTestCase { + public void testCreate() { + // Non-nested case + PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( + KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), + new DerivedSourceReaders(null, null, null, null), + null + ); + assertTrue(perFieldDerivedVectorInjector instanceof RootPerFieldDerivedVectorInjector); + + // Nested case + perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( + KNNCodecTestUtil.FieldInfoBuilder.builder("parent.test").build(), + new DerivedSourceReaders(null, null, null, null), + null + ); + assertTrue(perFieldDerivedVectorInjector instanceof NestedPerFieldDerivedVectorInjector); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java new file mode 100644 index 000000000..3e015b09c --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.search.DocIdSetIterator; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator; +import org.opensearch.knn.index.vectorvalues.VectorValueExtractorStrategy; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.KNNRestTestCase.FIELD_NAME; + +public class RootPerFieldDerivedVectorInjectorTests extends KNNTestCase { + public static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + + @SneakyThrows + public void testInject() { + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder(FIELD_NAME).build(); + try (MockedStatic mockedKnnVectorValues = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + mockedKnnVectorValues.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, null, null)) + .thenReturn(new KNNVectorValues(new KNNVectorValuesIterator() { + @Override + public int docId() { + return 0; + } + + @Override + public int advance(int docId) { + return 0; + } + + @Override + public int nextDoc() { + return 0; + } + + @Override + public DocIdSetIterator getDocIdSetIterator() { + return null; + } + + @Override + public long liveDocs() { + return 0; + } + + @Override + public VectorValueExtractorStrategy getVectorExtractorStrategy() { + return null; + } + }) { + + @Override + public float[] getVector() { + return TEST_VECTOR; + } + + @Override + public float[] conditionalCloneVector() { + return TEST_VECTOR; + } + }); + PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = new RootPerFieldDerivedVectorInjector( + fieldInfo, + new DerivedSourceReaders(null, null, null, null) + ); + + Map source = new HashMap<>(); + perFieldDerivedVectorInjector.inject(0, source); + assertArrayEquals(TEST_VECTOR, (float[]) source.get(FIELD_NAME), 0.0001f); + } + } +}