Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HSEARCH-5021 Improve support for custom value bridges for vector fields #3863

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,17 @@ public ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger() {
}

@Override
public <F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType) {
public <F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType) {
throw new UnsupportedOperationException( "The Elasticsearch backend does not support vector field yet." );
}

@Override
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension) {
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector() {
throw new UnsupportedOperationException( "The Elasticsearch backend does not support vector field yet." );
}

@Override
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension) {
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector() {
throw new UnsupportedOperationException( "The Elasticsearch backend does not support vector field yet." );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -727,4 +727,8 @@ SearchException vectorKnnMatchVectorTypeDiffersFromField(String absoluteFieldPat
+ " Matching against an array with length of '%3$s' is unsupported."
+ " Use the array of the same size as the vector field.")
SearchException vectorKnnMatchVectorDimensionDiffersFromField(String absoluteFieldPath, int expected, int actual);

@Message(id = ID_OFFSET + 179, value = "Invalid index field type: missing vector dimension."
+ " Define the vector dimension explicitly. %1$s")
SearchException nullVectorDimension(String hint, @Param EventContext eventContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,15 @@ abstract class AbstractLuceneVectorFieldTypeOptionsStep<S extends AbstractLucene
private static final int MAX_MAX_CONNECTIONS = 512;

protected VectorSimilarity vectorSimilarity = VectorSimilarity.DEFAULT;
protected final int dimension;
protected Integer dimension;
protected int beamWidth = MAX_MAX_CONNECTIONS;
protected int maxConnections = 16;
private Projectable projectable = Projectable.DEFAULT;
private Searchable searchable = Searchable.DEFAULT;
private F indexNullAsValue = null;

AbstractLuceneVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, Class<F> valueType, int dimension) {
AbstractLuceneVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, Class<F> valueType) {
super( buildContext, valueType );
if ( dimension < 1 || dimension > DEFAULT_MAX_DIMENSIONS ) {
throw log.vectorPropertyUnsupportedValue( "dimension", dimension, DEFAULT_MAX_DIMENSIONS );
}
this.dimension = dimension;
}

@Override
Expand Down Expand Up @@ -94,6 +90,15 @@ public S maxConnections(int maxConnections) {
return thisAsS();
}

@Override
public S dimension(int dimension) {
if ( dimension < 1 || dimension > DEFAULT_MAX_DIMENSIONS ) {
throw log.vectorPropertyUnsupportedValue( "dimension", dimension, DEFAULT_MAX_DIMENSIONS );
}
this.dimension = dimension;
return thisAsS();
}

@Override
public S indexNullAs(F indexNullAsValue) {
this.indexNullAsValue = indexNullAsValue;
Expand All @@ -102,6 +107,9 @@ public S indexNullAs(F indexNullAsValue) {

@Override
public LuceneIndexValueFieldType<F> toIndexFieldType() {
if ( dimension == null ) {
throw log.nullVectorDimension( buildContext.hints().missingVectorDimension(), buildContext.getEventContext() );
}
VectorSimilarityFunction resolvedVectorSimilarity = resolveDefault( vectorSimilarity );
boolean resolvedProjectable = resolveDefault( projectable );
boolean resolvedSearchable = resolveDefault( searchable );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class LuceneByteVectorFieldTypeOptionsStep
extends AbstractLuceneVectorFieldTypeOptionsStep<LuceneByteVectorFieldTypeOptionsStep, byte[]> {

LuceneByteVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, int dimension) {
super( buildContext, byte[].class, dimension );
LuceneByteVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext) {
super( buildContext, byte[].class );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class LuceneFloatVectorFieldTypeOptionsStep
extends AbstractLuceneVectorFieldTypeOptionsStep<LuceneFloatVectorFieldTypeOptionsStep, float[]> {

LuceneFloatVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, int dimension) {
super( buildContext, float[].class, dimension );
LuceneFloatVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext) {
super( buildContext, float[].class );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ else if ( BigInteger.class.equals( valueType ) ) {
}

@SuppressWarnings("unchecked")
public <F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType) {
public <F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType) {
if ( byte[].class.equals( valueType ) ) {
return (VectorFieldTypeOptionsStep<?, F>) asByteVector( dimension );
return (VectorFieldTypeOptionsStep<?, F>) asByteVector();
}
else if ( float[].class.equals( valueType ) ) {
return (VectorFieldTypeOptionsStep<?, F>) asFloatVector( dimension );
return (VectorFieldTypeOptionsStep<?, F>) asFloatVector();
}
else {
throw log.cannotGuessVectorFieldType( valueType, getEventContext() );
Expand Down Expand Up @@ -246,13 +246,13 @@ public ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger() {
}

@Override
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension) {
return new LuceneByteVectorFieldTypeOptionsStep( this, dimension );
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector() {
return new LuceneByteVectorFieldTypeOptionsStep( this );
}

@Override
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension) {
return new LuceneFloatVectorFieldTypeOptionsStep( this, dimension );
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector() {
return new LuceneFloatVectorFieldTypeOptionsStep( this );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ public interface BackendMappingHints {
@Message("")
String missingDecimalScale();

@Message("")
String missingVectorDimension();

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ public interface IndexFieldTypeFactory {
/**
* Define a vector field type whose values are represented as a given type in Hibernate Search.
* <p>
* When possible, prefer the other methods such as {@link #asByteVector(int)} or {@link #asFloatVector(int)}
* When possible, prefer the other methods such as {@link #asByteVector()} or {@link #asFloatVector()}
* to avoid unnecessary type checks.
*
* @param <F> The type of values for this field type.
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @param valueType The type of values for this field type. Should be an array type like {@code byte[]} or {@code float[]}.
* @return A DSL step where the index vector field type can be defined in more details.
*/
<F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType);
@Incubating
<F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType);

/**
* Define a field type whose values are represented as a {@link String} in Hibernate Search.
Expand Down Expand Up @@ -180,26 +180,23 @@ public interface IndexFieldTypeFactory {
*/
ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger();


/**
* Define a field type intended for use in vector search
* and whose values are represented as a {@code byte[]} in Hibernate Search.
*
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @return A DSL step where the index field type can be defined in more details.
*/
@Incubating
VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension);
VectorFieldTypeOptionsStep<?, byte[]> asByteVector();

/**
* Define a field type intended for use in vector search
* and whose values are represented as a {@code float[]} in Hibernate Search.
*
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @return A DSL step where the index field type can be defined in more details.
*/
@Incubating
VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension);
VectorFieldTypeOptionsStep<?, float[]> asFloatVector();

/**
* Extend the current factory with the given extension,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,10 @@ public interface VectorFieldTypeOptionsStep<S extends VectorFieldTypeOptionsStep
*/
S maxConnections(int maxConnections);

/**
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @return {@code this}, for method chaining.
*/
S dimension(int dimension);

}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private static class IndexBinding {
final IndexFieldReference<byte[]> vectorField;

IndexBinding(IndexSchemaElement root) {
vectorField = root.field( "vector", c -> c.asByteVector( 2 ).maxConnections( 10 ) ).toReference();
vectorField = root.field( "vector", c -> c.asByteVector().dimension( 2 ).maxConnections( 10 ) ).toReference();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void test(int dimension, int beamWidth, int maxConnections, String property, int
.withIndex( SimpleMappedIndex
.of( root -> root
.field( "vector",
f -> f.asByteVector( dimension ).beamWidth( beamWidth )
f -> f.asByteVector().dimension( dimension ).beamWidth( beamWidth )
.maxConnections( maxConnections ) )
.toReference() ) )
.setup()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ private static class PredicateIndexBinding {
PredicateIndexBinding(IndexSchemaElement root) {
parking = root.field( "parking", f -> f.asBoolean().projectable( Projectable.YES ) ).toReference();
rating = root.field( "rating", f -> f.asInteger().projectable( Projectable.YES ) ).toReference();
location = root.field( "location", f -> f.asFloatVector( 2 ).projectable( Projectable.YES )
location = root.field( "location", f -> f.asFloatVector().dimension( 2 ).projectable( Projectable.YES )
.maxConnections( 16 ).beamWidth( 100 ).vectorSimilarity( VectorSimilarity.L2 ) ).toReference();
}
}
Expand All @@ -727,7 +727,7 @@ private static class MultiValuedIndexBinding {
final IndexFieldReference<byte[]> vector;

private MultiValuedIndexBinding(IndexSchemaElement root) {
vector = root.field( "vector", f -> f.asByteVector( 2 ) ).multiValued().toReference();
vector = root.field( "vector", f -> f.asByteVector().dimension( 2 ) ).multiValued().toReference();
}
}

Expand All @@ -742,10 +742,10 @@ private static class NestedIndexBinding {
nested = nestedField.toReference();

byteVector = nestedField.field(
"byteVector", f -> f.asByteVector( 2 ).projectable( Projectable.YES ) )
"byteVector", f -> f.asByteVector().dimension( 2 ).projectable( Projectable.YES ) )
.toReference();
floatVector = nestedField
.field( "floatVector", f -> f.asFloatVector( 2 ).projectable( Projectable.YES ) )
.field( "floatVector", f -> f.asFloatVector().dimension( 2 ).projectable( Projectable.YES ) )
.toReference();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public VectorFieldTypeDescriptor<byte[]> withDimension(int dimension) {

@Override
public VectorFieldTypeOptionsStep<?, byte[]> configure(IndexFieldTypeFactory fieldContext) {
return fieldContext.asByteVector( size );
return fieldContext.asByteVector().dimension( size );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public VectorFieldTypeDescriptor<float[]> withDimension(int dimension) {

@Override
public VectorFieldTypeOptionsStep<?, float[]> configure(IndexFieldTypeFactory fieldContext) {
return fieldContext.asFloatVector( size );
return fieldContext.asFloatVector().dimension( size );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,15 +345,9 @@ class IndexedEntity {
.satisfies( FailureReportUtils.hasFailureReport()
.typeContext( IndexedEntity.class.getName() )
.pathContext( ".notVector" )
.failure(
"Unable to apply property mapping: this property mapping must target an index field of vector type",
"but the resolved field type is non-vector",
"This generally means you need to use a different field annotation"
+ " or to convert property values using a custom ValueBridge or ValueBinder",
"If you are already using a custom ValueBridge or ValueBinder, check its field type",
"encountered type DSL step '",
"expected interface '" + VectorFieldTypeOptionsStep.class.getName() + "'"
) );
// NOTE: this is an exception from the IndexFieldTypeFactory implementation, hence it is backend-specific and in this case
// it is from the stub-backend.
.failure( "No built-in vector index field type for class: 'java.lang.Integer'." ) );
}

@Test
Expand All @@ -371,14 +365,9 @@ class IndexedEntity {
.satisfies( FailureReportUtils.hasFailureReport()
.typeContext( IndexedEntity.class.getName() )
.pathContext( ".bytes" )
.failure(
"Unable to apply property mapping: this property mapping must target an index field of vector type, but the resolved field type is non-vector",
"This generally means you need to use a different field annotation"
+ " or to convert property values using a custom ValueBridge or ValueBinder",
"If you are already using a custom ValueBridge or ValueBinder, check its field type",
"encountered type DSL step '",
"expected interface '" + VectorFieldTypeOptionsStep.class.getName() + "'"
) );
// NOTE: this is an exception from the IndexFieldTypeFactory implementation, hence it is backend-specific and in this case
// it is from the stub-backend.
.failure( "No built-in vector index field type for class: 'java.lang.Integer'." ) );
}

@Test
Expand Down Expand Up @@ -431,26 +420,6 @@ class IndexedEntity {
+ " or set the type parameter F to a definite, raw type." ) );
}

@Test
void customBridge_vectorDimensionUnknown() {
@Indexed(index = INDEX_NAME)
class IndexedEntity {
@DocumentId
Integer id;
@VectorField(valueBinder = @ValueBinderRef(type = ValidImplicitTypeBridge.ValidImplicitTypeBinder.class))
Collection<Float> floats;
}

assertThatThrownBy( () -> setupHelper.start().expectCustomBeans().setup( IndexedEntity.class ) )
.isInstanceOf( SearchException.class )
.satisfies( FailureReportUtils.hasFailureReport()
.typeContext( IndexedEntity.class.getName() )
.pathContext( ".floats" )
.failure( "Vector dimension is a required property. "
+ "Either specify it as an annotation property (@VectorField(dimension = somePositiveInteger)),"
+ " or define a value binder (@VectorField(valueBinder = @ValueBinderRef(..))) that explicitly declares a vector field specifying the dimension." ) );
}

@Test
void valueExtractorsEnabled() {
@Indexed(index = INDEX_NAME)
Expand Down Expand Up @@ -521,6 +490,24 @@ class IndexedEntity {
+ "e.g. @ContainerExtraction(extract = ContainerExtract.DEFAULT, value = { ... })." ) );
}

@Test
void customBridge_dimensionFromAnnotationTypeInBridge() {
@Indexed(index = INDEX_NAME)
class IndexedEntity {
@DocumentId
Integer id;
@VectorField(dimension = 3,
valueBinder = @ValueBinderRef(type = ListTypeBridgeDimensionFromAnnotation.ExplicitFieldTypeBinder.class))
List<Float> floats;
}

backendMock.expectSchema( INDEX_NAME, b -> b
.field( "floats", float[].class, f -> f.dimension( 3 ) )
);
setupHelper.start().expectCustomBeans().setup( IndexedEntity.class );
backendMock.verifyExpectationsMet();
}

@SuppressWarnings("rawtypes")
public static class ValidTypeBridge implements ValueBridge<List, byte[]> {
@Override
Expand All @@ -539,7 +526,31 @@ public byte[] toIndexedValue(List value, ValueBridgeToIndexedValueContext contex
public static class ExplicitFieldTypeBinder implements ValueBinder {
@Override
public void bind(ValueBindingContext<?> context) {
context.bridge( List.class, new ValidTypeBridge(), context.typeFactory().asByteVector( 2 ) );
context.bridge( List.class, new ValidTypeBridge(), context.typeFactory().asByteVector().dimension( 2 ) );
}
}
}

@SuppressWarnings("rawtypes")
public static class ListTypeBridgeDimensionFromAnnotation implements ValueBridge<List, float[]> {
@Override
public float[] toIndexedValue(List value, ValueBridgeToIndexedValueContext context) {
if ( value == null ) {
return null;
}
float[] floats = new float[value.size()];
int index = 0;
for ( Object o : value ) {
floats[index++] = Byte.parseByte( Objects.toString( o, null ) );
}
return floats;
}

public static class ExplicitFieldTypeBinder implements ValueBinder {
@Override
public void bind(ValueBindingContext<?> context) {
context.bridge( List.class, new ListTypeBridgeDimensionFromAnnotation(),
context.typeFactory().asFloatVector() );
}
}
}
Expand All @@ -564,7 +575,7 @@ public static class ParametricBinder implements ValueBinder {
@Override
public void bind(ValueBindingContext<?> context) {
context.bridge( List.class, new ParametricBridge(),
context.typeFactory().asFloatVector( extractDimension( context ) )
context.typeFactory().asFloatVector().dimension( extractDimension( context ) )
);
}
}
Expand Down
Loading
Loading