Skip to content

Commit fc546e4

Browse files
committed
Add support for pgvector.
[closes #612] Signed-off-by: Mark Paluch <[email protected]>
1 parent d193789 commit fc546e4

File tree

9 files changed

+949
-54
lines changed

9 files changed

+949
-54
lines changed

README.md

+50-48
Original file line numberDiff line numberDiff line change
@@ -425,54 +425,55 @@ When available, the driver registers also an array variant of the codec.
425425

426426
This reference table shows the type mapping between [PostgreSQL][p] and Java data types:
427427

428-
| PostgreSQL Type | Supported Data Type |
429-
|:------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------|
430-
| [`bigint`][psql-bigint-ref] | [**`Long`**][java-long-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref] |
431-
| [`bit`][psql-bit-ref] | Not yet supported.|
432-
| [`bit varying`][psql-bit-ref] | Not yet supported.|
433-
| [`boolean or bool`][psql-boolean-ref] | [`Boolean`][java-boolean-ref]|
434-
| [`box`][psql-box-ref] | **`Box`**|
435-
| [`bytea`][psql-bytea-ref] | [**`ByteBuffer`**][java-ByteBuffer-ref], [`byte[]`][java-byte-ref], [`Blob`][r2dbc-blob-ref]|
436-
| [`character`][psql-character-ref] | [`String`][java-string-ref]|
437-
| [`character varying`][psql-character-ref] | [`String`][java-string-ref]|
438-
| [`cidr`][psql-cidr-ref] | Not yet supported.|
439-
| [`circle`][psql-circle-ref] | **`Circle`**|
440-
| [`date`][psql-date-ref] | [`LocalDate`][java-ld-ref]|
441-
| [`double precision`][psql-floating-point-ref] | [**`Double`**][java-double-ref], [`Float`][java-float-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
442-
| [enumerated types][psql-enum-ref] | Client code `Enum` types through `EnumCodec`|
443-
| [`geometry`][postgis-ref] | **`org.locationtech.jts.geom.Geometry`**|
444-
| [`hstore`][psql-hstore-ref] | [**`Map`**][java-map-ref]|
445-
| [`inet`][psql-inet-ref] | [**`InetAddress`**][java-inet-ref]|
446-
| [`integer`][psql-integer-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
447-
| [`interval`][psql-interval-ref] | **`Interval`**|
448-
| [`json`][psql-json-ref] | **`Json`**, [`String`][java-string-ref]. Reading: `ByteBuf`[`byte[]`][java-primitive-ref], [`ByteBuffer`][java-ByteBuffer-ref], [`String`][java-string-ref], [`InputStream`][java-inputstream-ref]|
449-
| [`jsonb`][psql-json-ref] | **`Json`**, [`String`][java-string-ref]. Reading: `ByteBuf`[`byte[]`][java-primitive-ref], [`ByteBuffer`][java-ByteBuffer-ref], [`String`][java-string-ref], [`InputStream`][java-inputstream-ref]|
450-
| [`line`][psql-line-ref] | **`Line`**|
451-
| [`lseg`][psql-lseq-ref] | **`Lseg`**|
452-
| [`macaddr`][psql-macaddr-ref] | Not yet supported.|
453-
| [`macaddr8`][psql-macaddr8-ref] | Not yet supported.|
454-
| [`money`][psql-money-ref] | Not yet supported. Please don't use this type. It is a very poor implementation. |
455-
| [`name`][psql-name-ref] | [**`String`**][java-string-ref]
456-
| [`numeric`][psql-bignumeric-ref] | [`BigDecimal`][java-bigdecimal-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigInteger`][java-biginteger-ref]|
457-
| [`oid`][psql-oid-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
458-
| [`path`][psql-path-ref] | **`Path`**|
459-
| [`pg_lsn`][psql-pg_lsn-ref] | Not yet supported.|
460-
| [`point`][psql-point-ref] | **`Point`**|
461-
| [`polygon`][psql-polygon-ref] | **`Polygon`**|
462-
| [`real`][psql-real-ref] | [**`Float`**][java-float-ref], [`Double`][java-double-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
463-
| [`smallint`][psql-smallint-ref] | [**`Short`**][java-short-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
464-
| [`smallserial`][psql-smallserial-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
465-
| [`serial`][psql-serial-ref] | [**`Long`**][java-long-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
466-
| [`text`][psql-text-ref] | [**`String`**][java-string-ref], [`Clob`][r2dbc-clob-ref]|
467-
| [`time [without time zone]`][psql-time-ref] | [`LocalTime`][java-lt-ref]|
468-
| [`time [with time zone]`][psql-time-ref] | [`OffsetTime`][java-ot-ref]|
469-
| [`timestamp [without time zone]`][psql-time-ref]|[**`LocalDateTime`**][java-ldt-ref], [`LocalTime`][java-lt-ref], [`LocalDate`][java-ld-ref], [`java.util.Date`][java-legacy-date-ref]|
470-
| [`timestamp [with time zone]`][psql-time-ref] | [**`OffsetDatetime`**][java-odt-ref], [`ZonedDateTime`][java-zdt-ref], [`Instant`][java-instant-ref]|
471-
| [`tsquery`][psql-tsquery-ref] | Not yet supported.|
472-
| [`tsvector`][psql-tsvector-ref] | Not yet supported.|
473-
| [`txid_snapshot`][psql-txid_snapshot-ref] | Not yet supported.|
474-
| [`uuid`][psql-uuid-ref] | [**`UUID`**][java-uuid-ref], [`String`][java-string-ref]||
475-
| [`xml`][psql-xml-ref] | Not yet supported. |
428+
| PostgreSQL Type | Supported Data Type |
429+
|:-------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------|
430+
| [`bigint`][psql-bigint-ref] | [**`Long`**][java-long-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref] |
431+
| [`bit`][psql-bit-ref] | Not yet supported.|
432+
| [`bit varying`][psql-bit-ref] | Not yet supported.|
433+
| [`boolean or bool`][psql-boolean-ref] | [`Boolean`][java-boolean-ref]|
434+
| [`box`][psql-box-ref] | **`Box`**|
435+
| [`bytea`][psql-bytea-ref] | [**`ByteBuffer`**][java-ByteBuffer-ref], [`byte[]`][java-byte-ref], [`Blob`][r2dbc-blob-ref]|
436+
| [`character`][psql-character-ref] | [`String`][java-string-ref]|
437+
| [`character varying`][psql-character-ref] | [`String`][java-string-ref]|
438+
| [`cidr`][psql-cidr-ref] | Not yet supported.|
439+
| [`circle`][psql-circle-ref] | **`Circle`**|
440+
| [`date`][psql-date-ref] | [`LocalDate`][java-ld-ref]|
441+
| [`double precision`][psql-floating-point-ref] | [**`Double`**][java-double-ref], [`Float`][java-float-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
442+
| [enumerated types][psql-enum-ref] | Client code `Enum` types through `EnumCodec`|
443+
| [`geometry`][postgis-ref] | **`org.locationtech.jts.geom.Geometry`**|
444+
| [`hstore`][psql-hstore-ref] | [**`Map`**][java-map-ref]|
445+
| [`inet`][psql-inet-ref] | [**`InetAddress`**][java-inet-ref]|
446+
| [`integer`][psql-integer-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
447+
| [`interval`][psql-interval-ref] | **`Interval`**|
448+
| [`json`][psql-json-ref] | **`Json`**, [`String`][java-string-ref]. Reading: `ByteBuf`[`byte[]`][java-primitive-ref], [`ByteBuffer`][java-ByteBuffer-ref], [`String`][java-string-ref], [`InputStream`][java-inputstream-ref]|
449+
| [`jsonb`][psql-json-ref] | **`Json`**, [`String`][java-string-ref]. Reading: `ByteBuf`[`byte[]`][java-primitive-ref], [`ByteBuffer`][java-ByteBuffer-ref], [`String`][java-string-ref], [`InputStream`][java-inputstream-ref]|
450+
| [`line`][psql-line-ref] | **`Line`**|
451+
| [`lseg`][psql-lseq-ref] | **`Lseg`**|
452+
| [`macaddr`][psql-macaddr-ref] | Not yet supported.|
453+
| [`macaddr8`][psql-macaddr8-ref] | Not yet supported.|
454+
| [`money`][psql-money-ref] | Not yet supported. Please don't use this type. It is a very poor implementation. |
455+
| [`name`][psql-name-ref] | [**`String`**][java-string-ref]
456+
| [`numeric`][psql-bignumeric-ref] | [`BigDecimal`][java-bigdecimal-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigInteger`][java-biginteger-ref]|
457+
| [`oid`][psql-oid-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
458+
| [`path`][psql-path-ref] | **`Path`**|
459+
| [`pg_lsn`][psql-pg_lsn-ref] | Not yet supported.|
460+
| [`point`][psql-point-ref] | **`Point`**|
461+
| [`polygon`][psql-polygon-ref] | **`Polygon`**|
462+
| [`real`][psql-real-ref] | [**`Float`**][java-float-ref], [`Double`][java-double-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
463+
| [`smallint`][psql-smallint-ref] | [**`Short`**][java-short-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
464+
| [`smallserial`][psql-smallserial-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
465+
| [`serial`][psql-serial-ref] | [**`Long`**][java-long-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]|
466+
| [`text`][psql-text-ref] | [**`String`**][java-string-ref], [`Clob`][r2dbc-clob-ref]|
467+
| [`time [without time zone]`][psql-time-ref] | [`LocalTime`][java-lt-ref]|
468+
| [`time [with time zone]`][psql-time-ref] | [`OffsetTime`][java-ot-ref]|
469+
| [`timestamp [without time zone]`][psql-time-ref] |[**`LocalDateTime`**][java-ldt-ref], [`LocalTime`][java-lt-ref], [`LocalDate`][java-ld-ref], [`java.util.Date`][java-legacy-date-ref]|
470+
| [`timestamp [with time zone]`][psql-time-ref] | [**`OffsetDatetime`**][java-odt-ref], [`ZonedDateTime`][java-zdt-ref], [`Instant`][java-instant-ref]|
471+
| [`tsquery`][psql-tsquery-ref] | Not yet supported.|
472+
| [`tsvector`][psql-tsvector-ref] | Not yet supported.|
473+
| [`txid_snapshot`][psql-txid_snapshot-ref] | Not yet supported.|
474+
| [`uuid`][psql-uuid-ref] | [**`UUID`**][java-uuid-ref], [`String`][java-string-ref]||
475+
| [`xml`][psql-xml-ref] | Not yet supported. |
476+
| [`vector`][psql-vector-ref] | **`Vector`**, [`float[]`][java-float-ref] |
476477

477478
Types in **bold** indicate the native (default) Java type.
478479

@@ -550,6 +551,7 @@ Support for the following single-dimensional arrays (read and write):
550551
[psql-xml-ref]: https://www.postgresql.org/docs/current/datatype-xml.html
551552
[psql-runtime-config]: https://www.postgresql.org/docs/current/runtime-config-client.html
552553
[postgis-ref]: http://postgis.net/workshops/postgis-intro/geometries.html
554+
[psql-vector-ref]: https://github.com/pgvector/pgvector
553555

554556
[r2dbc-blob-ref]: https://r2dbc.io/spec/0.9.0.RELEASE/api/io/r2dbc/spi/Blob.html
555557
[r2dbc-clob-ref]: https://r2dbc.io/spec/0.9.0.RELEASE/api/io/r2dbc/spi/Clob.html

src/main/java/io/r2dbc/postgresql/codec/BuiltinDynamicCodecs.java

+11-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import reactor.util.annotation.Nullable;
2525

2626
import java.util.Arrays;
27+
import java.util.Collections;
2728
import java.util.stream.Collectors;
2829

2930
/**
@@ -44,21 +45,24 @@ enum BuiltinCodec {
4445
public boolean isSupported() {
4546
return this.jtsPresent;
4647
}
47-
};
48+
}, VECTOR("vector");
4849

4950
private final String name;
5051

5152
BuiltinCodec(String name) {
5253
this.name = name;
5354
}
5455

55-
public Codec<?> createCodec(ByteBufAllocator byteBufAllocator, int oid) {
56+
public Iterable<Codec<?>> createCodec(ByteBufAllocator byteBufAllocator, int oid, int typarray) {
5657

5758
switch (this) {
5859
case HSTORE:
59-
return new HStoreCodec(byteBufAllocator, oid);
60+
return Collections.singletonList(new HStoreCodec(byteBufAllocator, oid));
6061
case POSTGIS_GEOMETRY:
61-
return new PostgisGeometryCodec(oid);
62+
return Collections.singletonList(new PostgisGeometryCodec(oid));
63+
case VECTOR:
64+
VectorCodec vectorCodec = new VectorCodec(byteBufAllocator, oid, typarray);
65+
return Arrays.asList(vectorCodec, new VectorCodec.VectorArrayCodec(byteBufAllocator, vectorCodec), new VectorFloatCodec(byteBufAllocator, oid));
6266
default:
6367
throw new UnsupportedOperationException(String.format("Codec %s for OID %d not supported", name(), oid));
6468
}
@@ -93,11 +97,12 @@ public Publisher<Void> register(PostgresqlConnection connection, ByteBufAllocato
9397
.flatMap(it -> it.map((row, rowMetadata) -> {
9498

9599
int oid = PostgresqlObjectId.toInt(row.get("oid", Long.class));
100+
int typarray = PostgresqlObjectId.toInt(row.get("typarray", Long.class));
96101
String typname = row.get("typname", String.class);
97102

98103
BuiltinCodec lookup = BuiltinCodec.lookup(typname);
99104
if (lookup.isSupported()) {
100-
registry.addLast(lookup.createCodec(byteBufAllocator, oid));
105+
lookup.createCodec(byteBufAllocator, oid, typarray).forEach(registry::addLast);
101106
}
102107

103108
return EMPTY;
@@ -106,7 +111,7 @@ public Publisher<Void> register(PostgresqlConnection connection, ByteBufAllocato
106111
}
107112

108113
private PostgresqlStatement createQuery(PostgresqlConnection connection) {
109-
return connection.createStatement(String.format("SELECT oid, typname FROM pg_catalog.pg_type WHERE typname IN (%s)", getPlaceholders()));
114+
return connection.createStatement(String.format("SELECT oid, typname, typarray FROM pg_catalog.pg_type WHERE typname IN (%s)", getPlaceholders()));
110115
}
111116

112117
private static String getPlaceholders() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Copyright 2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql.codec;
18+
19+
import io.r2dbc.postgresql.util.Assert;
20+
21+
import java.util.Arrays;
22+
import java.util.Collection;
23+
24+
/**
25+
* Value object that maps to the {@code vector} datatype provided by Postgres pgvector.
26+
*
27+
* @since 1.0.3
28+
*/
29+
public class Vector {
30+
31+
private static final Vector EMPTY = new Vector(new float[0]);
32+
33+
private final float[] vec;
34+
35+
private Vector(float[] vec) {
36+
this.vec = Assert.requireNonNull(vec, "Vector must not be null");
37+
}
38+
39+
/**
40+
* Create a new empty {@link Vector}.
41+
*
42+
* @return the empty {@link Vector} object
43+
*/
44+
public static Vector empty() {
45+
return EMPTY;
46+
}
47+
48+
/**
49+
* Create a new {@link Vector} given {@code vector} points.
50+
*
51+
* @param vec the vector values
52+
* @return the new {@link Vector} object
53+
*/
54+
public static Vector of(float... vec) {
55+
Assert.requireNonNull(vec, "Vector must not be null");
56+
return vec.length == 0 ? empty() : new Vector(vec);
57+
}
58+
59+
/**
60+
* Create a new {@link Vector} given {@code vector} points.
61+
*
62+
* @param vec the vector values
63+
* @return the new {@link Vector} object
64+
*/
65+
public static Vector of(Collection<? extends Number> vec) {
66+
Assert.requireNonNull(vec, "Vector must not be null");
67+
68+
if (vec.isEmpty()) {
69+
return empty();
70+
}
71+
72+
float[] floats = new float[vec.size()];
73+
int index = 0;
74+
for (Number number : vec) {
75+
Number next = Assert.requireNonNull(number, "Vector must not contain null elements");
76+
floats[index++] = next.floatValue();
77+
}
78+
79+
return new Vector(floats);
80+
}
81+
82+
/**
83+
* Return the vector values.
84+
*
85+
* @return the vector values.
86+
*/
87+
public float[] getVector() {
88+
if (this.vec.length == 0) {
89+
return this.vec;
90+
}
91+
float[] copy = new float[this.vec.length];
92+
System.arraycopy(this.vec, 0, copy, 0, this.vec.length);
93+
return copy;
94+
}
95+
96+
@Override
97+
public boolean equals(Object o) {
98+
if (this == o) {
99+
return true;
100+
}
101+
if (o == null || getClass() != o.getClass()) {
102+
return false;
103+
}
104+
Vector other = (Vector) o;
105+
return Arrays.equals(this.vec, other.vec);
106+
}
107+
108+
@Override
109+
public int hashCode() {
110+
return Arrays.hashCode(this.vec);
111+
}
112+
113+
@Override
114+
public String toString() {
115+
StringBuilder builder = new StringBuilder();
116+
builder.append('[');
117+
118+
for (int i = 0; i < this.vec.length; i++) {
119+
if (i != 0) {
120+
builder.append(',');
121+
}
122+
builder.append(this.vec[i]);
123+
}
124+
builder.append(']');
125+
126+
return builder.toString();
127+
}
128+
}

0 commit comments

Comments
 (0)