Skip to content

Commit e51368c

Browse files
committed
HHH-19284 Align vector function registration with common functions
1 parent 354e07c commit e51368c

File tree

4 files changed

+113
-141
lines changed

4 files changed

+113
-141
lines changed

hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java

+4-31
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,19 @@
88
import org.hibernate.boot.model.FunctionContributor;
99
import org.hibernate.dialect.Dialect;
1010
import org.hibernate.dialect.MariaDBDialect;
11-
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
12-
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
13-
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
14-
import org.hibernate.type.BasicType;
15-
import org.hibernate.type.BasicTypeRegistry;
16-
import org.hibernate.type.StandardBasicTypes;
17-
import org.hibernate.type.spi.TypeConfiguration;
1811

1912
public class MariaDBFunctionContributor implements FunctionContributor {
2013
@Override
2114
public void contributeFunctions(FunctionContributions functionContributions) {
2215
final Dialect dialect = functionContributions.getDialect();
23-
if (dialect instanceof MariaDBDialect) {
24-
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
25-
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
26-
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
27-
final BasicType<Double> doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE);
16+
if ( dialect instanceof MariaDBDialect ) {
17+
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
2818

29-
registerVectorDistanceFunction(functionRegistry, "cosine_distance", "vec_distance_cosine", doubleType);
30-
registerVectorDistanceFunction(functionRegistry, "euclidean_distance", "vec_distance_euclidean", doubleType);
31-
functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance");
19+
vectorFunctionFactory.cosineDistance( "vec_distance_cosine(?1,?2)" );
20+
vectorFunctionFactory.euclideanDistance( "vec_distance_euclidean(?1,?2)" );
3221
}
3322
}
3423

35-
private void registerVectorDistanceFunction(
36-
SqmFunctionRegistry functionRegistry,
37-
String functionName,
38-
String templatePattern,
39-
BasicType<Double> returnType) {
40-
41-
functionRegistry.patternDescriptorBuilder(functionName, templatePattern + "(?1,?2)")
42-
.setArgumentsValidator(StandardArgumentsValidators.composite(
43-
StandardArgumentsValidators.exactly(2),
44-
VectorArgumentValidator.INSTANCE
45-
))
46-
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
47-
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
48-
.register();
49-
}
50-
5124
@Override
5225
public int ordinal() {
5326
return 200;

hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java

+10-56
Original file line numberDiff line numberDiff line change
@@ -8,74 +8,28 @@
88
import org.hibernate.boot.model.FunctionContributor;
99
import org.hibernate.dialect.Dialect;
1010
import org.hibernate.dialect.OracleDialect;
11-
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
12-
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
13-
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
14-
import org.hibernate.type.BasicType;
15-
import org.hibernate.type.BasicTypeRegistry;
16-
import org.hibernate.type.StandardBasicTypes;
17-
import org.hibernate.type.spi.TypeConfiguration;
1811

1912
public class OracleVectorFunctionContributor implements FunctionContributor {
2013

2114
@Override
2215
public void contributeFunctions(FunctionContributions functionContributions) {
2316
final Dialect dialect = functionContributions.getDialect();
24-
if (dialect instanceof OracleDialect) {
25-
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
26-
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
27-
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
28-
final BasicType<Double> doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE);
29-
final BasicType<Integer> integerType = basicTypeRegistry.resolve(StandardBasicTypes.INTEGER);
17+
if ( dialect instanceof OracleDialect ) {
18+
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
3019

31-
registerVectorDistanceFunction(functionRegistry, "cosine_distance", "vector_distance(?1, ?2, COSINE)", doubleType);
32-
registerVectorDistanceFunction(functionRegistry, "euclidean_distance", "vector_distance(?1, ?2, EUCLIDEAN)", doubleType);
33-
functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance");
20+
vectorFunctionFactory.cosineDistance( "vector_distance(?1,?2,COSINE)" );
21+
vectorFunctionFactory.euclideanDistance( "vector_distance(?1,?2,EUCLIDEAN)" );
22+
vectorFunctionFactory.l1Distance( "vector_distance(?1,?2,MANHATTAN)" );
23+
vectorFunctionFactory.hammingDistance( "vector_distance(?1,?2,HAMMING)" );
3424

35-
registerVectorDistanceFunction(functionRegistry, "l1_distance", "vector_distance(?1, ?2, MANHATTAN)", doubleType);
36-
functionRegistry.registerAlternateKey("taxicab_distance", "l1_distance");
25+
vectorFunctionFactory.innerProduct( "vector_distance(?1,?2,DOT)*-1" );
26+
vectorFunctionFactory.negativeInnerProduct( "vector_distance(?1,?2,DOT)" );
3727

38-
registerVectorDistanceFunction(functionRegistry, "negative_inner_product", "vector_distance(?1, ?2, DOT)", doubleType);
39-
registerVectorDistanceFunction(functionRegistry, "inner_product", "vector_distance(?1, ?2, DOT)*-1", doubleType);
40-
registerVectorDistanceFunction(functionRegistry, "hamming_distance", "vector_distance(?1, ?2, HAMMING)", doubleType);
41-
42-
registerNamedVectorFunction(functionRegistry, "vector_dims", integerType, 1);
43-
registerNamedVectorFunction(functionRegistry, "vector_norm", doubleType, 1);
28+
vectorFunctionFactory.vectorDimensions();
29+
vectorFunctionFactory.vectorNorm();
4430
}
4531
}
4632

47-
private void registerVectorDistanceFunction(
48-
SqmFunctionRegistry functionRegistry,
49-
String functionName,
50-
String pattern,
51-
BasicType<?> returnType) {
52-
53-
functionRegistry.patternDescriptorBuilder(functionName, pattern)
54-
.setArgumentsValidator(StandardArgumentsValidators.composite(
55-
StandardArgumentsValidators.exactly(2),
56-
VectorArgumentValidator.INSTANCE
57-
))
58-
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
59-
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
60-
.register();
61-
}
62-
63-
private void registerNamedVectorFunction(
64-
SqmFunctionRegistry functionRegistry,
65-
String functionName,
66-
BasicType<?> returnType,
67-
int argumentCount) {
68-
69-
functionRegistry.namedDescriptorBuilder(functionName)
70-
.setArgumentsValidator(StandardArgumentsValidators.composite(
71-
StandardArgumentsValidators.exactly(argumentCount),
72-
VectorArgumentValidator.INSTANCE
73-
))
74-
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
75-
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
76-
.register();
77-
}
78-
7933
@Override
8034
public int ordinal() {
8135
return 200;

hibernate-vector/src/main/java/org/hibernate/vector/PGVectorFunctionContributor.java

+8-54
Original file line numberDiff line numberDiff line change
@@ -9,73 +9,27 @@
99
import org.hibernate.dialect.CockroachDialect;
1010
import org.hibernate.dialect.Dialect;
1111
import org.hibernate.dialect.PostgreSQLDialect;
12-
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
13-
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
14-
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
15-
import org.hibernate.type.BasicType;
16-
import org.hibernate.type.BasicTypeRegistry;
17-
import org.hibernate.type.StandardBasicTypes;
18-
import org.hibernate.type.spi.TypeConfiguration;
1912

2013
public class PGVectorFunctionContributor implements FunctionContributor {
2114

2215
@Override
2316
public void contributeFunctions(FunctionContributions functionContributions) {
2417
final Dialect dialect = functionContributions.getDialect();
2518
if (dialect instanceof PostgreSQLDialect || dialect instanceof CockroachDialect) {
26-
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
27-
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
28-
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
29-
final BasicType<Double> doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE);
30-
final BasicType<Integer> integerType = basicTypeRegistry.resolve(StandardBasicTypes.INTEGER);
19+
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
3120

32-
registerVectorDistanceFunction(functionRegistry, "cosine_distance", "?1<=>?2", doubleType);
33-
registerVectorDistanceFunction(functionRegistry, "euclidean_distance", "?1<->?2", doubleType);
34-
functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance");
21+
vectorFunctionFactory.cosineDistance( "?1<=>?2" );
22+
vectorFunctionFactory.euclideanDistance( "?1<->?2" );
23+
vectorFunctionFactory.l1Distance( "l1_distance(?1,?2)" );
3524

36-
registerNamedVectorFunction(functionRegistry, "l1_distance", doubleType, 2);
37-
functionRegistry.registerAlternateKey("taxicab_distance", "l1_distance");
25+
vectorFunctionFactory.innerProduct( "(?1<#>?2)*-1" );
26+
vectorFunctionFactory.negativeInnerProduct( "?1<#>?2" );
3827

39-
registerVectorDistanceFunction(functionRegistry, "negative_inner_product", "?1<#>?2", doubleType);
40-
registerVectorDistanceFunction(functionRegistry, "inner_product", "(?1<#>?2)*-1", doubleType);
41-
42-
registerNamedVectorFunction(functionRegistry, "vector_dims", integerType, 1);
43-
registerNamedVectorFunction(functionRegistry, "vector_norm", doubleType, 1);
28+
vectorFunctionFactory.vectorDimensions();
29+
vectorFunctionFactory.vectorNorm();
4430
}
4531
}
4632

47-
private void registerVectorDistanceFunction(
48-
SqmFunctionRegistry functionRegistry,
49-
String functionName,
50-
String pattern,
51-
BasicType<?> returnType) {
52-
53-
functionRegistry.patternDescriptorBuilder(functionName, pattern)
54-
.setArgumentsValidator(StandardArgumentsValidators.composite(
55-
StandardArgumentsValidators.exactly(2),
56-
VectorArgumentValidator.INSTANCE
57-
))
58-
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
59-
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
60-
.register();
61-
}
62-
63-
private void registerNamedVectorFunction(
64-
SqmFunctionRegistry functionRegistry,
65-
String functionName,
66-
BasicType<?> returnType,
67-
int argumentCount) {
68-
69-
functionRegistry.namedDescriptorBuilder(functionName)
70-
.setArgumentsValidator(StandardArgumentsValidators.composite(
71-
StandardArgumentsValidators.exactly(argumentCount),
72-
VectorArgumentValidator.INSTANCE
73-
))
74-
.setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE)
75-
.setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType))
76-
.register();
77-
}
78-
7933
@Override
8034
public int ordinal() {
8135
return 200;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector;
6+
7+
import org.hibernate.boot.model.FunctionContributions;
8+
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
9+
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
10+
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
11+
import org.hibernate.type.BasicType;
12+
import org.hibernate.type.BasicTypeRegistry;
13+
import org.hibernate.type.StandardBasicTypes;
14+
import org.hibernate.type.spi.TypeConfiguration;
15+
16+
/**
17+
* Enumerates common vector function template definitions.
18+
* Centralized for easier use from dialects.
19+
*/
20+
public class VectorFunctionFactory {
21+
22+
private final SqmFunctionRegistry functionRegistry;
23+
private final TypeConfiguration typeConfiguration;
24+
private final BasicType<Double> doubleType;
25+
private final BasicType<Integer> integerType;
26+
27+
public VectorFunctionFactory(FunctionContributions functionContributions) {
28+
this.functionRegistry = functionContributions.getFunctionRegistry();
29+
this.typeConfiguration = functionContributions.getTypeConfiguration();
30+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
31+
this.doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
32+
this.integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER );
33+
}
34+
35+
public void cosineDistance(String pattern) {
36+
registerVectorDistanceFunction( "cosine_distance", pattern );
37+
}
38+
39+
public void euclideanDistance(String pattern) {
40+
registerVectorDistanceFunction( "euclidean_distance", pattern );
41+
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
42+
}
43+
44+
public void l1Distance(String pattern) {
45+
registerVectorDistanceFunction( "l1_distance", pattern );
46+
functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" );
47+
}
48+
49+
public void innerProduct(String pattern) {
50+
registerVectorDistanceFunction( "inner_product", pattern );
51+
}
52+
53+
public void negativeInnerProduct(String pattern) {
54+
registerVectorDistanceFunction( "negative_inner_product", pattern );
55+
}
56+
57+
public void hammingDistance(String pattern) {
58+
registerVectorDistanceFunction( "hamming_distance", pattern );
59+
}
60+
61+
public void vectorDimensions() {
62+
registerNamedVectorFunction( "vector_dims", integerType, 1 );
63+
}
64+
65+
public void vectorNorm() {
66+
registerNamedVectorFunction( "vector_norm", integerType, 1 );
67+
}
68+
69+
public void registerVectorDistanceFunction(String functionName, String pattern) {
70+
functionRegistry.patternDescriptorBuilder( functionName, pattern )
71+
.setArgumentsValidator( StandardArgumentsValidators.composite(
72+
StandardArgumentsValidators.exactly( 2 ),
73+
VectorArgumentValidator.INSTANCE
74+
) )
75+
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
76+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
77+
.register();
78+
}
79+
80+
public void registerNamedVectorFunction(String functionName, BasicType<?> returnType, int argumentCount) {
81+
functionRegistry.namedDescriptorBuilder( functionName )
82+
.setArgumentsValidator( StandardArgumentsValidators.composite(
83+
StandardArgumentsValidators.exactly( argumentCount ),
84+
VectorArgumentValidator.INSTANCE
85+
) )
86+
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
87+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( returnType ) )
88+
.register();
89+
}
90+
91+
}

0 commit comments

Comments
 (0)