Skip to content

Commit

Permalink
HSEARCH-5021 Make dimension "optional" in IndexFieldTypeFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
marko-bekhta committed Dec 21, 2023
1 parent cd93e0a commit fdf049c
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -727,4 +727,14 @@ SearchException vectorKnnMatchVectorTypeDiffersFromField(String absoluteFieldPat
+ " Matching against an array with length of '%3$s' is unsupported."
+ " Use the array of the same size as the vector field.")
SearchException vectorKnnMatchVectorDimensionDiffersFromField(String absoluteFieldPath, int expected, int actual);

@Message(id = ID_OFFSET + 179, value = "Vector dimension is a required property. "
+ "Either specify it as an annotation property (@VectorField(dimension = somePositiveInteger)), "
+ "or define a value binder (@VectorField(valueBinder = @ValueBinderRef(..))) that explicitly declares a vector field specifying the dimension.")
SearchException vectorDimensionNotSpecified();

@Message(id = ID_OFFSET + 180, value = "Vector field dimension can only be specified once."
+ " Either set it through an annotation e.g. `@VectorField(dimension=..)`,"
+ " or set it in the binder i.e. `context.typeFactory().asVector( ... )`.")
SearchException vectorDimensionCanOnlyBeSetOnce();
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,15 @@ abstract class AbstractLuceneVectorFieldTypeOptionsStep<S extends AbstractLucene
private static final int MAX_MAX_CONNECTIONS = 512;

protected VectorSimilarity vectorSimilarity = VectorSimilarity.DEFAULT;
protected final int dimension;
protected Integer dimension;
protected int beamWidth = MAX_MAX_CONNECTIONS;
protected int maxConnections = 16;
private Projectable projectable = Projectable.DEFAULT;
private Searchable searchable = Searchable.DEFAULT;
private F indexNullAsValue = null;

AbstractLuceneVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, Class<F> valueType, int dimension) {
AbstractLuceneVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, Class<F> valueType) {
super( buildContext, valueType );
if ( dimension < 1 || dimension > DEFAULT_MAX_DIMENSIONS ) {
throw log.vectorPropertyUnsupportedValue( "dimension", dimension, DEFAULT_MAX_DIMENSIONS );
}
this.dimension = dimension;
}

@Override
Expand Down Expand Up @@ -94,6 +90,18 @@ public S maxConnections(int maxConnections) {
return thisAsS();
}

@Override
public S dimension(int dimension) {
if ( dimension < 1 || dimension > DEFAULT_MAX_DIMENSIONS ) {
throw log.vectorPropertyUnsupportedValue( "dimension", dimension, DEFAULT_MAX_DIMENSIONS );
}
if ( this.dimension != null ) {
throw log.vectorDimensionCanOnlyBeSetOnce();
}
this.dimension = dimension;
return thisAsS();
}

