Skip to content

Commit eebb305

Browse files
committed
HHH-17357 Add hibernate-types module with pgvector support
1 parent 0291006 commit eebb305

22 files changed

+1178
-89
lines changed

docker_db.sh

+4
Original file line numberDiff line numberDiff line change
@@ -147,22 +147,26 @@ postgresql() {
147147
postgresql_12() {
148148
$CONTAINER_CLI rm -f postgres || true
149149
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:12-3.4
150+
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-12-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
150151
}
151152

152153
postgresql_13() {
153154
$CONTAINER_CLI rm -f postgres || true
154155
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:13-3.1
156+
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-13-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
155157
}
156158

157159
postgresql_14() {
158160
$CONTAINER_CLI rm -f postgres || true
159161
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:14-3.3
162+
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-14-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
160163
}
161164

162165
postgresql_15() {
163166
$CONTAINER_CLI rm -f postgres || true
164167
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 --tmpfs /pgtmpfs:size=131072k -d docker.io/postgis/postgis:15-3.3 \
165168
-c fsync=off -c synchronous_commit=off -c full_page_writes=off -c shared_buffers=256MB -c maintenance_work_mem=256MB -c max_wal_size=1GB -c checkpoint_timeout=1d
169+
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-15-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
166170
}
167171

168172
edb() {

documentation/src/main/asciidoc/userguide/Hibernate_User_Guide.adoc

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ include::chapters/query/hql/QueryLanguage.adoc[]
3434
include::chapters/query/criteria/Criteria.adoc[]
3535
include::chapters/query/native/Native.adoc[]
3636
include::chapters/query/spatial/Spatial.adoc[]
37+
include::chapters/query/types/TypesModule.adoc[]
3738
include::chapters/multitenancy/MultiTenancy.adoc[]
3839
include::chapters/envers/Envers.adoc[]
3940
include::chapters/beans/Beans.adoc[]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
[[types-module]]
2+
== Hibernate Types module
3+
:root-project-dir: ../../../../../../../..
4+
:types-project-dir: {root-project-dir}/hibernate-types
5+
:example-dir-types: {types-project-dir}/src/test/java/org/hibernate/types
6+
:extrasdir: extras
7+
8+
[[types-module-overview]]
9+
=== Overview
10+
11+
The Hibernate ORM core module tries to be as minimal as possible and only model functionality
12+
that is somewhat "standard" in the SQL space or can only be modeled as part of the core module.
13+
To avoid growing that module further unnecessarily, support for certain special SQL types or functions
14+
is separated out into the Hibernate ORM types module.
15+
16+
[[types-module-setup]]
17+
=== Setup
18+
19+
You need to include the `hibernate-types` dependency in your build environment.
20+
For Maven, you need to add the following dependency:
21+
22+
[[types-module-setup-maven-example]]
23+
.Maven dependency
24+
====
25+
[source,xml]
26+
----
27+
<dependency>
28+
<groupId>org.hibernate.orm</groupId>
29+
<artifactId>hibernate-types</artifactId>
30+
<version>${hibernate.version}</version>
31+
</dependency>
32+
----
33+
====
34+
35+
The module contains service implementations that are picked up by the Java `ServiceLoader` automatically,
36+
so no further configuration is necessary to make the features available.
37+
38+
[[types-module-vector]]
39+
=== Vector type support
40+
41+
The Hibernate ORM types module comes with support for a special `vector` data type that essentially represents an array of floats.
42+
43+
So far, only the PostgreSQL extension `pgvector` is supported, but in theory,
44+
the vector specific functions could be implemented to work with every database that supports arrays.
45+
46+
For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation].
47+
48+
[[types-module-vector-usage]]
49+
==== Usage
50+
51+
Annotate a persistent attribute with `@JdbcTypeCode(SqlTypes.VECTOR)` and specify the vector length with `@Array(length = ...)`.
52+
53+
[[types-module-vector-usage-example]]
54+
====
55+
[source, JAVA, indent=0]
56+
----
57+
include::{example-dir-types}/vector/PGVectorTest.java[tags=usage-example]
58+
----
59+
====
60+
61+
To cast the string representation of a vector to the vector data type, simply use an HQL cast i.e. `cast('[1,2,3]' as vector)`.
62+
63+
[[types-module-vector-functions]]
64+
==== Functions
65+
66+
Expressions of the vector type can be used with various vector functions.
67+
68+
[[types-module-vector-functions-overview]]
69+
|===
70+
| Function | Purpose
71+
72+
| `cosine_distance()` | Computes the https://en.wikipedia.org/wiki/Cosine_similarity[cosine distance] between two vectors. Maps to the `<``=``>` operator
73+
| `euclidean_distance()` | Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors. Maps to the `<``-``>` operator
74+
| `l2_distance()` | Alias for `euclidean_distance()`
75+
| `taxicab_distance()` | Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors
76+
| `l1_distance()` | Alias for `taxicab_distance()`
77+
| `inner_product()` | Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors
78+
| `negative_inner_product()` | Computes the negative inner product. Maps to the `<``#``>` operator
79+
| `vector_dims()` | Determines the dimensions of a vector
80+
| `vector_norm()` | Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector
81+
|===
82+
83+
In addition to these special vector functions, it is also possible to use vectors with the following builtin operators
84+
85+
`<vector1> + <vector2> = <vector3>`:: Element-wise addition of vectors.
86+
`<vector1> - <vector2> = <vector3>`:: Element-wise subtraction of vectors.
87+
`<vector1> * <vector2> = <vector3>`:: Element-wise multiplication of vectors.
88+
`sum(<vector1>) = <vector2>`:: Aggregate function support for element-wise summation of vectors.
89+
`avg(<vector1>) = <vector2>`:: Aggregate function support for element-wise average of vectors.
90+
91+
[[types-module-vector-functions-cosine-distance]]
92+
===== `cosine_distance()`
93+
94+
Computes the https://en.wikipedia.org/wiki/Cosine_similarity[cosine distance] between two vectors,
95+
which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 ) )`. Maps to the `<``=``>` pgvector operator.
96+
97+
[[types-module-vector-functions-cosine-distance-example]]
98+
====
99+
[source, JAVA, indent=0]
100+
----
101+
include::{example-dir-types}/vector/PGVectorTest.java[tags=cosine-distance-example]
102+
----
103+
====
104+
105+
[[types-module-vector-functions-euclidean-distance]]
106+
===== `euclidean_distance()` and `l2_distance()`
107+
108+
Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors,
109+
which is `sqrt( sum( (v1_i - v2_i)^2 ) )`. Maps to the `<``-``>` pgvector operator.
110+
The `l2_distance()` function is an alias.
111+
112+
[[types-module-vector-functions-euclidean-distance-example]]
113+
====
114+
[source, JAVA, indent=0]
115+
----
116+
include::{example-dir-types}/vector/PGVectorTest.java[tags=euclidean-distance-example]
117+
----
118+
====
119+
120+
[[types-module-vector-functions-taxicab-distance]]
121+
===== `taxicab_distance()` and `l1_distance()`
122+
123+
Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors,
124+
which is `vector_norm(v1) - vector_norm(v2)`.
125+
The `l1_distance()` function is an alias.
126+
127+
[[types-module-vector-functions-taxicab-distance-example]]
128+
====
129+
[source, JAVA, indent=0]
130+
----
131+
include::{example-dir-types}/vector/PGVectorTest.java[tags=taxicab-distance-example]
132+
----
133+
====
134+
135+
[[types-module-vector-functions-inner-product]]
136+
===== `inner_product()` and `negative_inner_product()`
137+
138+
Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors,
139+
which is `sum( v1_i * v2_i )`. The `negative_inner_product()` function maps to the `<``#``>` pgvector operator,
140+
and the `inner_product()` function as well, but multiplies the result time `-1`.
141+
142+
[[types-module-vector-functions-inner-product-example]]
143+
====
144+
[source, JAVA, indent=0]
145+
----
146+
include::{example-dir-types}/vector/PGVectorTest.java[tags=inner-product-example]
147+
----
148+
====
149+
150+
[[types-module-vector-functions-vector-dims]]
151+
===== `vector_dims()`
152+
153+
Determines the dimensions of a vector.
154+
155+
[[types-module-vector-functions-vector-dims-example]]
156+
====
157+
[source, JAVA, indent=0]
158+
----
159+
include::{example-dir-types}/vector/PGVectorTest.java[tags=vector-dims-example]
160+
----
161+
====
162+
163+
[[types-module-vector-functions-vector-norm]]
164+
===== `vector_norm()`
165+
166+
Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector,
167+
which is `sqrt( sum( v_i^2 ) )`.
168+
169+
[[types-module-vector-functions-vector-norm-example]]
170+
====
171+
[source, JAVA, indent=0]
172+
----
173+
include::{example-dir-types}/vector/PGVectorTest.java[tags=vector-norm-example]
174+
----
175+
====
176+
177+
178+
179+

