From 4bce385c734792cd158ea6c9d0986871a76c18b6 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Mon, 3 Feb 2025 15:29:25 -0800 Subject: [PATCH 1/3] Fix bytes bug and add uTs for derived source Fixes a bug in the derived source writer where we are reading the entire bytes array from the bytes ref instead of just the offset+length. Along with that, touches up the ParentChildHelper (no prod impact) and also adds some unit tests. Signed-off-by: John Mazanec --- .../DerivedSourceStoredFieldsReader.java | 3 + .../DerivedSourceStoredFieldsWriter.java | 2 +- .../codec/KNN9120Codec/KNN9120Codec.java | 6 + .../derivedsource/DerivedSourceReaders.java | 3 + .../DerivedSourceVectorInjector.java | 4 + .../derivedsource/ParentChildHelper.java | 19 ++- .../VectorValueExtractorStrategy.java | 2 +- .../DerivedSourceStoredFieldVisitorTests.java | 43 ++++++ .../DerivedSourceVectorInjectorTests.java | 136 ++++++++++++++++++ .../derivedsource/ParentChildHelperTests.java | 42 ++++++ ...ieldDerivedVectorInjectorFactoryTests.java | 29 ++++ ...ootPerFieldDerivedVectorInjectorTests.java | 85 +++++++++++ 12 files changed, 369 insertions(+), 5 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java index 24900eb19..03a89b089 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.KNN9120Codec; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentReadState; @@ -18,6 +19,7 @@ import java.io.IOException; import java.util.List; +@Log4j2 public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { private final StoredFieldsReader delegate; private final List derivedVectorFields; @@ -102,6 +104,7 @@ public void checkIntegrity() throws IOException { @Override public void close() throws IOException { + log.info("Closing derived source stored fields reader for segment: " + segmentReadState.segmentInfo.name); IOUtils.close(delegate, derivedSourceVectorInjector); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java index c585b09f7..0c43f6a49 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -79,7 +79,7 @@ public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOExceptio // Reference: // https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322 Tuple> mapTuple = XContentHelper.convertToMap( - BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)), + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)), true, MediaTypeRegistry.JSON ); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index b8a5e6a12..5e40faf1a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -25,6 +25,7 @@ public class KNN9120Codec extends FilterCodec { private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final StoredFieldsFormat storedFieldsFormat; private final MapperService mapperService; @@ -48,6 +49,7 @@ protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; this.mapperService = mapperService; + this.storedFieldsFormat = getStoredFieldsFormat(); } @Override @@ -67,6 +69,10 @@ public KnnVectorsFormat knnVectorsFormat() { @Override public StoredFieldsFormat storedFieldsFormat() { + return storedFieldsFormat; + } + + private StoredFieldsFormat getStoredFieldsFormat() { DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> { if (segmentReadState.fieldInfos.hasVectorValues()) { return knnVectorsFormat().fieldsReader(segmentReadState); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java index c7e472e60..76be2f16c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -7,6 +7,7 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; @@ -23,6 +24,7 @@ */ @RequiredArgsConstructor @Getter +@Log4j2 public class DerivedSourceReaders implements Closeable { @Nullable private final KnnVectorsReader knnVectorsReader; @@ -35,6 +37,7 @@ public class DerivedSourceReaders implements Closeable { @Override public void close() throws IOException { + log.info("Closing derived source readers"); IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index d3b1fe846..4ac1d1058 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -39,6 +39,8 @@ public class DerivedSourceVectorInjector implements Closeable { private final List perFieldDerivedVectorInjectors; private final Set fieldNames; + private final String segmentName; + /** * Constructor for DerivedSourceVectorInjector. * @@ -60,6 +62,7 @@ public DerivedSourceVectorInjector( ); this.fieldNames.add(fieldInfo.name); } + this.segmentName = segmentReadState.segmentInfo.name; } /** @@ -131,6 +134,7 @@ public boolean shouldInject(String[] includes, String[] excludes) { @Override public void close() throws IOException { + log.info("Closing derived source injector reader for segment" + segmentName); IOUtils.close(derivedSourceReaders); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java index 534cf93d7..ae249b1b3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelper.java @@ -15,9 +15,12 @@ public class ParentChildHelper { * this would return "parent.to". * * @param field nested field path - * @return parent field path without the child + * @return parent field path without the child. Null if no parent exists */ public static String getParentField(String field) { + if (field == null) { + return null; + } int lastDot = field.lastIndexOf('.'); if (lastDot == -1) { return null; @@ -30,10 +33,16 @@ public static String getParentField(String field) { * return "child". * * @param field nested field path - * @return child field path without the parent path + * @return child field path without the parent path. Null if no child exists */ public static String getChildField(String field) { + if (field == null) { + return null; + } int lastDot = field.lastIndexOf('.'); + if (lastDot == -1) { + return null; + } return field.substring(lastDot + 1); } @@ -46,7 +55,11 @@ public static String getChildField(String field) { * @return sibling field path */ public static String constructSiblingField(String field, String sibling) { - return getParentField(field) + "." + sibling; + String parent = getParentField(field); + if (parent == null) { + return sibling; + } + return parent + "." + sibling; } /** diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java index 07db4e7f6..9952986c8 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java @@ -20,7 +20,7 @@ /** * Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator} */ -interface VectorValueExtractorStrategy { +public interface VectorValueExtractorStrategy { /** * Extract a float vector from KNNVectorValuesIterator. diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java new file mode 100644 index 000000000..943146a91 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceStoredFieldVisitorTests.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.apache.lucene.index.StoredFieldVisitor; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DerivedSourceStoredFieldVisitorTests extends KNNTestCase { + + public void testBinaryField() throws Exception { + StoredFieldVisitor delegate = mock(StoredFieldVisitor.class); + doAnswer(invocationOnMock -> null).when(delegate).binaryField(any(), any()); + DerivedSourceVectorInjector derivedSourceVectorInjector = mock(DerivedSourceVectorInjector.class); + when(derivedSourceVectorInjector.injectVectors(anyInt(), any())).thenReturn(new byte[0]); + DerivedSourceStoredFieldVisitor derivedSourceStoredFieldVisitor = new DerivedSourceStoredFieldVisitor( + delegate, + 0, + derivedSourceVectorInjector + ); + + // When field is not _source, then do not call the injector + derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), null); + verify(derivedSourceVectorInjector, times(0)).injectVectors(anyInt(), any()); + verify(delegate, times(1)).binaryField(any(), any()); + + // When field is not _source, then do call the injector + derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(), null); + verify(derivedSourceVectorInjector, times(1)).injectVectors(anyInt(), any()); + verify(delegate, times(2)).binaryField(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java new file mode 100644 index 000000000..1fa4b9364 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; + +public class DerivedSourceVectorInjectorTests extends KNNTestCase { + + @SneakyThrows + @SuppressWarnings("unchecked") + public void testInjectVectors() { + List fields = List.of( + KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() + ); + + Map fieldToVector = Collections.unmodifiableMap(new HashMap<>() { + { + put("test1", new float[] { 1.0f, 2.0f, 3.0f }); + put("test2", new float[] { 4.0f, 5.0f, 6.0f, 7.0f }); + put("test3", new float[] { 7.0f, 8.0f, 9.0f, 1.0f, 3.0f, 4.0f }); + put("test4", null); + } + }); + + try (MockedStatic factory = Mockito.mockStatic(PerFieldDerivedVectorInjectorFactory.class)) { + factory.when(() -> PerFieldDerivedVectorInjectorFactory.create(any(), any(), any())).thenAnswer(invocation -> { + FieldInfo fieldInfo = invocation.getArgument(0); + return (PerFieldDerivedVectorInjector) (docId, sourceAsMap) -> { + float[] vector = fieldToVector.get(fieldInfo.name); + if (vector != null) { + sourceAsMap.put(fieldInfo.name, vector); + } + }; + }); + + DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( + new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + null, + fields + ); + + int docId = 2; + String existingFieldKey = "existingField"; + String existingFieldValue = "existingField"; + Map source = Map.of(existingFieldKey, existingFieldValue); + byte[] originalSourceBytes = mapToBytes(source); + byte[] modifiedSourceByttes = derivedSourceVectorInjector.injectVectors(docId, originalSourceBytes); + Map modifiedSource = bytesToMap(modifiedSourceByttes); + + assertEquals(existingFieldValue, modifiedSource.get(existingFieldKey)); + + assertArrayEquals(fieldToVector.get("test1"), toFloatArray((List) modifiedSource.get("test1")), 0.000001f); + assertArrayEquals(fieldToVector.get("test2"), toFloatArray((List) modifiedSource.get("test2")), 0.000001f); + assertArrayEquals(fieldToVector.get("test3"), toFloatArray((List) modifiedSource.get("test3")), 0.000001f); + assertFalse(modifiedSource.containsKey("test4")); + } + } + + @SneakyThrows + private byte[] mapToBytes(Map map) { + + BytesStreamOutput bStream = new BytesStreamOutput(1024); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(map); + builder.close(); + return BytesReference.toBytes(BytesReference.bytes(builder)); + } + + private float[] toFloatArray(List list) { + float[] array = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i).floatValue(); + } + return array; + } + + private Map bytesToMap(byte[] bytes) { + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytes)), + true, + MediaTypeRegistry.getDefaultMediaType() + ); + + return mapTuple.v2(); + } + + @SneakyThrows + public void testShouldInject() { + + List fields = List.of( + KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(), + KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() + ); + + try ( + DerivedSourceVectorInjector vectorInjector = new DerivedSourceVectorInjector( + new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + null, + fields + ) + ) { + assertTrue(vectorInjector.shouldInject(null, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1", "test2", "test3" }, null)); + assertTrue(vectorInjector.shouldInject(null, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2", "test3" })); + assertFalse(vectorInjector.shouldInject(null, new String[] { "test1", "test2", "test3" })); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java new file mode 100644 index 000000000..085222460 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/ParentChildHelperTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.opensearch.knn.KNNTestCase; + +public class ParentChildHelperTests extends KNNTestCase { + + public void testGetParentField() { + assertEquals("parent.to", ParentChildHelper.getParentField("parent.to.child")); + assertEquals("parent", ParentChildHelper.getParentField("parent.to")); + assertNull(ParentChildHelper.getParentField("child")); + assertNull(ParentChildHelper.getParentField("")); + assertNull(ParentChildHelper.getParentField(null)); + } + + public void testGetChildField() { + assertEquals("child", ParentChildHelper.getChildField("parent.to.child")); + assertNull(ParentChildHelper.getChildField(null)); + assertNull(ParentChildHelper.getChildField("child")); + } + + public void testConstructSiblingField() { + assertEquals("parent.to.sibling", ParentChildHelper.constructSiblingField("parent.to.child", "sibling")); + assertEquals("sibling", ParentChildHelper.constructSiblingField("parent", "sibling")); + } + + public void testSplitPath() { + String[] path = ParentChildHelper.splitPath("parent.to.child"); + assertEquals(3, path.length); + assertEquals("parent", path[0]); + assertEquals("to", path[1]); + assertEquals("child", path[2]); + + path = ParentChildHelper.splitPath("parent"); + assertEquals(1, path.length); + assertEquals("parent", path[0]); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java new file mode 100644 index 000000000..f117db8c8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +public class PerFieldDerivedVectorInjectorFactoryTests extends KNNTestCase { + public void testCreate() { + // Non-nested case + PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( + KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), + new DerivedSourceReaders(null, null, null, null), + null + ); + assertTrue(perFieldDerivedVectorInjector instanceof RootPerFieldDerivedVectorInjector); + + // Nested case + perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( + KNNCodecTestUtil.FieldInfoBuilder.builder("parent.test").build(), + new DerivedSourceReaders(null, null, null, null), + null + ); + assertTrue(perFieldDerivedVectorInjector instanceof NestedPerFieldDerivedVectorInjector); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java new file mode 100644 index 000000000..3e015b09c --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.search.DocIdSetIterator; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator; +import org.opensearch.knn.index.vectorvalues.VectorValueExtractorStrategy; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.KNNRestTestCase.FIELD_NAME; + +public class RootPerFieldDerivedVectorInjectorTests extends KNNTestCase { + public static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + + @SneakyThrows + public void testInject() { + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder(FIELD_NAME).build(); + try (MockedStatic mockedKnnVectorValues = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + mockedKnnVectorValues.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, null, null)) + .thenReturn(new KNNVectorValues(new KNNVectorValuesIterator() { + @Override + public int docId() { + return 0; + } + + @Override + public int advance(int docId) { + return 0; + } + + @Override + public int nextDoc() { + return 0; + } + + @Override + public DocIdSetIterator getDocIdSetIterator() { + return null; + } + + @Override + public long liveDocs() { + return 0; + } + + @Override + public VectorValueExtractorStrategy getVectorExtractorStrategy() { + return null; + } + }) { + + @Override + public float[] getVector() { + return TEST_VECTOR; + } + + @Override + public float[] conditionalCloneVector() { + return TEST_VECTOR; + } + }); + PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = new RootPerFieldDerivedVectorInjector( + fieldInfo, + new DerivedSourceReaders(null, null, null, null) + ); + + Map source = new HashMap<>(); + perFieldDerivedVectorInjector.inject(0, source); + assertArrayEquals(TEST_VECTOR, (float[]) source.get(FIELD_NAME), 0.0001f); + } + } +} From 91365c97214c0c50ecf1c87041fbfdcb78d82493 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 5 Feb 2025 19:30:28 -0800 Subject: [PATCH 2/3] Reuse readers to prevent memory leak Signed-off-by: John Mazanec --- .../DerivedSourceStoredFieldsFormat.java | 8 +++- .../DerivedSourceStoredFieldsReader.java | 24 +++++----- .../derivedsource/DerivedSourceReaders.java | 14 ++++-- .../DerivedSourceReadersSupplier.java | 5 +- .../DerivedSourceVectorInjector.java | 12 ++--- .../DerivedSourceVectorInjectorTests.java | 4 +- ...ieldDerivedVectorInjectorFactoryTests.java | 4 +- ...ootPerFieldDerivedVectorInjectorTests.java | 46 +++++-------------- 8 files changed, 52 insertions(+), 65 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java index e60b82b2e..f862ea3cb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -19,6 +19,7 @@ import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; import org.opensearch.knn.index.mapper.KNNVectorFieldType; @@ -55,11 +56,14 @@ public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentI if (derivedVectorFields == null || derivedVectorFields.isEmpty()) { return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); } + + SegmentReadState segmentReadState = new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext); + DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); return new DerivedSourceStoredFieldsReader( delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), derivedVectorFields, - derivedSourceReadersSupplier, - new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) + derivedSourceReaders, + segmentReadState ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java index 03a89b089..65a55c998 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -12,7 +12,7 @@ import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.util.IOUtils; import org.opensearch.index.fieldvisitor.FieldsVisitor; -import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; @@ -23,7 +23,7 @@ public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { private final StoredFieldsReader delegate; private final List derivedVectorFields; - private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + private final DerivedSourceReaders derivedSourceReaders; private final SegmentReadState segmentReadState; private final boolean shouldInject; @@ -33,36 +33,36 @@ public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { * * @param delegate delegate StoredFieldsReader * @param derivedVectorFields List of fields that are derived source fields - * @param derivedSourceReadersSupplier Supplier for the derived source readers + * @param derivedSourceReaders Derived source readers * @param segmentReadState SegmentReadState for the segment * @throws IOException in case of I/O error */ public DerivedSourceStoredFieldsReader( StoredFieldsReader delegate, List derivedVectorFields, - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState ) throws IOException { - this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true); + this(delegate, derivedVectorFields, derivedSourceReaders, segmentReadState, true); } private DerivedSourceStoredFieldsReader( StoredFieldsReader delegate, List derivedVectorFields, - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState, boolean shouldInject ) throws IOException { this.delegate = delegate; this.derivedVectorFields = derivedVectorFields; - this.derivedSourceReadersSupplier = derivedSourceReadersSupplier; + this.derivedSourceReaders = derivedSourceReaders; this.segmentReadState = segmentReadState; this.shouldInject = shouldInject; this.derivedSourceVectorInjector = createDerivedSourceVectorInjector(); } - private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException { - return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields); + private DerivedSourceVectorInjector createDerivedSourceVectorInjector() { + return new DerivedSourceVectorInjector(derivedSourceReaders, segmentReadState, derivedVectorFields); } @Override @@ -88,7 +88,7 @@ public StoredFieldsReader clone() { return new DerivedSourceStoredFieldsReader( delegate.clone(), derivedVectorFields, - derivedSourceReadersSupplier, + derivedSourceReaders.clone(), segmentReadState, shouldInject ); @@ -104,7 +104,7 @@ public void checkIntegrity() throws IOException { @Override public void close() throws IOException { - log.info("Closing derived source stored fields reader for segment: " + segmentReadState.segmentInfo.name); + log.debug("Closing derived source stored fields reader for segment: " + segmentReadState.segmentInfo.name); IOUtils.close(delegate, derivedSourceVectorInjector); } @@ -120,7 +120,7 @@ private StoredFieldsReader cloneForMerge() { return new DerivedSourceStoredFieldsReader( delegate.getMergeInstance(), derivedVectorFields, - derivedSourceReadersSupplier, + derivedSourceReaders.clone(), segmentReadState, false ); diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java index 76be2f16c..80b63f02b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -25,7 +25,7 @@ @RequiredArgsConstructor @Getter @Log4j2 -public class DerivedSourceReaders implements Closeable { +public class DerivedSourceReaders implements Cloneable, Closeable { @Nullable private final KnnVectorsReader knnVectorsReader; @Nullable @@ -35,9 +35,17 @@ public class DerivedSourceReaders implements Closeable { @Nullable private final NormsProducer normsProducer; + private final boolean isCloned; + @Override public void close() throws IOException { - log.info("Closing derived source readers"); - IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); + if (isCloned == false) { + IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); + } + } + + @Override + public DerivedSourceReaders clone() { + return new DerivedSourceReaders(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer, true); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java index 2dafa3af9..527f25fa3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersSupplier.java @@ -27,7 +27,7 @@ public class DerivedSourceReadersSupplier { private final DerivedSourceReaderSupplier normsProducer; /** - * Get the readers for the segment + * Get the readers for the segment. * * @param state SegmentReadState * @return DerivedSourceReaders @@ -38,7 +38,8 @@ public DerivedSourceReaders getReaders(SegmentReadState state) throws IOExceptio knnVectorsReaderSupplier.apply(state), docValuesProducerSupplier.apply(state), fieldsProducerSupplier.apply(state), - normsProducer.apply(state) + normsProducer.apply(state), + false ); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index 4ac1d1058..e7d78e251 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -39,21 +39,19 @@ public class DerivedSourceVectorInjector implements Closeable { private final List perFieldDerivedVectorInjectors; private final Set fieldNames; - private final String segmentName; - /** * Constructor for DerivedSourceVectorInjector. * - * @param derivedSourceReadersSupplier Supplier for the derived source readers. + * @param derivedSourceReaders Derived source readers. * @param segmentReadState Segment read state * @param fieldsToInjectVector List of fields to inject vectors into */ public DerivedSourceVectorInjector( - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState, List fieldsToInjectVector - ) throws IOException { - this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + ) { + this.derivedSourceReaders = derivedSourceReaders; this.perFieldDerivedVectorInjectors = new ArrayList<>(); this.fieldNames = new HashSet<>(); for (FieldInfo fieldInfo : fieldsToInjectVector) { @@ -62,7 +60,6 @@ public DerivedSourceVectorInjector( ); this.fieldNames.add(fieldInfo.name); } - this.segmentName = segmentReadState.segmentInfo.name; } /** @@ -134,7 +131,6 @@ public boolean shouldInject(String[] includes, String[] excludes) { @Override public void close() throws IOException { - log.info("Closing derived source injector reader for segment" + segmentName); IOUtils.close(derivedSourceReaders); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java index 1fa4b9364..f845e9d82 100644 --- a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjectorTests.java @@ -59,7 +59,7 @@ public void testInjectVectors() { }); DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( - new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + new DerivedSourceReaders(null, null, null, null, false), null, fields ); @@ -119,7 +119,7 @@ public void testShouldInject() { try ( DerivedSourceVectorInjector vectorInjector = new DerivedSourceVectorInjector( - new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + new DerivedSourceReaders(null, null, null, null, false), null, fields ) diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java index f117db8c8..8b7d29221 100644 --- a/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/PerFieldDerivedVectorInjectorFactoryTests.java @@ -13,7 +13,7 @@ public void testCreate() { // Non-nested case PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), - new DerivedSourceReaders(null, null, null, null), + new DerivedSourceReaders(null, null, null, null, false), null ); assertTrue(perFieldDerivedVectorInjector instanceof RootPerFieldDerivedVectorInjector); @@ -21,7 +21,7 @@ public void testCreate() { // Nested case perFieldDerivedVectorInjector = PerFieldDerivedVectorInjectorFactory.create( KNNCodecTestUtil.FieldInfoBuilder.builder("parent.test").build(), - new DerivedSourceReaders(null, null, null, null), + new DerivedSourceReaders(null, null, null, null, false), null ); assertTrue(perFieldDerivedVectorInjector instanceof NestedPerFieldDerivedVectorInjector); diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java index 3e015b09c..a009091a2 100644 --- a/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjectorTests.java @@ -7,7 +7,6 @@ import lombok.SneakyThrows; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.search.DocIdSetIterator; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; @@ -15,11 +14,12 @@ import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator; -import org.opensearch.knn.index.vectorvalues.VectorValueExtractorStrategy; import java.util.HashMap; import java.util.Map; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.FIELD_NAME; public class RootPerFieldDerivedVectorInjectorTests extends KNNTestCase { @@ -29,39 +29,17 @@ public class RootPerFieldDerivedVectorInjectorTests extends KNNTestCase { public void testInject() { FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder(FIELD_NAME).build(); try (MockedStatic mockedKnnVectorValues = Mockito.mockStatic(KNNVectorValuesFactory.class)) { - mockedKnnVectorValues.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, null, null)) - .thenReturn(new KNNVectorValues(new KNNVectorValuesIterator() { - @Override - public int docId() { - return 0; - } - - @Override - public int advance(int docId) { - return 0; - } - - @Override - public int nextDoc() { - return 0; - } - @Override - public DocIdSetIterator getDocIdSetIterator() { - return null; - } - - @Override - public long liveDocs() { - return 0; - } - - @Override - public VectorValueExtractorStrategy getVectorExtractorStrategy() { - return null; - } - }) { + final KNNVectorValuesIterator vectorValuesIterator = Mockito.mock(KNNVectorValuesIterator.class); + when(vectorValuesIterator.docId()).thenReturn(0); + when(vectorValuesIterator.advance(anyInt())).thenReturn(0); + when(vectorValuesIterator.nextDoc()).thenReturn(0); + when(vectorValuesIterator.getDocIdSetIterator()).thenReturn(null); + when(vectorValuesIterator.liveDocs()).thenReturn(0L); + when(vectorValuesIterator.getVectorExtractorStrategy()).thenReturn(null); + mockedKnnVectorValues.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, null, null)) + .thenReturn(new KNNVectorValues(vectorValuesIterator) { @Override public float[] getVector() { return TEST_VECTOR; @@ -74,7 +52,7 @@ public float[] conditionalCloneVector() { }); PerFieldDerivedVectorInjector perFieldDerivedVectorInjector = new RootPerFieldDerivedVectorInjector( fieldInfo, - new DerivedSourceReaders(null, null, null, null) + new DerivedSourceReaders(null, null, null, null, false) ); Map source = new HashMap<>(); From 1471b4e3eae55d93233e085c337b4a8859344bf5 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 5 Feb 2025 22:04:22 -0800 Subject: [PATCH 3/3] Add uT Signed-off-by: John Mazanec --- .../DerivedSourceStoredFieldsWriterTests.java | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java new file mode 100644 index 000000000..2953539ad --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriterTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class DerivedSourceStoredFieldsWriterTests extends KNNTestCase { + + @SneakyThrows + public void testWriteField() { + StoredFieldsWriter delegate = mock(StoredFieldsWriter.class); + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(); + List fields = List.of("test"); + + DerivedSourceStoredFieldsWriter derivedSourceStoredFieldsWriter = new DerivedSourceStoredFieldsWriter(delegate, fields); + + Map source = Map.of("test", new float[] { 1.0f, 2.0f, 3.0f }, "text_field", "text_value"); + BytesStreamOutput bStream = new BytesStreamOutput(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(source); + builder.close(); + byte[] originalBytes = bStream.bytes().toBytesRef().bytes; + byte[] shiftedBytes = new byte[originalBytes.length + 2]; + System.arraycopy(originalBytes, 0, shiftedBytes, 1, originalBytes.length); + derivedSourceStoredFieldsWriter.writeField(fieldInfo, new BytesRef(shiftedBytes, 1, originalBytes.length)); + } +}