@Override
public S indexNullAs(F indexNullAsValue) {
this.indexNullAsValue = indexNullAsValue;
Expand All @@ -102,6 +110,9 @@ public S indexNullAs(F indexNullAsValue) {

@Override
public LuceneIndexValueFieldType<F> toIndexFieldType() {
if ( dimension == null ) {
throw log.vectorDimensionNotSpecified();
}
VectorSimilarityFunction resolvedVectorSimilarity = resolveDefault( vectorSimilarity );
boolean resolvedProjectable = resolveDefault( projectable );
boolean resolvedSearchable = resolveDefault( searchable );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class LuceneByteVectorFieldTypeOptionsStep
extends AbstractLuceneVectorFieldTypeOptionsStep<LuceneByteVectorFieldTypeOptionsStep, byte[]> {

LuceneByteVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, int dimension) {
super( buildContext, byte[].class, dimension );
LuceneByteVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext) {
super( buildContext, byte[].class );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class LuceneFloatVectorFieldTypeOptionsStep
extends AbstractLuceneVectorFieldTypeOptionsStep<LuceneFloatVectorFieldTypeOptionsStep, float[]> {

LuceneFloatVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext, int dimension) {
super( buildContext, float[].class, dimension );
LuceneFloatVectorFieldTypeOptionsStep(LuceneIndexFieldTypeBuildContext buildContext) {
super( buildContext, float[].class );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ else if ( BigInteger.class.equals( valueType ) ) {
}

@SuppressWarnings("unchecked")
public <F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType) {
public <F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType) {
if ( byte[].class.equals( valueType ) ) {
return (VectorFieldTypeOptionsStep<?, F>) asByteVector( dimension );
return (VectorFieldTypeOptionsStep<?, F>) asByteVector();
}
else if ( float[].class.equals( valueType ) ) {
return (VectorFieldTypeOptionsStep<?, F>) asFloatVector( dimension );
return (VectorFieldTypeOptionsStep<?, F>) asFloatVector();
}
else {
throw log.cannotGuessVectorFieldType( valueType, getEventContext() );
Expand Down Expand Up @@ -246,13 +246,13 @@ public ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger() {
}

@Override
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension) {
return new LuceneByteVectorFieldTypeOptionsStep( this, dimension );
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector() {
return new LuceneByteVectorFieldTypeOptionsStep( this );
}

@Override
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension) {
return new LuceneFloatVectorFieldTypeOptionsStep( this, dimension );
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector() {
return new LuceneFloatVectorFieldTypeOptionsStep( this );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,23 @@ public interface IndexFieldTypeFactory {
* @param valueType The type of values for this field type. Should be an array type like {@code byte[]} or {@code float[]}.
* @return A DSL step where the index vector field type can be defined in more details.
*/
<F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType);
@Incubating
default <F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType) {
return asVector( valueType ).dimension( dimension );
}

/**
* Define a vector field type whose values are represented as a given type in Hibernate Search.
* <p>
* When possible, prefer the other methods such as {@link #asByteVector(int)} or {@link #asFloatVector(int)}
* to avoid unnecessary type checks.
*
* @param <F> The type of values for this field type.
* @param valueType The type of values for this field type. Should be an array type like {@code byte[]} or {@code float[]}.
* @return A DSL step where the index vector field type can be defined in more details.
*/
@Incubating
<F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType);

/**
* Define a field type whose values are represented as a {@link String} in Hibernate Search.
Expand Down Expand Up @@ -180,7 +196,6 @@ public interface IndexFieldTypeFactory {
*/
ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger();


/**
* Define a field type intended for use in vector search
* and whose values are represented as a {@code byte[]} in Hibernate Search.
Expand All @@ -189,7 +204,18 @@ public interface IndexFieldTypeFactory {
* @return A DSL step where the index field type can be defined in more details.
*/
@Incubating
VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension);
default VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension) {
return asByteVector().dimension( dimension );
}

/**
* Define a field type intended for use in vector search
* and whose values are represented as a {@code byte[]} in Hibernate Search.
*
* @return A DSL step where the index field type can be defined in more details.
*/
@Incubating
VectorFieldTypeOptionsStep<?, byte[]> asByteVector();

/**
* Define a field type intended for use in vector search
Expand All @@ -199,7 +225,18 @@ public interface IndexFieldTypeFactory {
* @return A DSL step where the index field type can be defined in more details.
*/
@Incubating
VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension);
default VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension) {
return asFloatVector().dimension( dimension );
}

/**
* Define a field type intended for use in vector search
* and whose values are represented as a {@code float[]} in Hibernate Search.
*
* @return A DSL step where the index field type can be defined in more details.
*/
@Incubating
VectorFieldTypeOptionsStep<?, float[]> asFloatVector();

/**
* Extend the current factory with the given extension,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,10 @@ public interface VectorFieldTypeOptionsStep<S extends VectorFieldTypeOptionsStep
*/
S maxConnections(int maxConnections);

/**
* @param dimension The number of dimensions (array length) of vectors to be indexed.
* @return {@code this}, for method chaining.
*/
S dimension(int dimension);

}
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,24 @@ class IndexedEntity {
+ "e.g. @ContainerExtraction(extract = ContainerExtract.DEFAULT, value = { ... })." ) );
}

@Test
void customBridge_dimensionFromAnnotationTypeInBridge() {
@Indexed(index = INDEX_NAME)
class IndexedEntity {
@DocumentId
Integer id;
@VectorField(dimension = 3,
valueBinder = @ValueBinderRef(type = ListTypeBridgeDimensionFromAnnotation.ExplicitFieldTypeBinder.class))
List<Float> floats;
}

backendMock.expectSchema( INDEX_NAME, b -> b
.field( "floats", float[].class, f -> f.dimension( 3 ) )
);
setupHelper.start().expectCustomBeans().setup( IndexedEntity.class );
backendMock.verifyExpectationsMet();
}

