Skip to content

Commit c604c06

Browse files
committed
Add implementation supporting multi-vectors (no tests yet)
1 parent c7be3f2 commit c604c06

File tree

10 files changed

+783
-140
lines changed

10 files changed

+783
-140
lines changed

src/collections/aggregate/index.ts

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,9 @@ class AggregateManager<T> implements Aggregate<T> {
378378
if (await this.grpcChecker) {
379379
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
380380
return this.grpc()
381-
.then((aggregate) =>
381+
.then(async (aggregate) =>
382382
aggregate.withHybrid({
383-
...Serialize.aggregate.hybrid(query, opts),
383+
...(await Serialize.aggregate.hybrid(query, opts)),
384384
groupBy: Serialize.aggregate.groupBy(group),
385385
limit: group.limit,
386386
})
@@ -489,9 +489,9 @@ class AggregateManager<T> implements Aggregate<T> {
489489
if (await this.grpcChecker) {
490490
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
491491
return this.grpc()
492-
.then((aggregate) =>
492+
.then(async (aggregate) =>
493493
aggregate.withNearVector({
494-
...Serialize.aggregate.nearVector(vector, opts),
494+
...(await Serialize.aggregate.nearVector(vector, opts)),
495495
groupBy: Serialize.aggregate.groupBy(group),
496496
limit: group.limit,
497497
})
@@ -609,7 +609,7 @@ class AggregateManager<T> implements Aggregate<T> {
609609
): Promise<AggregateResult<T, M>> {
610610
if (await this.grpcChecker) {
611611
return this.grpc()
612-
.then((aggregate) => aggregate.withHybrid(Serialize.aggregate.hybrid(query, opts)))
612+
.then(async (aggregate) => aggregate.withHybrid(await Serialize.aggregate.hybrid(query, opts)))
613613
.then((reply) => Deserialize.aggregate(reply));
614614
}
615615
let builder = this.base(opts?.returnMetrics, opts?.filters).withHybrid({
@@ -696,10 +696,12 @@ class AggregateManager<T> implements Aggregate<T> {
696696
): Promise<AggregateResult<T, M>> {
697697
if (await this.grpcChecker) {
698698
return this.grpc()
699-
.then((aggregate) => aggregate.withNearVector(Serialize.aggregate.nearVector(vector, opts)))
699+
.then(async (aggregate) =>
700+
aggregate.withNearVector(await Serialize.aggregate.nearVector(vector, opts))
701+
)
700702
.then((reply) => Deserialize.aggregate(reply));
701703
}
702-
if (!NearVectorInputGuards.is1DArray(vector)) {
704+
if (!NearVectorInputGuards.is1D(vector)) {
703705
throw new WeaviateInvalidInputError(
704706
'Vector can only be a 1D array of numbers when using `nearVector` with <1.29 Weaviate versions.'
705707
);

src/collections/query/check.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,18 @@ export class Check<T> {
9898
return check.supports;
9999
};
100100

101+
private checkSupportForVectors = async (
102+
vec?: NearVectorInputType | HybridNearVectorSubSearch | HybridNearTextSubSearch
103+
) => {
104+
if (vec === undefined || Serialize.isHybridNearTextSearch(vec)) return false;
105+
if (Serialize.isHybridNearVectorSearch(vec) && !Serialize.isMultiVectorPerTarget(vec.vector))
106+
return false;
107+
if (Serialize.isHybridVectorSearch(vec) && !Serialize.isMultiVectorPerTarget(vec)) return false;
108+
const check = await this.dbVersionSupport.supportsMultiVectorPerTargetSearch();
109+
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
110+
return check.supports;
111+
};
112+
101113
public nearSearch = (opts?: BaseNearOptions<T>) => {
102114
return Promise.all([
103115
this.getSearcher(),
@@ -118,6 +130,7 @@ export class Check<T> {
118130
this.checkSupportForMultiVectorSearch(vec),
119131
this.checkSupportForMultiVectorPerTargetSearch(vec),
120132
this.checkSupportForMultiWeightPerTargetSearch(opts),
133+
this.checkSupportForVectors(),
121134
this.checkSupportForNamedVectors(opts),
122135
]).then(
123136
([
@@ -126,14 +139,17 @@ export class Check<T> {
126139
supportsMultiVector,
127140
supportsVectorsForTargets,
128141
supportsWeightsForTargets,
142+
supportsVectors,
129143
]) => {
130144
const is126 = supportsMultiTarget || supportsMultiVector;
131145
const is127 = supportsVectorsForTargets || supportsWeightsForTargets;
146+
const is129 = supportsVectors;
132147
return {
133148
search,
134149
supportsTargets: is126 || is127,
135150
supportsVectorsForTargets: is127,
136151
supportsWeightsForTargets: is127,
152+
supportsVectors: is129,
137153
};
138154
}
139155
);
@@ -146,6 +162,7 @@ export class Check<T> {
146162
this.checkSupportForMultiVectorSearch(opts?.vector),
147163
this.checkSupportForMultiVectorPerTargetSearch(opts?.vector),
148164
this.checkSupportForMultiWeightPerTargetSearch(opts),
165+
this.checkSupportForVectors(),
149166
this.checkSupportForNamedVectors(opts),
150167
this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts),
151168
this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts),
@@ -156,14 +173,17 @@ export class Check<T> {
156173
supportsMultiVector,
157174
supportsWeightsForTargets,
158175
supportsVectorsForTargets,
176+
supportsVectors,
159177
]) => {
160178
const is126 = supportsMultiTarget || supportsMultiVector;
161179
const is127 = supportsVectorsForTargets || supportsWeightsForTargets;
180+
const is129 = supportsVectors;
162181
return {
163182
search,
164183
supportsTargets: is126 || is127,
165184
supportsWeightsForTargets: is127,
166185
supportsVectorsForTargets: is127,
186+
supportsVectors: is129,
167187
};
168188
}
169189
);

src/collections/query/factories.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { ListOfVectors, PrimitiveVectorType } from './types.js';
2+
import { NearVectorInputGuards } from './utils.js';
3+
4+
const hybridVector = {
5+
nearText: () => {},
6+
nearVector: () => {},
7+
};
8+
9+
const nearVector = {
10+
listOfVectors: <V extends PrimitiveVectorType>(...vectors: V[]): ListOfVectors<V> => {
11+
return {
12+
kind: 'listOfVectors',
13+
dimensionality: NearVectorInputGuards.is1D(vectors[0]) ? '1D' : '2D',
14+
vectors,
15+
};
16+
},
17+
};
18+
19+
export const queryFactory = {
20+
hybridVector,
21+
nearVector,
22+
};

src/collections/query/index.ts

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,25 @@ class QueryManager<T> implements Query<T> {
9595
public hybrid(query: string, opts?: HybridOptions<T>): QueryReturn<T> {
9696
return this.check
9797
.hybridSearch(opts)
98-
.then(({ search, supportsTargets, supportsWeightsForTargets, supportsVectorsForTargets }) =>
99-
search.withHybrid(
100-
Serialize.search.hybrid(
101-
{ query, supportsTargets, supportsWeightsForTargets, supportsVectorsForTargets },
102-
opts
103-
)
104-
)
98+
.then(
99+
async ({
100+
search,
101+
supportsTargets,
102+
supportsWeightsForTargets,
103+
supportsVectorsForTargets,
104+
supportsVectors,
105+
}) => ({
106+
search,
107+
args: await Serialize.search.hybrid({
108+
query,
109+
supportsTargets,
110+
supportsWeightsForTargets,
111+
supportsVectorsForTargets,
112+
supportsVectors,
113+
}),
114+
})
105115
)
116+
.then(({ search, args }) => search.withHybrid(args))
106117
.then((reply) => this.parseGroupByReply(opts, reply));
107118
}
108119

@@ -112,19 +123,19 @@ class QueryManager<T> implements Query<T> {
112123
return this.check
113124
.nearSearch(opts)
114125
.then(({ search, supportsTargets, supportsWeightsForTargets }) => {
115-
return toBase64FromMedia(image).then((image) =>
116-
search.withNearImage(
117-
Serialize.search.nearImage(
118-
{
119-
image,
120-
supportsTargets,
121-
supportsWeightsForTargets,
122-
},
123-
opts
124-
)
125-
)
126-
);
126+
return toBase64FromMedia(image).then((image) => ({
127+
search,
128+
args: Serialize.search.nearImage(
129+
{
130+
image,
131+
supportsTargets,
132+
supportsWeightsForTargets,
133+
},
134+
opts
135+
),
136+
}));
127137
})
138+
.then(({ search, args }) => search.withNearImage(args))
128139
.then((reply) => this.parseGroupByReply(opts, reply));
129140
}
130141

@@ -183,18 +194,18 @@ class QueryManager<T> implements Query<T> {
183194
public nearObject(id: string, opts?: NearOptions<T>): QueryReturn<T> {
184195
return this.check
185196
.nearSearch(opts)
186-
.then(({ search, supportsTargets, supportsWeightsForTargets }) =>
187-
search.withNearObject(
188-
Serialize.search.nearObject(
189-
{
190-
id,
191-
supportsTargets,
192-
supportsWeightsForTargets,
193-
},
194-
opts
195-
)
196-
)
197-
)
197+
.then(({ search, supportsTargets, supportsWeightsForTargets }) => ({
198+
search,
199+
args: Serialize.search.nearObject(
200+
{
201+
id,
202+
supportsTargets,
203+
supportsWeightsForTargets,
204+
},
205+
opts
206+
),
207+
}))
208+
.then(({ search, args }) => search.withNearObject(args))
198209
.then((reply) => this.parseGroupByReply(opts, reply));
199210
}
200211

@@ -203,18 +214,18 @@ class QueryManager<T> implements Query<T> {
203214
public nearText(query: string | string[], opts?: NearTextOptions<T>): QueryReturn<T> {
204215
return this.check
205216
.nearSearch(opts)
206-
.then(({ search, supportsTargets, supportsWeightsForTargets }) =>
207-
search.withNearText(
208-
Serialize.search.nearText(
209-
{
210-
query,
211-
supportsTargets,
212-
supportsWeightsForTargets,
213-
},
214-
opts
215-
)
216-
)
217-
)
217+
.then(({ search, supportsTargets, supportsWeightsForTargets }) => ({
218+
search,
219+
args: Serialize.search.nearText(
220+
{
221+
query,
222+
supportsTargets,
223+
supportsWeightsForTargets,
224+
},
225+
opts
226+
),
227+
}))
228+
.then(({ search, args }) => search.withNearText(args))
218229
.then((reply) => this.parseGroupByReply(opts, reply));
219230
}
220231

@@ -223,25 +234,34 @@ class QueryManager<T> implements Query<T> {
223234
public nearVector(vector: NearVectorInputType, opts?: NearOptions<T>): QueryReturn<T> {
224235
return this.check
225236
.nearVector(vector, opts)
226-
.then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) =>
227-
search.withNearVector(
228-
Serialize.search.nearVector(
237+
.then(
238+
async ({
239+
search,
240+
supportsTargets,
241+
supportsVectorsForTargets,
242+
supportsWeightsForTargets,
243+
supportsVectors,
244+
}) => ({
245+
search,
246+
args: await Serialize.search.nearVector(
229247
{
230248
vector,
231249
supportsTargets,
232250
supportsVectorsForTargets,
233251
supportsWeightsForTargets,
252+
supportsVectors,
234253
},
235254
opts
236-
)
237-
)
255+
),
256+
})
238257
)
258+
.then(({ search, args }) => search.withNearVector(args))
239259
.then((reply) => this.parseGroupByReply(opts, reply));
240260
}
241261
}
242262

243263
export default QueryManager.use;
244-
264+
export { queryFactory } from './factories.js';
245265
export {
246266
BaseBm25Options,
247267
BaseHybridOptions,

src/collections/query/types.ts

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,27 @@ export type GroupByNearTextOptions<T> = BaseNearTextOptions<T> & {
183183
/** The type of the media to search for in the `query.nearMedia` method */
184184
export type NearMediaType = 'audio' | 'depth' | 'image' | 'imu' | 'thermal' | 'video';
185185

186+
/** The allowed types of primitive vectors as stored in Weaviate.
187+
*
188+
* These correspond to 1-dimensional vectors, created by modules named `x2vec-`, and 2-dimensional vectors, created by modules named `x2colbert-`.
189+
*/
190+
export type PrimitiveVectorType = number[] | number[][];
191+
192+
export type ListOfVectors<V extends PrimitiveVectorType> = {
193+
kind: 'listOfVectors';
194+
dimensionality: '1D' | '2D';
195+
vectors: V[];
196+
};
197+
186198
/**
187199
* The vector(s) to search for in the `query/generate.nearVector` and `query/generate.hybrid` methods. One of:
188-
* - a single vector, in which case pass a single number array.
189-
* - multiple named vectors, in which case pass an object of type `Record<string, number[] | number[][]>`.
200+
* - a single 1-dimensional vector, in which case pass a single number array.
201+
* - a single 2-dimensional vector, in which case pas a single array of number arrays.
202+
* - multiple named vectors, in which case pass an object of type `Record<string, PrimitiveVectorType>`.
190203
*/
191-
export type NearVectorInputType = number[] | Record<string, number[] | number[][]>;
204+
export type NearVectorInputType =
205+
| PrimitiveVectorType
206+
| Record<string, PrimitiveVectorType | ListOfVectors<number[]> | ListOfVectors<number[][]>>;
192207

193208
/**
194209
* Over which vector spaces to perform the vector search query in the `nearX` search method. One of:

src/collections/query/utils.ts

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
import { MultiTargetVectorJoin } from '../index.js';
2-
import { NearVectorInputType, TargetVectorInputType } from './types.js';
2+
import { ListOfVectors, NearVectorInputType, PrimitiveVectorType, TargetVectorInputType } from './types.js';
33

44
export class NearVectorInputGuards {
5-
public static is1DArray(input: NearVectorInputType): input is number[] {
5+
public static is1D(input: NearVectorInputType): input is number[] {
66
return Array.isArray(input) && input.length > 0 && !Array.isArray(input[0]);
77
}
88

9-
public static isObject(input: NearVectorInputType): input is Record<string, number[] | number[][]> {
9+
public static is2D(input: NearVectorInputType): input is number[][] {
10+
return Array.isArray(input) && input.length > 0 && Array.isArray(input[0]) && input[0].length > 0;
11+
}
12+
13+
public static isObject(
14+
input: NearVectorInputType
15+
): input is Record<string, PrimitiveVectorType | ListOfVectors<number[]> | ListOfVectors<number[][]>> {
1016
return !Array.isArray(input);
1117
}
18+
19+
public static isListOf1D(
20+
input: PrimitiveVectorType | ListOfVectors<number[]> | ListOfVectors<number[][]>
21+
): input is ListOfVectors<number[]> {
22+
const i = input as ListOfVectors<number[]>;
23+
return !Array.isArray(input) && i.kind === 'listOfVectors' && i.dimensionality == '1D';
24+
}
25+
26+
public static isListOf2D(
27+
input: PrimitiveVectorType | ListOfVectors<number[]> | ListOfVectors<number[][]>
28+
): input is ListOfVectors<number[][]> {
29+
const i = input as ListOfVectors<number[][]>;
30+
return !Array.isArray(input) && i.kind === 'listOfVectors' && i.dimensionality == '2D';
31+
}
1232
}
1333

1434
export class ArrayInputGuards {

0 commit comments

Comments
 (0)