Skip to content

Commit

Permalink
HSEARCH-5021 Make dimension "optional" in IndexFieldTypeFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
marko-bekhta authored and yrodiere committed Dec 21, 2023
1 parent 9b7b0fb commit b08e96e
Show file tree
Hide file tree
Showing 23 changed files with 239 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,17 @@ public ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger() {
}

@Override
public <F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType) {
public <F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType) {
throw new UnsupportedOperationException( "The Elasticsearch backend does not support vector field yet." );
}

@Override
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension) {
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector() {
throw new UnsupportedOperationException( "The Elasticsearch backend does not support vector field yet." );
}

@Override
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension) {
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector() {
throw new UnsupportedOperationException( "The Elasticsearch backend does not support vector field yet." );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -727,4 +727,8 @@ SearchException vectorKnnMatchVectorTypeDiffersFromField(String absoluteFieldPat
+ " Matching against an array with length of '%3$s' is unsupported."
+ " Use the array of the same size as the vector field.")
SearchException vectorKnnMatchVectorDimensionDiffersFromField(String absoluteFieldPath, int expected, int actual);

@Message(id = ID_OFFSET + 179, value = "Invalid index field type: missing vector dimension."
+ " Define the vector dimension explicitly. %1$s")
SearchException nullVectorDimension(String hint, @Param EventContext eventContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,15 @@ abstract class AbstractLuceneVectorFieldTypeOptionsStep<S extends AbstractLucene
private static final int MAX_MAX_CONNECTIONS = 512;

protected VectorSimilarity vectorSimilarity = VectorSimilarity.DEFAULT;
protected final int dimension;
protected Integer dimension;
protected int beamWidth = MAX_MAX_CONNECTIONS;
protected int maxConnections = 16;
private Projectable projectable = Projectable.DEFAULT;
private Searchable searchable = Searchable.DEFAULT;
private F indexNullAsValue = null;

AbstractLuceneVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, Class<F> valueType, int dimension) {
AbstractLuceneVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, Class<F> valueType) {
super( buildContext, valueType );
if ( dimension < 1 || dimension > DEFAULT_MAX_DIMENSIONS ) {
throw log.vectorPropertyUnsupportedValue( "dimension", dimension, DEFAULT_MAX_DIMENSIONS );
}
this.dimension = dimension;
}

@Override
Expand Down Expand Up @@ -94,6 +90,15 @@ public S maxConnections(int maxConnections) {
return thisAsS();
}

@Override
public S dimension(int dimension) {
if ( dimension < 1 || dimension > DEFAULT_MAX_DIMENSIONS ) {
throw log.vectorPropertyUnsupportedValue( "dimension", dimension, DEFAULT_MAX_DIMENSIONS );
}
this.dimension = dimension;
return thisAsS();
}

@Override
public S indexNullAs(F indexNullAsValue) {
this.indexNullAsValue = indexNullAsValue;
Expand All @@ -102,6 +107,9 @@ public S indexNullAs(F indexNullAsValue) {

@Override
public LuceneIndexValueFieldType<F> toIndexFieldType() {
if ( dimension == null ) {
throw log.nullVectorDimension( buildContext.hints().missingVectorDimension(), buildContext.getEventContext() );
}
VectorSimilarityFunction resolvedVectorSimilarity = resolveDefault( vectorSimilarity );
boolean resolvedProjectable = resolveDefault( projectable );
boolean resolvedSearchable = resolveDefault( searchable );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class LuceneByteVectorFieldTypeOptionsStep
extends AbstractLuceneVectorFieldTypeOptionsStep<LuceneByteVectorFieldTypeOptionsStep, byte[]> {

LuceneByteVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, int dimension) {
super( buildContext, byte[].class, dimension );
LuceneByteVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext) {
super( buildContext, byte[].class );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class LuceneFloatVectorFieldTypeOptionsStep
extends AbstractLuceneVectorFieldTypeOptionsStep<LuceneFloatVectorFieldTypeOptionsStep, float[]> {

LuceneFloatVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, int dimension) {
super( buildContext, float[].class, dimension );
LuceneFloatVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext) {
super( buildContext, float[].class );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ else if ( BigInteger.class.equals( valueType ) ) {
}

@SuppressWarnings("unchecked")
public <F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType) {
public <F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType) {
if ( byte[].class.equals( valueType ) ) {
return (VectorFieldTypeOptionsStep<?, F>) asByteVector( dimension );
return (VectorFieldTypeOptionsStep<?, F>) asByteVector();
}
else if ( float[].class.equals( valueType ) ) {
return (VectorFieldTypeOptionsStep<?, F>) asFloatVector( dimension );
return (VectorFieldTypeOptionsStep<?, F>) asFloatVector();
}
else {
throw log.cannotGuessVectorFieldType( valueType, getEventContext() );
Expand Down Expand Up @@ -246,13 +246,13 @@ public ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger() {
}

@Override
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension) {
return new LuceneByteVectorFieldTypeOptionsStep( this, dimension );
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector() {
return new LuceneByteVectorFieldTypeOptionsStep( this );
}

@Override
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension) {
return new LuceneFloatVectorFieldTypeOptionsStep( this, dimension );
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector() {
return new LuceneFloatVectorFieldTypeOptionsStep( this );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ public interface BackendMappingHints {
@Message("")
String missingDecimalScale();

@Message("")
String missingVectorDimension();

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ public interface IndexFieldTypeFactory {
/**
* Define a vector field type whose values are represented as a given type in Hibernate Search.
* <p>
* When possible, prefer the other methods such as {@link #asByteVector(int)} or {@link #asFloatVector(int)}
* When possible, prefer the other methods such as {@link #asByteVector()} or {@link #asFloatVector()}
* to avoid unnecessary type checks.
*
* @param <F> The type of values for this field type.
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @param valueType The type of values for this field type. Should be an array type like {@code byte[]} or {@code float[]}.
* @return A DSL step where the index vector field type can be defined in more details.
*/
<F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType);
@Incubating
<F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType);

/**
* Define a field type whose values are represented as a {@link String} in Hibernate Search.
Expand Down Expand Up @@ -180,26 +180,23 @@ public interface IndexFieldTypeFactory {
*/
ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger();


/**
* Define a field type intended for use in vector search
* and whose values are represented as a {@code byte[]} in Hibernate Search.
*
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @return A DSL step where the index field type can be defined in more details.
*/
@Incubating
VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension);
VectorFieldTypeOptionsStep<?, byte[]> asByteVector();

/**
* Define a field type intended for use in vector search
* and whose values are represented as a {@code float[]} in Hibernate Search.
*
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @return A DSL step where the index field type can be defined in more details.
*/
@Incubating
VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension);
VectorFieldTypeOptionsStep<?, float[]> asFloatVector();

/**
* Extend the current factory with the given extension,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,10 @@ public interface VectorFieldTypeOptionsStep<S extends VectorFieldTypeOptionsStep
*/
S maxConnections(int maxConnections);

/**
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @return {@code this}, for method chaining.
*/
S dimension(int dimension);

}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private static class IndexBinding {
final IndexFieldReference<byte[]> vectorField;

IndexBinding(IndexSchemaElement root) {
vectorField = root.field( "vector", c -> c.asByteVector( 2 ).maxConnections( 10 ) ).toReference();
vectorField = root.field( "vector", c -> c.asByteVector().dimension( 2 ).maxConnections( 10 ) ).toReference();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void test(int dimension, int beamWidth, int maxConnections, String property, int
.withIndex( SimpleMappedIndex
.of( root -> root
.field( "vector",
f -> f.asByteVector( dimension ).beamWidth( beamWidth )
f -> f.asByteVector().dimension( dimension ).beamWidth( beamWidth )
.maxConnections( maxConnections ) )
.toReference() ) )
.setup()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ private static class PredicateIndexBinding {
PredicateIndexBinding(IndexSchemaElement root) {
parking = root.field( "parking", f -> f.asBoolean().projectable( Projectable.YES ) ).toReference();
rating = root.field( "rating", f -> f.asInteger().projectable( Projectable.YES ) ).toReference();
location = root.field( "location", f -> f.asFloatVector( 2 ).projectable( Projectable.YES )
location = root.field( "location", f -> f.asFloatVector().dimension( 2 ).projectable( Projectable.YES )
.maxConnections( 16 ).beamWidth( 100 ).vectorSimilarity( VectorSimilarity.L2 ) ).toReference();
}
}
Expand All @@ -727,7 +727,7 @@ private static class MultiValuedIndexBinding {
final IndexFieldReference<byte[]> vector;

private MultiValuedIndexBinding(IndexSchemaElement root) {
vector = root.field( "vector", f -> f.asByteVector( 2 ) ).multiValued().toReference();
vector = root.field( "vector", f -> f.asByteVector().dimension( 2 ) ).multiValued().toReference();
}
}

Expand All @@ -742,10 +742,10 @@ private static class NestedIndexBinding {
nested = nestedField.toReference();

byteVector = nestedField.field(
"byteVector", f -> f.asByteVector( 2 ).projectable( Projectable.YES ) )
"byteVector", f -> f.asByteVector().dimension( 2 ).projectable( Projectable.YES ) )
.toReference();
floatVector = nestedField
.field( "floatVector", f -> f.asFloatVector( 2 ).projectable( Projectable.YES ) )
.field( "floatVector", f -> f.asFloatVector().dimension( 2 ).projectable( Projectable.YES ) )
.toReference();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public VectorFieldTypeDescriptor<byte[]> withDimension(int dimension) {

@Override
public VectorFieldTypeOptionsStep<?, byte[]> configure(IndexFieldTypeFactory fieldContext) {
return fieldContext.asByteVector( size );
return fieldContext.asByteVector().dimension( size );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public VectorFieldTypeDescriptor<float[]> withDimension(int dimension) {

@Override
public VectorFieldTypeOptionsStep<?, float[]> configure(IndexFieldTypeFactory fieldContext) {
return fieldContext.asFloatVector( size );
return fieldContext.asFloatVector().dimension( size );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,26 +420,6 @@ class IndexedEntity {
+ " or set the type parameter F to a definite, raw type." ) );
}

@Test
void customBridge_vectorDimensionUnknown() {
@Indexed(index = INDEX_NAME)
class IndexedEntity {
@DocumentId
Integer id;
@VectorField(valueBinder = @ValueBinderRef(type = ValidImplicitTypeBridge.ValidImplicitTypeBinder.class))
Collection<Float> floats;
}

assertThatThrownBy( () -> setupHelper.start().expectCustomBeans().setup( IndexedEntity.class ) )
.isInstanceOf( SearchException.class )
.satisfies( FailureReportUtils.hasFailureReport()
.typeContext( IndexedEntity.class.getName() )
.pathContext( ".floats" )
.failure( "Vector dimension is a required property. "
+ "Either specify it as an annotation property (@VectorField(dimension = somePositiveInteger)),"
+ " or define a value binder (@VectorField(valueBinder = @ValueBinderRef(..))) that explicitly declares a vector field specifying the dimension." ) );
}

@Test
void valueExtractorsEnabled() {
@Indexed(index = INDEX_NAME)
Expand Down Expand Up @@ -510,6 +490,24 @@ class IndexedEntity {
+ "e.g. @ContainerExtraction(extract = ContainerExtract.DEFAULT, value = { ... })." ) );
}

@Test
void customBridge_dimensionFromAnnotationTypeInBridge() {
@Indexed(index = INDEX_NAME)
class IndexedEntity {
@DocumentId
Integer id;
@VectorField(dimension = 3,
valueBinder = @ValueBinderRef(type = ListTypeBridgeDimensionFromAnnotation.ExplicitFieldTypeBinder.class))
List<Float> floats;
}

backendMock.expectSchema( INDEX_NAME, b -> b
.field( "floats", float[].class, f -> f.dimension( 3 ) )
);
setupHelper.start().expectCustomBeans().setup( IndexedEntity.class );
backendMock.verifyExpectationsMet();
}

@SuppressWarnings("rawtypes")
public static class ValidTypeBridge implements ValueBridge<List, byte[]> {
@Override
Expand All @@ -528,7 +526,31 @@ public byte[] toIndexedValue(List value, ValueBridgeToIndexedValueContext contex
public static class ExplicitFieldTypeBinder implements ValueBinder {
@Override
public void bind(ValueBindingContext<?> context) {
context.bridge( List.class, new ValidTypeBridge(), context.typeFactory().asByteVector( 2 ) );
context.bridge( List.class, new ValidTypeBridge(), context.typeFactory().asByteVector().dimension( 2 ) );
}
}
}

@SuppressWarnings("rawtypes")
public static class ListTypeBridgeDimensionFromAnnotation implements ValueBridge<List, float[]> {
@Override
public float[] toIndexedValue(List value, ValueBridgeToIndexedValueContext context) {
if ( value == null ) {
return null;
}
float[] floats = new float[value.size()];
int index = 0;
for ( Object o : value ) {
floats[index++] = Byte.parseByte( Objects.toString( o, null ) );
}
return floats;
}

public static class ExplicitFieldTypeBinder implements ValueBinder {
@Override
public void bind(ValueBindingContext<?> context) {
context.bridge( List.class, new ListTypeBridgeDimensionFromAnnotation(),
context.typeFactory().asFloatVector() );
}
}
}
Expand All @@ -553,7 +575,7 @@ public static class ParametricBinder implements ValueBinder {
@Override
public void bind(ValueBindingContext<?> context) {
context.bridge( List.class, new ParametricBridge(),
context.typeFactory().asFloatVector( extractDimension( context ) )
context.typeFactory().asFloatVector().dimension( extractDimension( context ) )
);
}
}
Expand Down
2 changes: 2 additions & 0 deletions integrationtest/mapper/pojo-standalone-realbackend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
</systemPropertyVariables>
<excludes>
<exclude>org.hibernate.search.integrationtest.mapper.pojo.standalone.realbackend.schema.management.LuceneSchemaManagerExporterIT</exclude>
<!-- Include once HSEARCH-4950 is implemented -->
<exclude>org.hibernate.search.integrationtest.mapper.pojo.standalone.realbackend.mapping.VectorFieldIT</exclude>
</excludes>
</configuration>
</execution>
Expand Down
Loading

0 comments on commit b08e96e

Please sign in to comment.