Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bytes offset bug and duplicate readers and add uTs for derived source #2494

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;

Expand Down Expand Up @@ -55,11 +56,14 @@ public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentI
if (derivedVectorFields == null || derivedVectorFields.isEmpty()) {
return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext);
}

SegmentReadState segmentReadState = new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext);
DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState);
return new DerivedSourceStoredFieldsReader(
delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext),
derivedVectorFields,
derivedSourceReadersSupplier,
new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext)
derivedSourceReaders,
segmentReadState
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.util.IOUtils;
import org.opensearch.index.fieldvisitor.FieldsVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector;

import java.io.IOException;
import java.util.List;

@Log4j2
public class DerivedSourceStoredFieldsReader extends StoredFieldsReader {
private final StoredFieldsReader delegate;
private final List<FieldInfo> derivedVectorFields;
private final DerivedSourceReadersSupplier derivedSourceReadersSupplier;
private final DerivedSourceReaders derivedSourceReaders;
private final SegmentReadState segmentReadState;
private final boolean shouldInject;

Expand All @@ -31,36 +33,36 @@ public class DerivedSourceStoredFieldsReader extends StoredFieldsReader {
*
* @param delegate delegate StoredFieldsReader
* @param derivedVectorFields List of fields that are derived source fields
* @param derivedSourceReadersSupplier Supplier for the derived source readers
* @param derivedSourceReaders Derived source readers
* @param segmentReadState SegmentReadState for the segment
* @throws IOException in case of I/O error
*/
public DerivedSourceStoredFieldsReader(
StoredFieldsReader delegate,
List<FieldInfo> derivedVectorFields,
DerivedSourceReadersSupplier derivedSourceReadersSupplier,
DerivedSourceReaders derivedSourceReaders,
SegmentReadState segmentReadState
) throws IOException {
this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true);
this(delegate, derivedVectorFields, derivedSourceReaders, segmentReadState, true);
}

private DerivedSourceStoredFieldsReader(
StoredFieldsReader delegate,
List<FieldInfo> derivedVectorFields,
DerivedSourceReadersSupplier derivedSourceReadersSupplier,
DerivedSourceReaders derivedSourceReaders,
SegmentReadState segmentReadState,
boolean shouldInject
) throws IOException {
this.delegate = delegate;
this.derivedVectorFields = derivedVectorFields;
this.derivedSourceReadersSupplier = derivedSourceReadersSupplier;
this.derivedSourceReaders = derivedSourceReaders;
this.segmentReadState = segmentReadState;
this.shouldInject = shouldInject;
this.derivedSourceVectorInjector = createDerivedSourceVectorInjector();
}

private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException {
return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields);
private DerivedSourceVectorInjector createDerivedSourceVectorInjector() {
return new DerivedSourceVectorInjector(derivedSourceReaders, segmentReadState, derivedVectorFields);
}

@Override
Expand All @@ -86,7 +88,7 @@ public StoredFieldsReader clone() {
return new DerivedSourceStoredFieldsReader(
delegate.clone(),
derivedVectorFields,
derivedSourceReadersSupplier,
derivedSourceReaders.clone(),
segmentReadState,
shouldInject
);
Expand All @@ -102,6 +104,7 @@ public void checkIntegrity() throws IOException {

@Override
public void close() throws IOException {
log.debug("Closing derived source stored fields reader for segment: " + segmentReadState.segmentInfo.name);
IOUtils.close(delegate, derivedSourceVectorInjector);
}

Expand All @@ -117,7 +120,7 @@ private StoredFieldsReader cloneForMerge() {
return new DerivedSourceStoredFieldsReader(
delegate.getMergeInstance(),
derivedVectorFields,
derivedSourceReadersSupplier,
derivedSourceReaders.clone(),
segmentReadState,
false
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOExceptio
// Reference:
// https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322
Tuple<? extends MediaType, Map<String, Object>> mapTuple = XContentHelper.convertToMap(
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)),
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a UT for this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense - added and validated fix before and after

true,
MediaTypeRegistry.JSON
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class KNN9120Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;
private final StoredFieldsFormat storedFieldsFormat;

private final MapperService mapperService;

Expand All @@ -48,6 +49,7 @@ protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
this.mapperService = mapperService;
this.storedFieldsFormat = getStoredFieldsFormat();
}

@Override
Expand All @@ -67,6 +69,10 @@ public KnnVectorsFormat knnVectorsFormat() {

@Override
public StoredFieldsFormat storedFieldsFormat() {
return storedFieldsFormat;
}

private StoredFieldsFormat getStoredFieldsFormat() {
DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> {
if (segmentReadState.fieldInfos.hasVectorValues()) {
return knnVectorsFormat().fieldsReader(segmentReadState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -23,7 +24,8 @@
*/
@RequiredArgsConstructor
@Getter
public class DerivedSourceReaders implements Closeable {
@Log4j2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit-pick] you are not using Log4j2 in the code

public class DerivedSourceReaders implements Cloneable, Closeable {
@Nullable
private final KnnVectorsReader knnVectorsReader;
@Nullable
Expand All @@ -33,8 +35,17 @@ public class DerivedSourceReaders implements Closeable {
@Nullable
private final NormsProducer normsProducer;

private final boolean isCloned;

@Override
public void close() throws IOException {
IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer);
if (isCloned == false) {
IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer);
}
}

@Override
public DerivedSourceReaders clone() {
return new DerivedSourceReaders(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class DerivedSourceReadersSupplier {
private final DerivedSourceReaderSupplier<NormsProducer> normsProducer;

/**
* Get the readers for the segment
* Get the readers for the segment.
*
* @param state SegmentReadState
* @return DerivedSourceReaders
Expand All @@ -38,7 +38,8 @@ public DerivedSourceReaders getReaders(SegmentReadState state) throws IOExceptio
knnVectorsReaderSupplier.apply(state),
docValuesProducerSupplier.apply(state),
fieldsProducerSupplier.apply(state),
normsProducer.apply(state)
normsProducer.apply(state),
false
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ public class DerivedSourceVectorInjector implements Closeable {
/**
* Constructor for DerivedSourceVectorInjector.
*
* @param derivedSourceReadersSupplier Supplier for the derived source readers.
* @param derivedSourceReaders Derived source readers.
* @param segmentReadState Segment read state
* @param fieldsToInjectVector List of fields to inject vectors into
*/
public DerivedSourceVectorInjector(
DerivedSourceReadersSupplier derivedSourceReadersSupplier,
DerivedSourceReaders derivedSourceReaders,
SegmentReadState segmentReadState,
List<FieldInfo> fieldsToInjectVector
) throws IOException {
this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState);
) {
this.derivedSourceReaders = derivedSourceReaders;
this.perFieldDerivedVectorInjectors = new ArrayList<>();
this.fieldNames = new HashSet<>();
for (FieldInfo fieldInfo : fieldsToInjectVector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ public class ParentChildHelper {
* this would return "parent.to".
*
* @param field nested field path
* @return parent field path without the child
* @return parent field path without the child. Null if no parent exists
*/
public static String getParentField(String field) {
if (field == null) {
return null;
}
int lastDot = field.lastIndexOf('.');
if (lastDot == -1) {
return null;
Expand All @@ -30,10 +33,16 @@ public static String getParentField(String field) {
* return "child".
*
* @param field nested field path
* @return child field path without the parent path
* @return child field path without the parent path. Null if no child exists
*/
public static String getChildField(String field) {
if (field == null) {
return null;
}
int lastDot = field.lastIndexOf('.');
if (lastDot == -1) {
return null;
}
return field.substring(lastDot + 1);
}

Expand All @@ -46,7 +55,11 @@ public static String getChildField(String field) {
* @return sibling field path
*/
public static String constructSiblingField(String field, String sibling) {
return getParentField(field) + "." + sibling;
String parent = getParentField(field);
if (parent == null) {
return sibling;
}
return parent + "." + sibling;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
/**
* Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator}
*/
interface VectorValueExtractorStrategy {
public interface VectorValueExtractorStrategy {

/**
* Extract a float vector from KNNVectorValuesIterator.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.SneakyThrows;
import org.apache.lucene.codecs.StoredFieldsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import java.util.List;
import java.util.Map;

import static org.mockito.Mockito.mock;

public class DerivedSourceStoredFieldsWriterTests extends KNNTestCase {

@SneakyThrows
public void testWriteField() {
StoredFieldsWriter delegate = mock(StoredFieldsWriter.class);
FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build();
List<String> fields = List.of("test");

DerivedSourceStoredFieldsWriter derivedSourceStoredFieldsWriter = new DerivedSourceStoredFieldsWriter(delegate, fields);

Map<String, Object> source = Map.of("test", new float[] { 1.0f, 2.0f, 3.0f }, "text_field", "text_value");
BytesStreamOutput bStream = new BytesStreamOutput();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(source);
builder.close();
byte[] originalBytes = bStream.bytes().toBytesRef().bytes;
byte[] shiftedBytes = new byte[originalBytes.length + 2];
System.arraycopy(originalBytes, 0, shiftedBytes, 1, originalBytes.length);
derivedSourceStoredFieldsWriter.writeField(fieldInfo, new BytesRef(shiftedBytes, 1, originalBytes.length));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import org.apache.lucene.index.StoredFieldVisitor;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DerivedSourceStoredFieldVisitorTests extends KNNTestCase {

public void testBinaryField() throws Exception {
StoredFieldVisitor delegate = mock(StoredFieldVisitor.class);
doAnswer(invocationOnMock -> null).when(delegate).binaryField(any(), any());
DerivedSourceVectorInjector derivedSourceVectorInjector = mock(DerivedSourceVectorInjector.class);
when(derivedSourceVectorInjector.injectVectors(anyInt(), any())).thenReturn(new byte[0]);
DerivedSourceStoredFieldVisitor derivedSourceStoredFieldVisitor = new DerivedSourceStoredFieldVisitor(
delegate,
0,
derivedSourceVectorInjector
);

// When field is not _source, then do not call the injector
derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), null);
verify(derivedSourceVectorInjector, times(0)).injectVectors(anyInt(), any());
verify(delegate, times(1)).binaryField(any(), any());

// When field is not _source, then do call the injector
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: When the field is _source, then do call the injector

derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(), null);
verify(derivedSourceVectorInjector, times(1)).injectVectors(anyInt(), any());
verify(delegate, times(2)).binaryField(any(), any());
}
}
Loading
Loading