Skip to content

Commit 354e07c

Browse files
0xC0FFE2beikov
authored andcommitted
HHH-19284 : Extract duplicated vector distance function registration
1 parent 22c4c60 commit 354e07c

File tree

3 files changed

+124
-161
lines changed

3 files changed

+124
-161
lines changed

hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java

+24-24
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,37 @@
1717
import org.hibernate.type.spi.TypeConfiguration;
1818

1919
public class MariaDBFunctionContributor implements FunctionContributor {
20-
2120
@Override
2221
public void contributeFunctions(FunctionContributions functionContributions) {
23-
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
24-
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
25-
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
2622
final Dialect dialect = functionContributions.getDialect();
27-
if ( dialect instanceof MariaDBDialect ) {
28-
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
29-
30-
functionRegistry.patternDescriptorBuilder( "cosine_distance", "vec_distance_cosine(?1,?2)" )
31-
.setArgumentsValidator( StandardArgumentsValidators.composite(
32-
StandardArgumentsValidators.exactly( 2 ),
33-
VectorArgumentValidator.INSTANCE
34-
) )
35-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
36-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
37-
.register();
38-
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vec_distance_euclidean(?1,?2)" )
39-
.setArgumentsValidator( StandardArgumentsValidators.composite(
40-
StandardArgumentsValidators.exactly( 2 ),
41-
VectorArgumentValidator.INSTANCE
42-
) )
43-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
44-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
45-
.register();
46-
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
23+
if (dialect instanceof MariaDBDialect) {
24+
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
25+
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
26+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
27+
final BasicType<Double> doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE);
4728

29+
registerVectorDistanceFunction(functionRegistry, "cosine_distance", "vec_distance_cosine", doubleType);
30+
registerVectorDistanceFunction(functionRegistry, "euclidean_distance", "vec_distance_euclidean", doubleType);
31+
functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance");
4832
}
4933
}
5034

35+
private void registerVectorDistanceFunction(
36+
SqmFunctionRegistry functionRegistry,
37+
String functionName,
38+
String templatePattern,
39+
BasicType<Double> returnType) {
40+
41+
functionRegistry.patternDescriptorBuilder(functionName, templatePattern + "(?1,?2)")
42+
.setArgumentsValidator(StandardArgumentsValidators.composite(
43+
StandardArgumentsValidators.exactly(2),
44+
VectorArgumentValidator.INSTANCE
45+
))
46+
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
47+
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
48+
.register();
49+
}
50+
5151
@Override
5252
public int ordinal() {
5353
return 200;

hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java

+50-72
Original file line numberDiff line numberDiff line change
@@ -20,84 +20,62 @@ public class OracleVectorFunctionContributor implements FunctionContributor {
2020

2121
@Override
2222
public void contributeFunctions(FunctionContributions functionContributions) {
23-
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
24-
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
25-
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
2623
final Dialect dialect = functionContributions.getDialect();
27-
if ( dialect instanceof OracleDialect ) {
28-
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
29-
final BasicType<Integer> integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER );
30-
functionRegistry.patternDescriptorBuilder( "cosine_distance", "vector_distance(?1, ?2, COSINE)" )
31-
.setArgumentsValidator( StandardArgumentsValidators.composite(
32-
StandardArgumentsValidators.exactly( 2 ),
33-
VectorArgumentValidator.INSTANCE
34-
) )
35-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
36-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
37-
.register();
38-
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vector_distance(?1, ?2, EUCLIDEAN)" )
39-
.setArgumentsValidator( StandardArgumentsValidators.composite(
40-
StandardArgumentsValidators.exactly( 2 ),
41-
VectorArgumentValidator.INSTANCE
42-
) )
43-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
44-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
45-
.register();
46-
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
24+
if (dialect instanceof OracleDialect) {
25+
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
26+
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
27+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
28+
final BasicType<Double> doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE);
29+
final BasicType<Integer> integerType = basicTypeRegistry.resolve(StandardBasicTypes.INTEGER);
4730

48-
functionRegistry.patternDescriptorBuilder( "l1_distance" , "vector_distance(?1, ?2, MANHATTAN)")
49-
.setArgumentsValidator( StandardArgumentsValidators.composite(
50-
StandardArgumentsValidators.exactly( 2 ),
51-
VectorArgumentValidator.INSTANCE
52-
) )
53-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
54-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
55-
.register();
56-
functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" );
31+
registerVectorDistanceFunction(functionRegistry, "cosine_distance", "vector_distance(?1, ?2, COSINE)", doubleType);
32+
registerVectorDistanceFunction(functionRegistry, "euclidean_distance", "vector_distance(?1, ?2, EUCLIDEAN)", doubleType);
33+
functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance");
5734

58-
functionRegistry.patternDescriptorBuilder( "negative_inner_product", "vector_distance(?1, ?2, DOT)" )
59-
.setArgumentsValidator( StandardArgumentsValidators.composite(
60-
StandardArgumentsValidators.exactly( 2 ),
61-
VectorArgumentValidator.INSTANCE
62-
) )
63-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
64-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
65-
.register();
66-
functionRegistry.patternDescriptorBuilder( "inner_product", "vector_distance(?1, ?2, DOT)*-1" )
67-
.setArgumentsValidator( StandardArgumentsValidators.composite(
68-
StandardArgumentsValidators.exactly( 2 ),
69-
VectorArgumentValidator.INSTANCE
70-
) )
71-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
72-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
73-
.register();
74-
functionRegistry.patternDescriptorBuilder( "hamming_distance", "vector_distance(?1, ?2, HAMMING)" )
75-
.setArgumentsValidator( StandardArgumentsValidators.composite(
76-
StandardArgumentsValidators.exactly( 2 ),
77-
VectorArgumentValidator.INSTANCE
78-
) )
79-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
80-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
81-
.register();
82-
functionRegistry.namedDescriptorBuilder( "vector_dims" )
83-
.setArgumentsValidator( StandardArgumentsValidators.composite(
84-
StandardArgumentsValidators.exactly( 1 ),
85-
VectorArgumentValidator.INSTANCE
86-
) )
87-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
88-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( integerType ) )
89-
.register();
90-
functionRegistry.namedDescriptorBuilder( "vector_norm" )
91-
.setArgumentsValidator( StandardArgumentsValidators.composite(
92-
StandardArgumentsValidators.exactly( 1 ),
93-
VectorArgumentValidator.INSTANCE
94-
) )
95-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
96-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
97-
.register();
35+
registerVectorDistanceFunction(functionRegistry, "l1_distance", "vector_distance(?1, ?2, MANHATTAN)", doubleType);
36+
functionRegistry.registerAlternateKey("taxicab_distance", "l1_distance");
37+
38+
registerVectorDistanceFunction(functionRegistry, "negative_inner_product", "vector_distance(?1, ?2, DOT)", doubleType);
39+
registerVectorDistanceFunction(functionRegistry, "inner_product", "vector_distance(?1, ?2, DOT)*-1", doubleType);
40+
registerVectorDistanceFunction(functionRegistry, "hamming_distance", "vector_distance(?1, ?2, HAMMING)", doubleType);
41+
42+
registerNamedVectorFunction(functionRegistry, "vector_dims", integerType, 1);
43+
registerNamedVectorFunction(functionRegistry, "vector_norm", doubleType, 1);
9844
}
9945
}
10046

47+
private void registerVectorDistanceFunction(
48+
SqmFunctionRegistry functionRegistry,
49+
String functionName,
50+
String pattern,
51+
BasicType<?> returnType) {
52+
53+
functionRegistry.patternDescriptorBuilder(functionName, pattern)
54+
.setArgumentsValidator(StandardArgumentsValidators.composite(
55+
StandardArgumentsValidators.exactly(2),
56+
VectorArgumentValidator.INSTANCE
57+
))
58+
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
59+
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
60+
.register();
61+
}
62+
63+
private void registerNamedVectorFunction(
64+
SqmFunctionRegistry functionRegistry,
65+
String functionName,
66+
BasicType<?> returnType,
67+
int argumentCount) {
68+
69+
functionRegistry.namedDescriptorBuilder(functionName)
70+
.setArgumentsValidator(StandardArgumentsValidators.composite(
71+
StandardArgumentsValidators.exactly(argumentCount),
72+
VectorArgumentValidator.INSTANCE
73+
))
74+
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
75+
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
76+
.register();
77+
}
78+
10179
@Override
10280
public int ordinal() {
10381
return 200;

hibernate-vector/src/main/java/org/hibernate/vector/PGVectorFunctionContributor.java

+50-65
Original file line numberDiff line numberDiff line change
@@ -21,76 +21,61 @@ public class PGVectorFunctionContributor implements FunctionContributor {
2121

2222
@Override
2323
public void contributeFunctions(FunctionContributions functionContributions) {
24-
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
25-
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
26-
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
2724
final Dialect dialect = functionContributions.getDialect();
28-
if ( dialect instanceof PostgreSQLDialect ||
29-
dialect instanceof CockroachDialect ) {
30-
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
31-
final BasicType<Integer> integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER );
32-
functionRegistry.patternDescriptorBuilder( "cosine_distance", "?1<=>?2" )
33-
.setArgumentsValidator( StandardArgumentsValidators.composite(
34-
StandardArgumentsValidators.exactly( 2 ),
35-
VectorArgumentValidator.INSTANCE
36-
) )
37-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
38-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
39-
.register();
40-
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "?1<->?2" )
41-
.setArgumentsValidator( StandardArgumentsValidators.composite(
42-
StandardArgumentsValidators.exactly( 2 ),
43-
VectorArgumentValidator.INSTANCE
44-
) )
45-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
46-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
47-
.register();
48-
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
49-
functionRegistry.namedDescriptorBuilder( "l1_distance" )
50-
.setArgumentsValidator( StandardArgumentsValidators.composite(
51-
StandardArgumentsValidators.exactly( 2 ),
52-
VectorArgumentValidator.INSTANCE
53-
) )
54-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
55-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
56-
.register();
57-
functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" );
25+
if (dialect instanceof PostgreSQLDialect || dialect instanceof CockroachDialect) {
26+
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
27+
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
28+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
29+
final BasicType<Double> doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE);
30+
final BasicType<Integer> integerType = basicTypeRegistry.resolve(StandardBasicTypes.INTEGER);
5831

59-
functionRegistry.patternDescriptorBuilder( "negative_inner_product", "?1<#>?2" )
60-
.setArgumentsValidator( StandardArgumentsValidators.composite(
61-
StandardArgumentsValidators.exactly( 2 ),
62-
VectorArgumentValidator.INSTANCE
63-
) )
64-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
65-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
66-
.register();
67-
functionRegistry.patternDescriptorBuilder( "inner_product", "(?1<#>?2)*-1" )
68-
.setArgumentsValidator( StandardArgumentsValidators.composite(
69-
StandardArgumentsValidators.exactly( 2 ),
70-
VectorArgumentValidator.INSTANCE
71-
) )
72-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
73-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
74-
.register();
75-
functionRegistry.namedDescriptorBuilder( "vector_dims" )
76-
.setArgumentsValidator( StandardArgumentsValidators.composite(
77-
StandardArgumentsValidators.exactly( 1 ),
78-
VectorArgumentValidator.INSTANCE
79-
) )
80-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
81-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( integerType ) )
82-
.register();
83-
functionRegistry.namedDescriptorBuilder( "vector_norm" )
84-
.setArgumentsValidator( StandardArgumentsValidators.composite(
85-
StandardArgumentsValidators.exactly( 1 ),
86-
VectorArgumentValidator.INSTANCE
87-
) )
88-
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
89-
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
90-
.register();
32+
registerVectorDistanceFunction(functionRegistry, "cosine_distance", "?1<=>?2", doubleType);
33+
registerVectorDistanceFunction(functionRegistry, "euclidean_distance", "?1<->?2", doubleType);
34+
functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance");
35+
36+
registerNamedVectorFunction(functionRegistry, "l1_distance", doubleType, 2);
37+
functionRegistry.registerAlternateKey("taxicab_distance", "l1_distance");
38+
39+
registerVectorDistanceFunction(functionRegistry, "negative_inner_product", "?1<#>?2", doubleType);
40+
registerVectorDistanceFunction(functionRegistry, "inner_product", "(?1<#>?2)*-1", doubleType);
41+
42+
registerNamedVectorFunction(functionRegistry, "vector_dims", integerType, 1);
43+
registerNamedVectorFunction(functionRegistry, "vector_norm", doubleType, 1);
9144
}
9245
}
9346

47+
private void registerVectorDistanceFunction(
48+
SqmFunctionRegistry functionRegistry,
49+
String functionName,
50+
String pattern,
51+
BasicType<?> returnType) {
52+
53+
functionRegistry.patternDescriptorBuilder(functionName, pattern)
54+
.setArgumentsValidator(StandardArgumentsValidators.composite(
55+
StandardArgumentsValidators.exactly(2),
56+
VectorArgumentValidator.INSTANCE
57+
))
58+
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
59+
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
60+
.register();
61+
}
62+
63+
private void registerNamedVectorFunction(
64+
SqmFunctionRegistry functionRegistry,
65+
String functionName,
66+
BasicType<?> returnType,
67+
int argumentCount) {
68+
69+
functionRegistry.namedDescriptorBuilder(functionName)
70+
.setArgumentsValidator(StandardArgumentsValidators.composite(
71+
StandardArgumentsValidators.exactly(argumentCount),
72+
VectorArgumentValidator.INSTANCE
73+
))
74+
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
75+
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
76+
.register();
77+
}
78+
9479
@Override
9580
public int ordinal() {
9681
return 200;

0 commit comments

Comments
 (0)