@SuppressWarnings("rawtypes")
public static class ValidTypeBridge implements ValueBridge<List, byte[]> {
@Override
Expand All @@ -533,6 +551,30 @@ public void bind(ValueBindingContext<?> context) {
}
}

@SuppressWarnings("rawtypes")
public static class ListTypeBridgeDimensionFromAnnotation implements ValueBridge<List, float[]> {
@Override
public float[] toIndexedValue(List value, ValueBridgeToIndexedValueContext context) {
if ( value == null ) {
return null;
}
float[] floats = new float[value.size()];
int index = 0;
for ( Object o : value ) {
floats[index++] = Byte.parseByte( Objects.toString( o, null ) );
}
return floats;
}

public static class ExplicitFieldTypeBinder implements ValueBinder {
@Override
public void bind(ValueBindingContext<?> context) {
context.bridge( List.class, new ListTypeBridgeDimensionFromAnnotation(),
context.typeFactory().asFloatVector() );
}
}
}

@SuppressWarnings("rawtypes")
public static class ParametricBridge implements ValueBridge<List, float[]> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ class PropertyMappingVectorFieldOptionsStepImpl
@Override
public <F> IndexFieldTypeOptionsStep<?, F> initiate(IndexFieldTypeFactory factory, Class<F> clazz) {
if ( dimension == null ) {
throw log.vectorDimensionNotSpecified();
return factory.asVector( clazz );
}
else {
return factory.asVector( dimension, clazz );
}
return factory.asVector( dimension, clazz );
}
},
FieldModelContributorContext::vectorTypeOptionsStep );
fieldModelContributor.add( c -> c.vectorTypeOptionsStep().dimension( dimension ) );
extractors( ContainerExtractorPath.noExtractors() );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,24 @@ public ScaledNumberIndexFieldTypeOptionsStep<?, BigInteger> asBigInteger() {

@Override
@SuppressWarnings("unchecked")
public <F> VectorFieldTypeOptionsStep<?, F> asVector(int dimension, Class<F> valueType) {
public <F> VectorFieldTypeOptionsStep<?, F> asVector(Class<F> valueType) {
if ( byte[].class.equals( valueType ) ) {
return (VectorFieldTypeOptionsStep<?, F>) asByteVector( dimension );
return (VectorFieldTypeOptionsStep<?, F>) asByteVector();
}
if ( float[].class.equals( valueType ) ) {
return (VectorFieldTypeOptionsStep<?, F>) asFloatVector( dimension );
return (VectorFieldTypeOptionsStep<?, F>) asFloatVector();
}
throw new SearchException( "No built-in vector index field type for class: '" + valueType.getName() + "'." );
}

@Override
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector(int dimension) {
return new StubVectorFieldTypeOptionsStep<>( dimension, byte[].class );
public VectorFieldTypeOptionsStep<?, byte[]> asByteVector() {
return new StubVectorFieldTypeOptionsStep<>( byte[].class );
}

@Override
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector(int dimension) {
return new StubVectorFieldTypeOptionsStep<>( dimension, float[].class );
public VectorFieldTypeOptionsStep<?, float[]> asFloatVector() {
return new StubVectorFieldTypeOptionsStep<>( float[].class );
}

public <T> IndexFieldTypeOptionsStep<?, T> asNonStandard(Class<T> fieldValueType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ class StubVectorFieldTypeOptionsStep<F>
extends AbstractStubSearchableProjectableIndexFieldTypeOptionsStep<StubVectorFieldTypeOptionsStep<F>, F>
implements VectorFieldTypeOptionsStep<StubVectorFieldTypeOptionsStep<F>, F> {

StubVectorFieldTypeOptionsStep(int dimension, Class<F> klass) {
StubVectorFieldTypeOptionsStep(Class<F> klass) {
super( klass );
builder.modifier( b -> b.dimension( dimension ) );
}

@Override
Expand All @@ -36,6 +35,12 @@ public StubVectorFieldTypeOptionsStep<F> maxConnections(int maxConnections) {
return this;
}

@Override
public StubVectorFieldTypeOptionsStep<F> dimension(int dimension) {
builder.modifier( b -> b.dimension( dimension ) );
return this;
}

@Override
StubVectorFieldTypeOptionsStep<F> thisAsS() {
return this;
Expand Down

0 comments on commit fdf049c

Please sign in to comment.