hibernate-core/src/main/java/org/hibernate/dialect/function/AvgFunction.java

+130-7
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,23 @@
88

99
import java.util.Arrays;
1010
import java.util.List;
11+
import java.util.Locale;
12+
import java.util.function.Supplier;
1113

1214
import org.hibernate.dialect.Dialect;
15+
import org.hibernate.metamodel.mapping.BasicValuedMapping;
1316
import org.hibernate.metamodel.mapping.JdbcMapping;
17+
import org.hibernate.metamodel.mapping.MappingModelExpressible;
18+
import org.hibernate.metamodel.model.domain.DomainType;
1419
import org.hibernate.query.ReturnableType;
20+
import org.hibernate.query.sqm.SqmExpressible;
1521
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
1622
import org.hibernate.query.sqm.function.FunctionKind;
17-
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
18-
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
23+
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
24+
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
25+
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
1926
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
20-
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
27+
import org.hibernate.query.sqm.tree.SqmTypedNode;
2128
import org.hibernate.sql.ast.Clause;
2229
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
2330
import org.hibernate.sql.ast.SqlAstTranslator;
@@ -27,8 +34,14 @@
2734
import org.hibernate.sql.ast.tree.expression.Distinct;
2835
import org.hibernate.sql.ast.tree.expression.Expression;
2936
import org.hibernate.sql.ast.tree.predicate.Predicate;
37+
import org.hibernate.type.BasicPluralType;
3038
import org.hibernate.type.BasicType;
39+
import org.hibernate.type.SqlTypes;
3140
import org.hibernate.type.StandardBasicTypes;
41+
import org.hibernate.type.descriptor.java.JavaType;
42+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
43+
import org.hibernate.type.descriptor.jdbc.JdbcType;
44+
import org.hibernate.type.descriptor.jdbc.ObjectJdbcType;
3245
import org.hibernate.type.spi.TypeConfiguration;
3346

