|
8 | 8 | import org.hibernate.boot.model.FunctionContributor;
|
9 | 9 | import org.hibernate.dialect.Dialect;
|
10 | 10 | import org.hibernate.dialect.OracleDialect;
|
11 |
| -import org.hibernate.query.sqm.function.SqmFunctionRegistry; |
12 |
| -import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; |
13 |
| -import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; |
14 |
| -import org.hibernate.type.BasicType; |
15 |
| -import org.hibernate.type.BasicTypeRegistry; |
16 |
| -import org.hibernate.type.StandardBasicTypes; |
17 |
| -import org.hibernate.type.spi.TypeConfiguration; |
18 | 11 |
|
19 | 12 | public class OracleVectorFunctionContributor implements FunctionContributor {
|
20 | 13 |
|
21 | 14 | @Override
|
22 | 15 | public void contributeFunctions(FunctionContributions functionContributions) {
|
23 | 16 | final Dialect dialect = functionContributions.getDialect();
|
24 |
| - if (dialect instanceof OracleDialect) { |
25 |
| - final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry(); |
26 |
| - final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); |
27 |
| - final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); |
28 |
| - final BasicType<Double> doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE); |
29 |
| - final BasicType<Integer> integerType = basicTypeRegistry.resolve(StandardBasicTypes.INTEGER); |
| 17 | + if ( dialect instanceof OracleDialect ) { |
| 18 | + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); |
30 | 19 |
|
31 |
| - registerVectorDistanceFunction(functionRegistry, "cosine_distance", "vector_distance(?1, ?2, COSINE)", doubleType); |
32 |
| - registerVectorDistanceFunction(functionRegistry, "euclidean_distance", "vector_distance(?1, ?2, EUCLIDEAN)", doubleType); |
33 |
| - functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance"); |
| 20 | + vectorFunctionFactory.cosineDistance( "vector_distance(?1,?2,COSINE)" ); |
| 21 | + vectorFunctionFactory.euclideanDistance( "vector_distance(?1,?2,EUCLIDEAN)" ); |
| 22 | + vectorFunctionFactory.l1Distance( "vector_distance(?1,?2,MANHATTAN)" ); |
| 23 | + vectorFunctionFactory.hammingDistance( "vector_distance(?1,?2,HAMMING)" ); |
34 | 24 |
|
35 |
| - registerVectorDistanceFunction(functionRegistry, "l1_distance", "vector_distance(?1, ?2, MANHATTAN)", doubleType); |
36 |
| - functionRegistry.registerAlternateKey("taxicab_distance", "l1_distance"); |
| 25 | + vectorFunctionFactory.innerProduct( "vector_distance(?1,?2,DOT)*-1" ); |
| 26 | + vectorFunctionFactory.negativeInnerProduct( "vector_distance(?1,?2,DOT)" ); |
37 | 27 |
|
38 |
| - registerVectorDistanceFunction(functionRegistry, "negative_inner_product", "vector_distance(?1, ?2, DOT)", doubleType); |
39 |
| - registerVectorDistanceFunction(functionRegistry, "inner_product", "vector_distance(?1, ?2, DOT)*-1", doubleType); |
40 |
| - registerVectorDistanceFunction(functionRegistry, "hamming_distance", "vector_distance(?1, ?2, HAMMING)", doubleType); |
41 |
| - |
42 |
| - registerNamedVectorFunction(functionRegistry, "vector_dims", integerType, 1); |
43 |
| - registerNamedVectorFunction(functionRegistry, "vector_norm", doubleType, 1); |
| 28 | + vectorFunctionFactory.vectorDimensions(); |
| 29 | + vectorFunctionFactory.vectorNorm(); |
44 | 30 | }
|
45 | 31 | }
|
46 | 32 |
|
47 |
| - private void registerVectorDistanceFunction( |
48 |
| - SqmFunctionRegistry functionRegistry, |
49 |
| - String functionName, |
50 |
| - String pattern, |
51 |
| - BasicType<?> returnType) { |
52 |
| - |
53 |
| - functionRegistry.patternDescriptorBuilder(functionName, pattern) |
54 |
| - .setArgumentsValidator(StandardArgumentsValidators.composite( |
55 |
| - StandardArgumentsValidators.exactly(2), |
56 |
| - VectorArgumentValidator.INSTANCE |
57 |
| - )) |
58 |
| - .setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE) |
59 |
| - .setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType)) |
60 |
| - .register(); |
61 |
| - } |
62 |
| - |
63 |
| - private void registerNamedVectorFunction( |
64 |
| - SqmFunctionRegistry functionRegistry, |
65 |
| - String functionName, |
66 |
| - BasicType<?> returnType, |
67 |
| - int argumentCount) { |
68 |
| - |
69 |
| - functionRegistry.namedDescriptorBuilder(functionName) |
70 |
| - .setArgumentsValidator(StandardArgumentsValidators.composite( |
71 |
| - StandardArgumentsValidators.exactly(argumentCount), |
72 |
| - VectorArgumentValidator.INSTANCE |
73 |
| - )) |
74 |
| - .setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE) |
75 |
| - .setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType)) |
76 |
| - .register(); |
77 |
| - } |
78 |
| - |
79 | 33 | @Override
|
80 | 34 | public int ordinal() {
|
81 | 35 | return 200;
|
|
0 commit comments