@@ -20,84 +20,62 @@ public class OracleVectorFunctionContributor implements FunctionContributor {
20
20
21
21
@ Override
22
22
public void contributeFunctions (FunctionContributions functionContributions ) {
23
- final SqmFunctionRegistry functionRegistry = functionContributions .getFunctionRegistry ();
24
- final TypeConfiguration typeConfiguration = functionContributions .getTypeConfiguration ();
25
- final BasicTypeRegistry basicTypeRegistry = typeConfiguration .getBasicTypeRegistry ();
26
23
final Dialect dialect = functionContributions .getDialect ();
27
- if ( dialect instanceof OracleDialect ) {
28
- final BasicType <Double > doubleType = basicTypeRegistry .resolve ( StandardBasicTypes .DOUBLE );
29
- final BasicType <Integer > integerType = basicTypeRegistry .resolve ( StandardBasicTypes .INTEGER );
30
- functionRegistry .patternDescriptorBuilder ( "cosine_distance" , "vector_distance(?1, ?2, COSINE)" )
31
- .setArgumentsValidator ( StandardArgumentsValidators .composite (
32
- StandardArgumentsValidators .exactly ( 2 ),
33
- VectorArgumentValidator .INSTANCE
34
- ) )
35
- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
36
- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
37
- .register ();
38
- functionRegistry .patternDescriptorBuilder ( "euclidean_distance" , "vector_distance(?1, ?2, EUCLIDEAN)" )
39
- .setArgumentsValidator ( StandardArgumentsValidators .composite (
40
- StandardArgumentsValidators .exactly ( 2 ),
41
- VectorArgumentValidator .INSTANCE
42
- ) )
43
- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
44
- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
45
- .register ();
46
- functionRegistry .registerAlternateKey ( "l2_distance" , "euclidean_distance" );
24
+ if (dialect instanceof OracleDialect ) {
25
+ final SqmFunctionRegistry functionRegistry = functionContributions .getFunctionRegistry ();
26
+ final TypeConfiguration typeConfiguration = functionContributions .getTypeConfiguration ();
27
+ final BasicTypeRegistry basicTypeRegistry = typeConfiguration .getBasicTypeRegistry ();
28
+ final BasicType <Double > doubleType = basicTypeRegistry .resolve (StandardBasicTypes .DOUBLE );
29
+ final BasicType <Integer > integerType = basicTypeRegistry .resolve (StandardBasicTypes .INTEGER );
47
30
48
- functionRegistry .patternDescriptorBuilder ( "l1_distance" , "vector_distance(?1, ?2, MANHATTAN)" )
49
- .setArgumentsValidator ( StandardArgumentsValidators .composite (
50
- StandardArgumentsValidators .exactly ( 2 ),
51
- VectorArgumentValidator .INSTANCE
52
- ) )
53
- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
54
- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
55
- .register ();
56
- functionRegistry .registerAlternateKey ( "taxicab_distance" , "l1_distance" );
31
+ registerVectorDistanceFunction (functionRegistry , "cosine_distance" , "vector_distance(?1, ?2, COSINE)" , doubleType );
32
+ registerVectorDistanceFunction (functionRegistry , "euclidean_distance" , "vector_distance(?1, ?2, EUCLIDEAN)" , doubleType );
33
+ functionRegistry .registerAlternateKey ("l2_distance" , "euclidean_distance" );
57
34
58
- functionRegistry .patternDescriptorBuilder ( "negative_inner_product" , "vector_distance(?1, ?2, DOT)" )
59
- .setArgumentsValidator ( StandardArgumentsValidators .composite (
60
- StandardArgumentsValidators .exactly ( 2 ),
61
- VectorArgumentValidator .INSTANCE
62
- ) )
63
- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
64
- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
65
- .register ();
66
- functionRegistry .patternDescriptorBuilder ( "inner_product" , "vector_distance(?1, ?2, DOT)*-1" )
67
- .setArgumentsValidator ( StandardArgumentsValidators .composite (
68
- StandardArgumentsValidators .exactly ( 2 ),
69
- VectorArgumentValidator .INSTANCE
70
- ) )
71
- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
72
- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
73
- .register ();
74
- functionRegistry .patternDescriptorBuilder ( "hamming_distance" , "vector_distance(?1, ?2, HAMMING)" )
75
- .setArgumentsValidator ( StandardArgumentsValidators .composite (
76
- StandardArgumentsValidators .exactly ( 2 ),
77
- VectorArgumentValidator .INSTANCE
78
- ) )
79
- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
80
- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
81
- .register ();
82
- functionRegistry .namedDescriptorBuilder ( "vector_dims" )
83
- .setArgumentsValidator ( StandardArgumentsValidators .composite (
84
- StandardArgumentsValidators .exactly ( 1 ),
85
- VectorArgumentValidator .INSTANCE
86
- ) )
87
- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
88
- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( integerType ) )
89
- .register ();
90
- functionRegistry .namedDescriptorBuilder ( "vector_norm" )
91
- .setArgumentsValidator ( StandardArgumentsValidators .composite (
92
- StandardArgumentsValidators .exactly ( 1 ),
93
- VectorArgumentValidator .INSTANCE
94
- ) )
95
- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
96
- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
97
- .register ();
35
+ registerVectorDistanceFunction (functionRegistry , "l1_distance" , "vector_distance(?1, ?2, MANHATTAN)" , doubleType );
36
+ functionRegistry .registerAlternateKey ("taxicab_distance" , "l1_distance" );
37
+
38
+ registerVectorDistanceFunction (functionRegistry , "negative_inner_product" , "vector_distance(?1, ?2, DOT)" , doubleType );
39
+ registerVectorDistanceFunction (functionRegistry , "inner_product" , "vector_distance(?1, ?2, DOT)*-1" , doubleType );
40
+ registerVectorDistanceFunction (functionRegistry , "hamming_distance" , "vector_distance(?1, ?2, HAMMING)" , doubleType );
41
+
42
+ registerNamedVectorFunction (functionRegistry , "vector_dims" , integerType , 1 );
43
+ registerNamedVectorFunction (functionRegistry , "vector_norm" , doubleType , 1 );
98
44
}
99
45
}
100
46
47
+ private void registerVectorDistanceFunction (
48
+ SqmFunctionRegistry functionRegistry ,
49
+ String functionName ,
50
+ String pattern ,
51
+ BasicType <?> returnType ) {
52
+
53
+ functionRegistry .patternDescriptorBuilder (functionName , pattern )
54
+ .setArgumentsValidator (StandardArgumentsValidators .composite (
55
+ StandardArgumentsValidators .exactly (2 ),
56
+ VectorArgumentValidator .INSTANCE
57
+ ))
58
+ .setArgumentTypeResolver (VectorArgumentTypeResolver .INSTANCE )
59
+ .setReturnTypeResolver (StandardFunctionReturnTypeResolvers .invariant (returnType ))
60
+ .register ();
61
+ }
62
+
63
+ private void registerNamedVectorFunction (
64
+ SqmFunctionRegistry functionRegistry ,
65
+ String functionName ,
66
+ BasicType <?> returnType ,
67
+ int argumentCount ) {
68
+
69
+ functionRegistry .namedDescriptorBuilder (functionName )
70
+ .setArgumentsValidator (StandardArgumentsValidators .composite (
71
+ StandardArgumentsValidators .exactly (argumentCount ),
72
+ VectorArgumentValidator .INSTANCE
73
+ ))
74
+ .setArgumentTypeResolver (VectorArgumentTypeResolver .INSTANCE )
75
+ .setReturnTypeResolver (StandardFunctionReturnTypeResolvers .invariant (returnType ))
76
+ .register ();
77
+ }
78
+
101
79
@ Override
102
80
public int ordinal () {
103
81
return 200 ;
0 commit comments