diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java index e60b82b2e..f862ea3cb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -19,6 +19,7 @@ import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; import org.opensearch.knn.index.mapper.KNNVectorFieldType; @@ -55,11 +56,14 @@ public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentI if (derivedVectorFields == null || derivedVectorFields.isEmpty()) { return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); } + + SegmentReadState segmentReadState = new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext); + DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); return new DerivedSourceStoredFieldsReader( delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), derivedVectorFields, - derivedSourceReadersSupplier, - new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) + derivedSourceReaders, + segmentReadState ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java index 24900eb19..65a55c998 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -5,23 +5,25 @@ package org.opensearch.knn.index.codec.KNN9120Codec; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.util.IOUtils; import org.opensearch.index.fieldvisitor.FieldsVisitor; -import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; import java.io.IOException; import java.util.List; +@Log4j2 public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { private final StoredFieldsReader delegate; private final List derivedVectorFields; - private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + private final DerivedSourceReaders derivedSourceReaders; private final SegmentReadState segmentReadState; private final boolean shouldInject; @@ -31,36 +33,36 @@ public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { * * @param delegate delegate StoredFieldsReader * @param derivedVectorFields List of fields that are derived source fields - * @param derivedSourceReadersSupplier Supplier for the derived source readers + * @param derivedSourceReaders Derived source readers * @param segmentReadState SegmentReadState for the segment * @throws IOException in case of I/O error */ public DerivedSourceStoredFieldsReader( StoredFieldsReader delegate, List derivedVectorFields, - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState ) throws IOException { - this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true); + this(delegate, derivedVectorFields, derivedSourceReaders, segmentReadState, true); } private DerivedSourceStoredFieldsReader( StoredFieldsReader delegate, List derivedVectorFields, - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState, boolean shouldInject ) throws IOException { this.delegate = delegate; this.derivedVectorFields = derivedVectorFields; - this.derivedSourceReadersSupplier = derivedSourceReadersSupplier; + this.derivedSourceReaders = derivedSourceReaders; this.segmentReadState = segmentReadState; this.shouldInject = shouldInject; this.derivedSourceVectorInjector = createDerivedSourceVectorInjector(); } - private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException { - return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields); + private DerivedSourceVectorInjector createDerivedSourceVectorInjector() { + return new DerivedSourceVectorInjector(derivedSourceReaders, segmentReadState, derivedVectorFields); } @Override @@ -86,7 +88,7 @@ public StoredFieldsReader clone() { return new DerivedSourceStoredFieldsReader( delegate.clone(), derivedVectorFields, - derivedSourceReadersSupplier, + derivedSourceReaders.clone(), segmentReadState, shouldInject ); @@ -102,6 +104,7 @@ public void checkIntegrity() throws IOException { @Override public void close() throws IOException { + log.debug("Closing derived source stored fields reader for segment: " + segmentReadState.segmentInfo.name); IOUtils.close(delegate, derivedSourceVectorInjector); } @@ -117,7 +120,7 @@ private StoredFieldsReader cloneForMerge() { return new DerivedSourceStoredFieldsReader( delegate.getMergeInstance(), derivedVectorFields, - derivedSourceReadersSupplier, + derivedSourceReaders.clone(), segmentReadState, false ); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java index c585b09f7..0c43f6a49 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -79,7 +79,7 @@ public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOExceptio // Reference: // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 Tuple> mapTuple = XContentHelper.convertToMap( - BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)), + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)), true, MediaTypeRegistry.JSON ); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index b8a5e6a12..5e40faf1a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -25,6 +25,7 @@ public class KNN9120Codec extends FilterCodec { private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final StoredFieldsFormat storedFieldsFormat; private final MapperService mapperService; @@ -48,6 +49,7 @@ protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; this.mapperService = mapperService; + this.storedFieldsFormat = getStoredFieldsFormat(); } @Override @@ -67,6 +69,10 @@ public KnnVectorsFormat knnVectorsFormat() { @Override public StoredFieldsFormat storedFieldsFormat() { + return storedFieldsFormat; + } + + private StoredFieldsFormat getStoredFieldsFormat() { DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> { if (segmentReadState.fieldInfos.hasVectorValues()) { return knnVectorsFormat().fieldsReader(segmentReadState); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java index c7e472e60..80b63f02b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -7,6 +7,7 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; @@ -23,7 +24,8 @@ */ @RequiredArgsConstructor @Getter -public class DerivedSourceReaders implements Closeable { +@Log4j2 +public class DerivedSourceReaders implements Cloneable, Closeable { @Nullable private final KnnVectorsReader knnVectorsReader; @Nullable @@ -33,8 +35,17 @@ public class DerivedSourceReaders implements Closeable { @Nullable private final NormsProducer normsProducer; + private final boolean isCloned; + @Override public void close() throws IOException { - IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); + if (isCloned == false) { + IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); + } + } + + @Override + public DerivedSourceReaders clone() { + return new DerivedSourceReaders(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer, true); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java index 2dafa3af9..527f25fa3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java @@ -27,7 +27,7 @@ public class DerivedSourceReadersSupplier { private final DerivedSourceReaderSupplier normsProducer; /** - * Get the readers for the segment + * Get the readers for the segment. * * @param state SegmentReadState * @return DerivedSourceReaders @@ -38,7 +38,8 @@ public DerivedSourceReaders getReaders(SegmentReadState state) throws IOExceptio knnVectorsReaderSupplier.apply(state), docValuesProducerSupplier.apply(state), fieldsProducerSupplier.apply(state), - normsProducer.apply(state) + normsProducer.apply(state), + false ); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index d3b1fe846..e7d78e251 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -42,16 +42,16 @@ public class DerivedSourceVectorInjector implements Closeable { /** * Constructor for DerivedSourceVectorInjector. * - * @param derivedSourceReadersSupplier Supplier for the derived source readers. + * @param derivedSourceReaders Derived source readers. * @param segmentReadState Segment read state * @param fieldsToInjectVector List of fields to inject vectors into */ public DerivedSourceVectorInjector( - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState, List fieldsToInjectVector - ) throws IOException { - this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + ) { + this.derivedSourceReaders = derivedSourceReaders; this.perFieldDerivedVectorInjectors = new ArrayList<>(); this.fieldNames = new HashSet<>(); for (FieldInfo fieldInfo : fieldsToInjectVector) { diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java index 534cf93d7..ae249b1b3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java @@ -15,9 +15,12 @@ public class ParentChildHelper { * this would return "parent.to". * * @param field nested field path - * @return parent field path without the child + * @return parent field path without the child. Null if no parent exists */ public static String getParentField(String field) { + if (field == null) { + return null; + } int lastDot = field.lastIndexOf('.'); if (lastDot == -1) { return null; @@ -30,10 +33,16 @@ public static String getParentField(String field) { * return "child". * * @param field nested field path - * @return child field path without the parent path + * @return child field path without the parent path. Null if no child exists */ public static String getChildField(String field) { + if (field == null) { + return null; + } int lastDot = field.lastIndexOf('.'); + if (lastDot == -1) { + return null; + } return field.substring(lastDot + 1); } @@ -46,7 +55,11 @@ public static String getChildField(String field) { * @return sibling field path */ public static String constructSiblingField(String field, String sibling) { - return getParentField(field) + "." + sibling; + String parent = getParentField(field); + if (parent == null) { + return sibling; + } + return parent + "." + sibling; } /** diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java index 07db4e7f6..9952986c8 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java @@ -20,7 +20,7 @@ /** * Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator} */ -interface VectorValueExtractorStrategy { +public interface VectorValueExtractorStrategy { /** * Extract a float vector from KNNVectorValuesIterator. diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java new file mode 100644 index 000000000..2953539ad --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class DerivedSourceStoredFieldsWriterTests extends KNNTestCase { + + @SneakyThrows + public void testWriteField() { + StoredFieldsWriter delegate = mock(StoredFieldsWriter.class); + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(); + List fields = List.of("test"); + + DerivedSourceStoredFieldsWriter derivedSourceStoredFieldsWriter = new DerivedSourceStoredFieldsWriter(delegate, fields); + + Map source = Map.of("test", new float[] { 1.0f, 2.0f, 3.0f }, "text_field", "text_value"); + BytesStreamOutput bStream = new BytesStreamOutput(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(source); + builder.close(); + byte[] originalBytes = bStream.bytes().toBytesRef().bytes; + byte[] shiftedBytes = new byte[originalBytes.length + 2]; + System.arraycopy(originalBytes, 0, shiftedBytes, 1, originalBytes.length); + derivedSourceStoredFieldsWriter.writeField(fieldInfo, new BytesRef(shiftedBytes, 1, originalBytes.length)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java new file mode 100644 index 000000000..943146a91 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DerivedSourceStoredFieldVisitorTests extends KNNTestCase { + + public void testBinaryField() throws Exception { + StoredFieldVisitor delegate = mock(StoredFieldVisitor.class); + doAnswer(invocationOnMock -> null).when(delegate).binaryField(any(), any()); + DerivedSourceVectorInjector derivedSourceVectorInjector = mock(DerivedSourceVectorInjector.class); + when(derivedSourceVectorInjector.injectVectors(anyInt(), any())).thenReturn(new byte[0]); + DerivedSourceStoredFieldVisitor derivedSourceStoredFieldVisitor = new DerivedSourceStoredFieldVisitor( + delegate, + 0, + derivedSourceVectorInjector + ); + + // When field is not _source, then do not call the injector + derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), null); + verify(derivedSourceVectorInjector, times(0)).injectVectors(anyInt(), any()); + verify(delegate, times(1)).binaryField(any(), any()); + + // When field is not _source, then do call the injector + derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(), null); + verify(derivedSourceVectorInjector, times(1)).injectVectors(anyInt(), any()); + verify(delegate, times(2)).binaryField(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java new file mode 100644 index 000000000..f845e9d82 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; + +public class DerivedSourceVectorInjectorTests extends KNNTestCase { + + @SneakyThrows + @SuppressWarnings("unchecked") + public void testInjectVectors() { + List fields = List.of( + KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() + ); + + Map fieldToVector = Collections.unmodifiableMap(new HashMap<>() { + { + put("test1", new float[] { 1.0f, 2.0f, 3.0f }); + put("test2", new float[] { 4.0f, 5.0f, 6.0f, 7.0f }); + put("test3", new float[] { 7.0f, 8.0f, 9.0f, 1.0f, 3.0f, 4.0f }); + put("test4", null); + } + }); + + try (MockedStatic factory = Mockito.mockStatic(PerFieldDerivedVectorInjectorFactory.class)) { + factory.when(() -> PerFieldDerivedVectorInjectorFactory.create(any(), any(), any())).thenAnswer(invocation -> { + FieldInfo fieldInfo = invocation.getArgument(0); + return (PerFieldDerivedVectorInjector) (docId, sourceAsMap) -> { + float[] vector = fieldToVector.get(fieldInfo.name); + if (vector != null) { + sourceAsMap.put(fieldInfo.name, vector); + } + }; + }); + + DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( + new DerivedSourceReaders(null, null, null, null, false), + null, + fields + ); + + int docId = 2; + String existingFieldKey = "existingField"; + String existingFieldValue = "existingField"; + Map source = Map.of(existingFieldKey, existingFieldValue); + byte[] originalSourceBytes = mapToBytes(source); + byte[] modifiedSourceByttes = derivedSourceVectorInjector.injectVectors(docId, originalSourceBytes); + Map modifiedSource = bytesToMap(modifiedSourceByttes); + + assertEquals(existingFieldValue, modifiedSource.get(existingFieldKey)); + + assertArrayEquals(fieldToVector.get("test1"), toFloatArray((List) modifiedSource.get("test1")), 0.000001f); + assertArrayEquals(fieldToVector.get("test2"), toFloatArray((List) modifiedSource.get("test2")), 0.000001f); + assertArrayEquals(fieldToVector.get("test3"), toFloatArray((List) modifiedSource.get("test3")), 0.000001f); + assertFalse(modifiedSource.containsKey("test4")); + } + } + + @SneakyThrows + private byte[] mapToBytes(Map map) { + + BytesStreamOutput bStream = new BytesStreamOutput(1024); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(map); + builder.close(); + return BytesReference.toBytes(BytesReference.bytes(builder)); + } + + private float[] toFloatArray(List list) { + float[] array = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i).floatValue(); + } + return array; + } + + private Map bytesToMap(byte[] bytes) { + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytes)), + true, + MediaTypeRegistry.getDefaultMediaType() + ); + + return mapTuple.v2(); + } + + @SneakyThrows + public void testShouldInject() { + + List fields = List.of( + KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() + ); + + try ( + DerivedSourceVectorInjector vectorInjector = new DerivedSourceVectorInjector( + new DerivedSourceReaders(null, null, null, null, false), + null, + fields + ) + ) { + assertTrue(vectorInjector.shouldInject(null, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1", "test2", "test3" }, null)); + assertTrue(vectorInjector.shouldInject(null, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2", "test3" })); + assertFalse(vectorInjector.shouldInject(null, new String[] { "test1", "test2", "test3" })); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java new file mode 100644 index 000000000..085222460 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.opensearch.knn.KNNTestCase; + +public class ParentChildHelperTests extends KNNTestCase { + + public void testGetParentField() { + assertEquals("parent.to", ParentChildHelper.getParentField("parent.to.child")); + assertEquals("parent", ParentChildHelper.getParentField("parent.to")); + assertNull(ParentChildHelper.getParentField("child")); + assertNull(ParentChildHelper.getParentField("")); + assertNull(ParentChildHelper.getParentField(null)); + } + + public void testGetChildField() { + assertEquals("child", ParentChildHelper.getChildField("parent.to.child")); + assertNull(ParentChildHelper.getChildField(null)); + assertNull(ParentChildHelper.getChildField("child")); + } + + public void testConstructSiblingField() { + assertEquals("parent.to.sibling", ParentChildHelper.constructSiblingField("parent.to.child", "sibling")); + assertEquals("sibling", ParentChildHelper.constructSiblingField("parent", "sibling")); + } + + public void testSplitPath() { + String[] path = ParentChildHelper.splitPath("parent.to.child"); + assertEquals(3, path.length); + assertEquals("parent", path[0]); + assertEquals("to", path[1]); + assertEquals("child", path[2]); + + path = ParentChildHelper.splitPath("parent"); + assertEquals(1, path.length); + assertEquals("parent", path[0]); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java new file mode 100644 index 000000000..8b7d29221 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +public class PerFieldDerivedVectorInjectorFactoryTests extends KNNTestCase { + public void testCreate() { + // Non-nested case + PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( + KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), + new DerivedSourceReaders(null, null, null, null, false), + null + ); + assertTrue(perFieldDerivedVectorInjector instanceof RootPerFieldDerivedVectorInjector); + + // Nested case + perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( + KNNCodecTestUtil.FieldInfoBuilder.builder("parent.test").build(), + new DerivedSourceReaders(null, null, null, null, false), + null + ); + assertTrue(perFieldDerivedVectorInjector instanceof NestedPerFieldDerivedVectorInjector); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java new file mode 100644 index 000000000..a009091a2 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.KNNRestTestCase.FIELD_NAME; + +public class RootPerFieldDerivedVectorInjectorTests extends KNNTestCase { + public static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + + @SneakyThrows + public void testInject() { + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder(FIELD_NAME).build(); + try (MockedStatic mockedKnnVectorValues = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + + final KNNVectorValuesIterator vectorValuesIterator = Mockito.mock(KNNVectorValuesIterator.class); + when(vectorValuesIterator.docId()).thenReturn(0); + when(vectorValuesIterator.advance(anyInt())).thenReturn(0); + when(vectorValuesIterator.nextDoc()).thenReturn(0); + when(vectorValuesIterator.getDocIdSetIterator()).thenReturn(null); + when(vectorValuesIterator.liveDocs()).thenReturn(0L); + when(vectorValuesIterator.getVectorExtractorStrategy()).thenReturn(null); + + mockedKnnVectorValues.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, null, null)) + .thenReturn(new KNNVectorValues(vectorValuesIterator) { + @Override + public float[] getVector() { + return TEST_VECTOR; + } + + @Override + public float[] conditionalCloneVector() { + return TEST_VECTOR; + } + }); + PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = new RootPerFieldDerivedVectorInjector( + fieldInfo, + new DerivedSourceReaders(null, null, null, null, false) + ); + + Map source = new HashMap<>(); + perFieldDerivedVectorInjector.inject(0, source); + assertArrayEquals(TEST_VECTOR, (float[]) source.get(FIELD_NAME), 0.0001f); + } + } +}