3447
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
@@ -49,10 +62,8 @@ public AvgFunction(
4962
super(
5063
"avg",
5164
FunctionKind.AGGREGATE,
52-
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ),
53-
StandardFunctionReturnTypeResolvers.invariant(
54-
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
55-
),
65+
new Validator(),
66+
new ReturnTypeResolver( typeConfiguration ),
5667
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, NUMERIC )
5768
);
5869
this.defaultArgumentRenderingMode = defaultArgumentRenderingMode;
@@ -131,4 +142,116 @@ public String getArgumentListSignature() {
131142
return "(NUMERIC arg)";
132143
}
133144

145+
public static class Validator implements ArgumentsValidator {
146+
147+
public static final ArgumentsValidator INSTANCE = new Validator();
148+
149+
@Override
150+
public void validate(
151+
List<? extends SqmTypedNode<?>> arguments,
152+
String functionName,
153+
TypeConfiguration typeConfiguration) {
154+
if ( arguments.size() != 1 ) {
155+
throw new FunctionArgumentException(
156+
String.format(
157+
Locale.ROOT,
158+
"Function %s() has %d parameters, but %d arguments given",
159+
functionName,
160+
1,
161+
arguments.size()
162+
)
163+
);
164+
}
165+
final SqmTypedNode<?> argument = arguments.get( 0 );
166+
final SqmExpressible<?> expressible = argument.getExpressible();
167+
final DomainType<?> domainType;
168+
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
169+
final JdbcType jdbcType = getJdbcType( domainType, typeConfiguration );
170+
if ( !isNumeric( jdbcType ) ) {
171+
throw new FunctionArgumentException(
172+
String.format(
173+
"Parameter %d of function '%s()' has type '%s', but argument is of type '%s'",
174+
1,
175+
functionName,
176+
NUMERIC,
177+
domainType.getTypeName()
178+
)
179+
);
180+
}
181+
}
182+
}
183+
184+
private static boolean isNumeric(JdbcType jdbcType) {
185+
final int sqlTypeCode = jdbcType.getDefaultSqlTypeCode();
186+
if ( SqlTypes.isNumericType( sqlTypeCode ) ) {
187+
return true;
188+
}
189+
if ( jdbcType instanceof ArrayJdbcType ) {
190+
return isNumeric( ( (ArrayJdbcType) jdbcType ).getElementJdbcType() );
191+
}
192+
return false;
193+
}
194+
195+
private static JdbcType getJdbcType(DomainType<?> domainType, TypeConfiguration typeConfiguration) {
196+
if ( domainType instanceof JdbcMapping ) {
197+
return ( (JdbcMapping) domainType ).getJdbcType();
198+
}
199+
else {
200+
final JavaType<?> javaType = domainType.getExpressibleJavaType();
201+
if ( javaType.getJavaTypeClass().isEnum() ) {
202+
// we can't tell if the enum is mapped STRING or ORDINAL
203+
return ObjectJdbcType.INSTANCE;
204+
}
205+
else {
206+
return javaType.getRecommendedJdbcType( typeConfiguration.getCurrentBaseSqlTypeIndicators() );
207+
}
208+
}
209+
}
210+
211+
@Override
212+
public String getSignature() {
213+
return "(arg)";
214+
}
215+
}
216+
217+
public static class ReturnTypeResolver implements FunctionReturnTypeResolver {
218+
219+
private final BasicType<Double> doubleType;
220+
221+
public ReturnTypeResolver(TypeConfiguration typeConfiguration) {
222+
this.doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
223+
}
224+
225+
@Override
226+
public BasicValuedMapping resolveFunctionReturnType(
227+
Supplier<BasicValuedMapping> impliedTypeAccess,
228+
List<? extends SqlAstNode> arguments) {
229+
final BasicValuedMapping impliedType = impliedTypeAccess.get();
230+
if ( impliedType != null ) {
231+
return impliedType;
232+
}
233+
final JdbcMapping jdbcMapping = ( (Expression) arguments.get( 0 ) ).getExpressionType().getSingleJdbcMapping();
234+
if ( jdbcMapping instanceof BasicPluralType<?, ?> ) {
235+
return (BasicValuedMapping) jdbcMapping;
236+
}
237+
return doubleType;
238+
}
239+
240+
@Override
241+
public ReturnableType<?> resolveFunctionReturnType(
242+
ReturnableType<?> impliedType,
243+
Supplier<MappingModelExpressible<?>> inferredTypeSupplier,
244+
List<? extends SqmTypedNode<?>> arguments,
245+
TypeConfiguration typeConfiguration) {
246+
final SqmExpressible<?> expressible = arguments.get( 0 ).getExpressible();
247+
final DomainType<?> domainType;
248+
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
249+
if ( domainType instanceof BasicPluralType<?, ?> ) {
250+
return (ReturnableType<?>) domainType;
251+
}
252+
}
253+
return typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
254+
}
255+
}
256+
134257
}

0 commit comments

Comments
 (0)