From c604c06bd7028cdaa78b52a25d9465ce862470ed Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 12 Feb 2025 09:51:28 +0000 Subject: [PATCH 1/4] Add implementation supporting multi-vectors (no tests yet) --- src/collections/aggregate/index.ts | 16 +- src/collections/query/check.ts | 20 ++ src/collections/query/factories.ts | 22 ++ src/collections/query/index.ts | 118 ++++---- src/collections/query/types.ts | 21 +- src/collections/query/utils.ts | 26 +- src/collections/serialize/index.ts | 268 ++++++++++++------ src/index.ts | 3 +- src/proto/v1/generative.ts | 419 ++++++++++++++++++++++++++++- src/utils/dbVersion.ts | 10 + 10 files changed, 783 insertions(+), 140 deletions(-) create mode 100644 src/collections/query/factories.ts diff --git a/src/collections/aggregate/index.ts b/src/collections/aggregate/index.ts index 1dad6460..06f86649 100644 --- a/src/collections/aggregate/index.ts +++ b/src/collections/aggregate/index.ts @@ -378,9 +378,9 @@ class AggregateManager implements Aggregate { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; return this.grpc() - .then((aggregate) => + .then(async (aggregate) => aggregate.withHybrid({ - ...Serialize.aggregate.hybrid(query, opts), + ...(await Serialize.aggregate.hybrid(query, opts)), groupBy: Serialize.aggregate.groupBy(group), limit: group.limit, }) @@ -489,9 +489,9 @@ class AggregateManager implements Aggregate { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; return this.grpc() - .then((aggregate) => + .then(async (aggregate) => aggregate.withNearVector({ - ...Serialize.aggregate.nearVector(vector, opts), + ...(await Serialize.aggregate.nearVector(vector, opts)), groupBy: Serialize.aggregate.groupBy(group), limit: group.limit, }) @@ -609,7 +609,7 @@ class AggregateManager implements Aggregate { ): Promise> { if (await this.grpcChecker) { return this.grpc() - .then((aggregate) => aggregate.withHybrid(Serialize.aggregate.hybrid(query, opts))) + .then(async (aggregate) => aggregate.withHybrid(await Serialize.aggregate.hybrid(query, opts))) .then((reply) => Deserialize.aggregate(reply)); } let builder = this.base(opts?.returnMetrics, opts?.filters).withHybrid({ @@ -696,10 +696,12 @@ class AggregateManager implements Aggregate { ): Promise> { if (await this.grpcChecker) { return this.grpc() - .then((aggregate) => aggregate.withNearVector(Serialize.aggregate.nearVector(vector, opts))) + .then(async (aggregate) => + aggregate.withNearVector(await Serialize.aggregate.nearVector(vector, opts)) + ) .then((reply) => Deserialize.aggregate(reply)); } - if (!NearVectorInputGuards.is1DArray(vector)) { + if (!NearVectorInputGuards.is1D(vector)) { throw new WeaviateInvalidInputError( 'Vector can only be a 1D array of numbers when using `nearVector` with <1.29 Weaviate versions.' ); diff --git a/src/collections/query/check.ts b/src/collections/query/check.ts index 291738de..cf437632 100644 --- a/src/collections/query/check.ts +++ b/src/collections/query/check.ts @@ -98,6 +98,18 @@ export class Check { return check.supports; }; + private checkSupportForVectors = async ( + vec?: NearVectorInputType | HybridNearVectorSubSearch | HybridNearTextSubSearch + ) => { + if (vec === undefined || Serialize.isHybridNearTextSearch(vec)) return false; + if (Serialize.isHybridNearVectorSearch(vec) && !Serialize.isMultiVectorPerTarget(vec.vector)) + return false; + if (Serialize.isHybridVectorSearch(vec) && !Serialize.isMultiVectorPerTarget(vec)) return false; + const check = await this.dbVersionSupport.supportsMultiVectorPerTargetSearch(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); + return check.supports; + }; + public nearSearch = (opts?: BaseNearOptions) => { return Promise.all([ this.getSearcher(), @@ -118,6 +130,7 @@ export class Check { this.checkSupportForMultiVectorSearch(vec), this.checkSupportForMultiVectorPerTargetSearch(vec), this.checkSupportForMultiWeightPerTargetSearch(opts), + this.checkSupportForVectors(), this.checkSupportForNamedVectors(opts), ]).then( ([ @@ -126,14 +139,17 @@ export class Check { supportsMultiVector, supportsVectorsForTargets, supportsWeightsForTargets, + supportsVectors, ]) => { const is126 = supportsMultiTarget || supportsMultiVector; const is127 = supportsVectorsForTargets || supportsWeightsForTargets; + const is129 = supportsVectors; return { search, supportsTargets: is126 || is127, supportsVectorsForTargets: is127, supportsWeightsForTargets: is127, + supportsVectors: is129, }; } ); @@ -146,6 +162,7 @@ export class Check { this.checkSupportForMultiVectorSearch(opts?.vector), this.checkSupportForMultiVectorPerTargetSearch(opts?.vector), this.checkSupportForMultiWeightPerTargetSearch(opts), + this.checkSupportForVectors(), this.checkSupportForNamedVectors(opts), this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts), this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts), @@ -156,14 +173,17 @@ export class Check { supportsMultiVector, supportsWeightsForTargets, supportsVectorsForTargets, + supportsVectors, ]) => { const is126 = supportsMultiTarget || supportsMultiVector; const is127 = supportsVectorsForTargets || supportsWeightsForTargets; + const is129 = supportsVectors; return { search, supportsTargets: is126 || is127, supportsWeightsForTargets: is127, supportsVectorsForTargets: is127, + supportsVectors: is129, }; } ); diff --git a/src/collections/query/factories.ts b/src/collections/query/factories.ts new file mode 100644 index 00000000..10d5c4b7 --- /dev/null +++ b/src/collections/query/factories.ts @@ -0,0 +1,22 @@ +import { ListOfVectors, PrimitiveVectorType } from './types.js'; +import { NearVectorInputGuards } from './utils.js'; + +const hybridVector = { + nearText: () => {}, + nearVector: () => {}, +}; + +const nearVector = { + listOfVectors: (...vectors: V[]): ListOfVectors => { + return { + kind: 'listOfVectors', + dimensionality: NearVectorInputGuards.is1D(vectors[0]) ? '1D' : '2D', + vectors, + }; + }, +}; + +export const queryFactory = { + hybridVector, + nearVector, +}; diff --git a/src/collections/query/index.ts b/src/collections/query/index.ts index 3be179d3..69adc646 100644 --- a/src/collections/query/index.ts +++ b/src/collections/query/index.ts @@ -95,14 +95,25 @@ class QueryManager implements Query { public hybrid(query: string, opts?: HybridOptions): QueryReturn { return this.check .hybridSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets, supportsVectorsForTargets }) => - search.withHybrid( - Serialize.search.hybrid( - { query, supportsTargets, supportsWeightsForTargets, supportsVectorsForTargets }, - opts - ) - ) + .then( + async ({ + search, + supportsTargets, + supportsWeightsForTargets, + supportsVectorsForTargets, + supportsVectors, + }) => ({ + search, + args: await Serialize.search.hybrid({ + query, + supportsTargets, + supportsWeightsForTargets, + supportsVectorsForTargets, + supportsVectors, + }), + }) ) + .then(({ search, args }) => search.withHybrid(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -112,19 +123,19 @@ class QueryManager implements Query { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => { - return toBase64FromMedia(image).then((image) => - search.withNearImage( - Serialize.search.nearImage( - { - image, - supportsTargets, - supportsWeightsForTargets, - }, - opts - ) - ) - ); + return toBase64FromMedia(image).then((image) => ({ + search, + args: Serialize.search.nearImage( + { + image, + supportsTargets, + supportsWeightsForTargets, + }, + opts + ), + })); }) + .then(({ search, args }) => search.withNearImage(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -183,18 +194,18 @@ class QueryManager implements Query { public nearObject(id: string, opts?: NearOptions): QueryReturn { return this.check .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - search.withNearObject( - Serialize.search.nearObject( - { - id, - supportsTargets, - supportsWeightsForTargets, - }, - opts - ) - ) - ) + .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ + search, + args: Serialize.search.nearObject( + { + id, + supportsTargets, + supportsWeightsForTargets, + }, + opts + ), + })) + .then(({ search, args }) => search.withNearObject(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -203,18 +214,18 @@ class QueryManager implements Query { public nearText(query: string | string[], opts?: NearTextOptions): QueryReturn { return this.check .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - search.withNearText( - Serialize.search.nearText( - { - query, - supportsTargets, - supportsWeightsForTargets, - }, - opts - ) - ) - ) + .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ + search, + args: Serialize.search.nearText( + { + query, + supportsTargets, + supportsWeightsForTargets, + }, + opts + ), + })) + .then(({ search, args }) => search.withNearText(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -223,25 +234,34 @@ class QueryManager implements Query { public nearVector(vector: NearVectorInputType, opts?: NearOptions): QueryReturn { return this.check .nearVector(vector, opts) - .then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) => - search.withNearVector( - Serialize.search.nearVector( + .then( + async ({ + search, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + supportsVectors, + }) => ({ + search, + args: await Serialize.search.nearVector( { vector, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets, + supportsVectors, }, opts - ) - ) + ), + }) ) + .then(({ search, args }) => search.withNearVector(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } } export default QueryManager.use; - +export { queryFactory } from './factories.js'; export { BaseBm25Options, BaseHybridOptions, diff --git a/src/collections/query/types.ts b/src/collections/query/types.ts index b00fa7c8..3ad93d21 100644 --- a/src/collections/query/types.ts +++ b/src/collections/query/types.ts @@ -183,12 +183,27 @@ export type GroupByNearTextOptions = BaseNearTextOptions & { /** The type of the media to search for in the `query.nearMedia` method */ export type NearMediaType = 'audio' | 'depth' | 'image' | 'imu' | 'thermal' | 'video'; +/** The allowed types of primitive vectors as stored in Weaviate. + * + * These correspond to 1-dimensional vectors, created by modules named `x2vec-`, and 2-dimensional vectors, created by modules named `x2colbert-`. + */ +export type PrimitiveVectorType = number[] | number[][]; + +export type ListOfVectors = { + kind: 'listOfVectors'; + dimensionality: '1D' | '2D'; + vectors: V[]; +}; + /** * The vector(s) to search for in the `query/generate.nearVector` and `query/generate.hybrid` methods. One of: - * - a single vector, in which case pass a single number array. - * - multiple named vectors, in which case pass an object of type `Record`. + * - a single 1-dimensional vector, in which case pass a single number array. + * - a single 2-dimensional vector, in which case pas a single array of number arrays. + * - multiple named vectors, in which case pass an object of type `Record`. */ -export type NearVectorInputType = number[] | Record; +export type NearVectorInputType = + | PrimitiveVectorType + | Record | ListOfVectors>; /** * Over which vector spaces to perform the vector search query in the `nearX` search method. One of: diff --git a/src/collections/query/utils.ts b/src/collections/query/utils.ts index 4bbe9f76..fb986368 100644 --- a/src/collections/query/utils.ts +++ b/src/collections/query/utils.ts @@ -1,14 +1,34 @@ import { MultiTargetVectorJoin } from '../index.js'; -import { NearVectorInputType, TargetVectorInputType } from './types.js'; +import { ListOfVectors, NearVectorInputType, PrimitiveVectorType, TargetVectorInputType } from './types.js'; export class NearVectorInputGuards { - public static is1DArray(input: NearVectorInputType): input is number[] { + public static is1D(input: NearVectorInputType): input is number[] { return Array.isArray(input) && input.length > 0 && !Array.isArray(input[0]); } - public static isObject(input: NearVectorInputType): input is Record { + public static is2D(input: NearVectorInputType): input is number[][] { + return Array.isArray(input) && input.length > 0 && Array.isArray(input[0]) && input[0].length > 0; + } + + public static isObject( + input: NearVectorInputType + ): input is Record | ListOfVectors> { return !Array.isArray(input); } + + public static isListOf1D( + input: PrimitiveVectorType | ListOfVectors | ListOfVectors + ): input is ListOfVectors { + const i = input as ListOfVectors; + return !Array.isArray(input) && i.kind === 'listOfVectors' && i.dimensionality == '1D'; + } + + public static isListOf2D( + input: PrimitiveVectorType | ListOfVectors | ListOfVectors + ): input is ListOfVectors { + const i = input as ListOfVectors; + return !Array.isArray(input) && i.kind === 'listOfVectors' && i.dimensionality == '2D'; + } } export class ArrayInputGuards { diff --git a/src/collections/serialize/index.ts b/src/collections/serialize/index.ts index 747bf00a..961d8b70 100644 --- a/src/collections/serialize/index.ts +++ b/src/collections/serialize/index.ts @@ -83,7 +83,9 @@ import { ObjectProperties, ObjectPropertiesValue, TextArrayProperties, + Vectors, Vectors as VectorsGrpc, + Vectors_VectorType, } from '../../proto/v1/base.js'; import { FilterId } from '../filters/classes.js'; import { FilterValue, Filters } from '../filters/index.js'; @@ -397,18 +399,19 @@ class Aggregate { }); }; - public static hybrid = ( + public static hybrid = async ( query: string, opts?: AggregateHybridOptions> - ): AggregateHybridArgs => { + ): Promise => { return { ...Aggregate.common(opts), objectLimit: opts?.objectLimit, - hybrid: Serialize.hybridSearch({ + hybrid: await Serialize.hybridSearch({ query: query, supportsTargets: true, supportsVectorsForTargets: true, supportsWeightsForTargets: true, + supportsVectors: true, ...opts, }), }; @@ -462,18 +465,19 @@ class Aggregate { }; }; - public static nearVector = ( + public static nearVector = async ( vector: NearVectorInputType, opts?: AggregateNearOptions> - ): AggregateNearVectorArgs => { + ): Promise => { return { ...Aggregate.common(opts), objectLimit: opts?.objectLimit, - nearVector: Serialize.nearVectorSearch({ + nearVector: await Serialize.nearVectorSearch({ vector, supportsTargets: true, supportsVectorsForTargets: true, supportsWeightsForTargets: true, + supportsVectors: true, ...opts, }), }; @@ -645,18 +649,19 @@ class Search { }); }; - public static hybrid = ( + public static hybrid = async ( args: { query: string; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; }, opts?: HybridOptions - ): SearchHybridArgs => { + ): Promise => { return { ...Search.common(opts), - hybridSearch: Serialize.hybridSearch({ ...args, ...opts }), + hybridSearch: await Serialize.hybridSearch({ ...args, ...opts }), groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; @@ -758,18 +763,19 @@ class Search { }; }; - public static nearVector = ( + public static nearVector = async ( args: { vector: NearVectorInputType; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; }, opts?: NearOptions - ): SearchNearVectorArgs => { + ): Promise => { return { ...Search.common(opts), - nearVector: Serialize.nearVectorSearch({ ...args, ...opts }), + nearVector: await Serialize.nearVectorSearch({ ...args, ...opts }), groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; @@ -867,19 +873,21 @@ export class Serialize { return (vector as HybridNearVectorSubSearch)?.vector !== undefined; }; - private static hybridVector = (args: { + private static hybridVector = async (args: { supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; vector?: BaseHybridOptions['vector']; }) => { const vector = args.vector; if (Serialize.isHybridVectorSearch(vector)) { - const { targets, targetVectors, vectorBytes, vectorPerTarget, vectorForTargets } = Serialize.vectors({ - ...args, - argumentName: 'vector', - vector: vector, - }); + const { targets, targetVectors, vectorBytes, vectorPerTarget, vectorForTargets } = + await Serialize.vectors({ + ...args, + argumentName: 'vector', + vector: vector, + }); return vectorBytes !== undefined ? { vectorBytes, targetVectors, targets } : { @@ -904,11 +912,12 @@ export class Serialize { }), }; } else if (Serialize.isHybridNearVectorSearch(vector)) { - const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets } = Serialize.vectors({ - ...args, - argumentName: 'vector', - vector: vector.vector, - }); + const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets } = + await Serialize.vectors({ + ...args, + argumentName: 'vector', + vector: vector.vector, + }); return { targetVectors, targets, @@ -926,14 +935,15 @@ export class Serialize { } }; - public static hybridSearch = ( + public static hybridSearch = async ( args: { query: string; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; } & HybridSearchOptions - ): Hybrid => { + ): Promise => { const fusionType = (fusionType?: string): Hybrid_FusionType => { switch (fusionType) { case 'Ranked': @@ -944,7 +954,7 @@ export class Serialize { return Hybrid_FusionType.FUSION_TYPE_UNSPECIFIED; } }; - const { targets, targetVectors, vectorBytes, nearText, nearVector } = Serialize.hybridVector(args); + const { targets, targetVectors, vectorBytes, nearText, nearVector } = await Serialize.hybridVector(args); return Hybrid.fromPartial({ query: args.query, alpha: args.alpha ? args.alpha : 0.5, @@ -1071,23 +1081,63 @@ export class Serialize { }); }; + private static vectorToBuffer = (vector: number[]): ArrayBufferLike => { + return new Float32Array(vector).buffer; + }; + private static vectorToBytes = (vector: number[]): Uint8Array => { - return new Uint8Array(new Float32Array(vector).buffer); + const uint32len = 4; + const dv = new DataView(new ArrayBuffer(vector.length * uint32len)); + vector.forEach((v, i) => dv.setFloat32(i * uint32len, v, true)); + return new Uint8Array(dv.buffer); + }; + + /** + * Convert a 2D array of numbers to a Uint8Array + * + * Defined as an async method so that control can be relinquished back to the event loop on each outer loop for large vectors + */ + private static vectorsToBytes = async (vectors: number[][]): Promise => { + if (vectors.length === 0) { + return new Uint8Array(); + } + if (vectors[0].length === 0) { + return new Uint8Array(); + } + + const uint16Len = 2; + const uint32len = 4; + const dim = vectors[0].length; + + const dv = new DataView(new ArrayBuffer(uint16Len + vectors.length * dim * uint32len)); + dv.setUint16(0, dim, true); + dv.setUint16(uint16Len, vectors.length, true); + await Promise.all( + vectors.map((vector, i) => + new Promise((resolve) => setTimeout(resolve, 0)).then(() => + vector.forEach((v, j) => dv.setFloat32(uint16Len + i * dim * uint32len + j * uint32len, v, true)) + ) + ) + ); + + return new Uint8Array(dv.buffer); }; - public static nearVectorSearch = (args: { + public static nearVectorSearch = async (args: { vector: NearVectorInputType; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; certainty?: number; distance?: number; targetVector?: TargetVectorInputType; - }): NearVector => { - const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets } = Serialize.vectors({ - ...args, - argumentName: 'nearVector', - }); + }): Promise => { + const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets } = + await Serialize.vectors({ + ...args, + argumentName: 'nearVector', + }); return NearVector.fromPartial({ certainty: args.certainty, distance: args.distance, @@ -1127,20 +1177,22 @@ export class Serialize { } }; - private static vectors = (args: { + static vectors = async (args: { supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; argumentName: 'nearVector' | 'vector'; targetVector?: TargetVectorInputType; vector?: NearVectorInputType; - }): { + }): Promise<{ targetVectors?: string[]; targets?: Targets; vectorBytes?: Uint8Array; + vectors?: Vectors[]; vectorPerTarget?: Record; vectorForTargets?: VectorForTarget[]; - } => { + }> => { const invalidVectorError = new WeaviateInvalidInputError(`${args.argumentName} argument must be populated and: - an array of numbers (number[]) @@ -1154,38 +1206,16 @@ export class Serialize { if (Object.keys(args.vector).length === 0) { throw invalidVectorError; } - if (args.supportsVectorsForTargets) { - const vectorForTargets: VectorForTarget[] = Object.entries(args.vector) - .map(([target, vector]) => { - return { - target, - vector: vector, - }; - }) - .reduce((acc, { target, vector }) => { - return ArrayInputGuards.is2DArray(vector) - ? acc.concat( - vector.map((v) => ({ name: target, vectorBytes: Serialize.vectorToBytes(v), vectors: [] })) - ) - : acc.concat([{ name: target, vectorBytes: Serialize.vectorToBytes(vector), vectors: [] }]); - }, [] as VectorForTarget[]); - return args.targetVector !== undefined - ? { - ...Serialize.targetVector(args), - vectorForTargets, - } - : { - targetVectors: undefined, - targets: Targets.fromPartial({ - targetVectors: vectorForTargets.map((v) => v.name), - }), - vectorForTargets, - }; - } else { + if (!args.supportsVectorsForTargets) { const vectorPerTarget: Record = {}; Object.entries(args.vector).forEach(([k, v]) => { if (ArrayInputGuards.is2DArray(v)) { - return; + throw new WeaviateUnsupportedFeatureError('Multi-vectors are not supported in Weaviate <1.29.0'); + } + if (NearVectorInputGuards.isListOf1D(v) || NearVectorInputGuards.isListOf2D(v)) { + throw new WeaviateUnsupportedFeatureError( + 'Lists of vectors are not supported in Weaviate <1.29.0' + ); } vectorPerTarget[k] = Serialize.vectorToBytes(v); }); @@ -1210,21 +1240,107 @@ export class Serialize { }; } } - } else { - if (args.vector.length === 0) { - throw invalidVectorError; - } - if (NearVectorInputGuards.is1DArray(args.vector)) { - const { targetVectors, targets } = Serialize.targetVector(args); - const vectorBytes = Serialize.vectorToBytes(args.vector); - return { - targetVectors, - targets, - vectorBytes, + const vectorForTargets: VectorForTarget[] = []; + for (const [target, vector] of Object.entries(args.vector)) { + const vectorForTarget: VectorForTarget = { + name: target, + vectorBytes: new Uint8Array(), + vectors: [], }; + if (!args.supportsVectors) { + if (NearVectorInputGuards.isListOf2D(vector)) { + throw new WeaviateUnsupportedFeatureError( + 'Lists of multi-vectors are not supported in Weaviate <1.29.0' + ); + } + if (ArrayInputGuards.is2DArray(vector)) { + vector.forEach((v) => + vectorForTargets.push({ name: target, vectorBytes: Serialize.vectorToBytes(v), vectors: [] }) + ); + continue; + } + if (NearVectorInputGuards.isListOf1D(vector)) { + vector.vectors.forEach((v) => + vectorForTargets.push({ + name: target, + vectorBytes: Serialize.vectorToBytes(v), + vectors: [], + }) + ); + continue; + } + vectorForTargets.push({ name: target, vectorBytes: Serialize.vectorToBytes(vector), vectors: [] }); + } + if (ArrayInputGuards.is2DArray(vector)) { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_MULTI_FP32, + vectorBytes: await Serialize.vectorsToBytes(vector), // eslint-disable-line no-await-in-loop + }) + ); + } + if (NearVectorInputGuards.isListOf1D(vector)) { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, + vectorBytes: await Serialize.vectorsToBytes(vector.vectors), // eslint-disable-line no-await-in-loop + }) + ); + } + if (NearVectorInputGuards.isListOf2D(vector)) { + for (const v of vector.vectors) { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_MULTI_FP32, + vectorBytes: await Serialize.vectorsToBytes(v), // eslint-disable-line no-await-in-loop + }) + ); + } + vectorForTargets.push(vectorForTarget); + continue; + } } + return args.targetVector !== undefined + ? { + ...Serialize.targetVector(args), + vectorForTargets, + } + : { + targetVectors: undefined, + targets: Targets.fromPartial({ + targetVectors: vectorForTargets.map((v) => v.name), + }), + vectorForTargets, + }; + } + if (args.vector.length === 0) { throw invalidVectorError; } + if (NearVectorInputGuards.is1D(args.vector)) { + const { targetVectors, targets } = Serialize.targetVector(args); + const vectorBytes = Serialize.vectorToBytes(args.vector); + return args.supportsVectors + ? { + targets, + targetVectors, + vectors: [Vectors.fromPartial({ type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, vectorBytes })], + } + : { + targets, + targetVectors, + vectorBytes, + }; + } + if (NearVectorInputGuards.is2D(args.vector)) { + const { targetVectors, targets } = Serialize.targetVector(args); + const vectorBytes = await Serialize.vectorsToBytes(args.vector); + return { + targets, + targetVectors, + vectors: [Vectors.fromPartial({ type: Vectors_VectorType.VECTOR_TYPE_MULTI_FP32, vectorBytes })], + }; + } + throw invalidVectorError; }; private static targets = ( diff --git a/src/index.ts b/src/index.ts index 101bc14e..7e270ab9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,7 +3,7 @@ import { Backup, backup } from './collections/backup/client.js'; import cluster, { Cluster } from './collections/cluster/index.js'; import { configGuards } from './collections/config/index.js'; import { configure, reconfigure } from './collections/configure/index.js'; -import collections, { Collections } from './collections/index.js'; +import collections, { Collections, queryFactory } from './collections/index.js'; import { AccessTokenCredentialsInput, ApiKey, @@ -251,6 +251,7 @@ const app = { configGuards, reconfigure, permissions, + query: queryFactory, }; export default app; diff --git a/src/proto/v1/generative.ts b/src/proto/v1/generative.ts index 12b1619f..bb6adb3b 100644 --- a/src/proto/v1/generative.ts +++ b/src/proto/v1/generative.ts @@ -51,6 +51,7 @@ export interface GenerativeProvider { google?: GenerativeGoogle | undefined; databricks?: GenerativeDatabricks | undefined; friendliai?: GenerativeFriendliAI | undefined; + nvidia?: GenerativeNvidia | undefined; } export interface GenerativeAnthropic { @@ -61,6 +62,7 @@ export interface GenerativeAnthropic { topK?: number | undefined; topP?: number | undefined; stopSequences?: TextArray | undefined; + images?: TextArray | undefined; } export interface GenerativeAnyscale { @@ -77,6 +79,7 @@ export interface GenerativeAWS { endpoint?: string | undefined; targetModel?: string | undefined; targetVariant?: string | undefined; + images?: TextArray | undefined; } export interface GenerativeCohere { @@ -106,6 +109,7 @@ export interface GenerativeOllama { apiEndpoint?: string | undefined; model?: string | undefined; temperature?: number | undefined; + images?: TextArray | undefined; } export interface GenerativeOpenAI { @@ -122,6 +126,7 @@ export interface GenerativeOpenAI { resourceName?: string | undefined; deploymentId?: string | undefined; isAzure?: boolean | undefined; + images?: TextArray | undefined; } export interface GenerativeGoogle { @@ -137,6 +142,7 @@ export interface GenerativeGoogle { projectId?: string | undefined; endpointId?: string | undefined; region?: string | undefined; + images?: TextArray | undefined; } export interface GenerativeDatabricks { @@ -162,6 +168,14 @@ export interface GenerativeFriendliAI { topP?: number | undefined; } +export interface GenerativeNvidia { + baseUrl?: string | undefined; + model?: string | undefined; + temperature?: number | undefined; + topP?: number | undefined; + maxTokens?: number | undefined; +} + export interface GenerativeAnthropicMetadata { usage: GenerativeAnthropicMetadata_Usage | undefined; } @@ -273,6 +287,16 @@ export interface GenerativeFriendliAIMetadata_Usage { totalTokens?: number | undefined; } +export interface GenerativeNvidiaMetadata { + usage?: GenerativeNvidiaMetadata_Usage | undefined; +} + +export interface GenerativeNvidiaMetadata_Usage { + promptTokens?: number | undefined; + completionTokens?: number | undefined; + totalTokens?: number | undefined; +} + export interface GenerativeMetadata { anthropic?: GenerativeAnthropicMetadata | undefined; anyscale?: GenerativeAnyscaleMetadata | undefined; @@ -285,6 +309,7 @@ export interface GenerativeMetadata { google?: GenerativeGoogleMetadata | undefined; databricks?: GenerativeDatabricksMetadata | undefined; friendliai?: GenerativeFriendliAIMetadata | undefined; + nvidia?: GenerativeNvidiaMetadata | undefined; } export interface GenerativeReply { @@ -630,6 +655,7 @@ function createBaseGenerativeProvider(): GenerativeProvider { google: undefined, databricks: undefined, friendliai: undefined, + nvidia: undefined, }; } @@ -671,6 +697,9 @@ export const GenerativeProvider = { if (message.friendliai !== undefined) { GenerativeFriendliAI.encode(message.friendliai, writer.uint32(98).fork()).ldelim(); } + if (message.nvidia !== undefined) { + GenerativeNvidia.encode(message.nvidia, writer.uint32(106).fork()).ldelim(); + } return writer; }, @@ -765,6 +794,13 @@ export const GenerativeProvider = { message.friendliai = GenerativeFriendliAI.decode(reader, reader.uint32()); continue; + case 13: + if (tag !== 106) { + break; + } + + message.nvidia = GenerativeNvidia.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -788,6 +824,7 @@ export const GenerativeProvider = { google: isSet(object.google) ? GenerativeGoogle.fromJSON(object.google) : undefined, databricks: isSet(object.databricks) ? GenerativeDatabricks.fromJSON(object.databricks) : undefined, friendliai: isSet(object.friendliai) ? GenerativeFriendliAI.fromJSON(object.friendliai) : undefined, + nvidia: isSet(object.nvidia) ? GenerativeNvidia.fromJSON(object.nvidia) : undefined, }; }, @@ -829,6 +866,9 @@ export const GenerativeProvider = { if (message.friendliai !== undefined) { obj.friendliai = GenerativeFriendliAI.toJSON(message.friendliai); } + if (message.nvidia !== undefined) { + obj.nvidia = GenerativeNvidia.toJSON(message.nvidia); + } return obj; }, @@ -869,6 +909,9 @@ export const GenerativeProvider = { message.friendliai = (object.friendliai !== undefined && object.friendliai !== null) ? GenerativeFriendliAI.fromPartial(object.friendliai) : undefined; + message.nvidia = (object.nvidia !== undefined && object.nvidia !== null) + ? GenerativeNvidia.fromPartial(object.nvidia) + : undefined; return message; }, }; @@ -882,6 +925,7 @@ function createBaseGenerativeAnthropic(): GenerativeAnthropic { topK: undefined, topP: undefined, stopSequences: undefined, + images: undefined, }; } @@ -908,6 +952,9 @@ export const GenerativeAnthropic = { if (message.stopSequences !== undefined) { TextArray.encode(message.stopSequences, writer.uint32(58).fork()).ldelim(); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(66).fork()).ldelim(); + } return writer; }, @@ -967,6 +1014,13 @@ export const GenerativeAnthropic = { message.stopSequences = TextArray.decode(reader, reader.uint32()); continue; + case 8: + if (tag !== 66) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -985,6 +1039,7 @@ export const GenerativeAnthropic = { topK: isSet(object.topK) ? globalThis.Number(object.topK) : undefined, topP: isSet(object.topP) ? globalThis.Number(object.topP) : undefined, stopSequences: isSet(object.stopSequences) ? TextArray.fromJSON(object.stopSequences) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, }; }, @@ -1011,6 +1066,9 @@ export const GenerativeAnthropic = { if (message.stopSequences !== undefined) { obj.stopSequences = TextArray.toJSON(message.stopSequences); } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } return obj; }, @@ -1028,6 +1086,9 @@ export const GenerativeAnthropic = { message.stopSequences = (object.stopSequences !== undefined && object.stopSequences !== null) ? TextArray.fromPartial(object.stopSequences) : undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; return message; }, }; @@ -1130,6 +1191,7 @@ function createBaseGenerativeAWS(): GenerativeAWS { endpoint: undefined, targetModel: undefined, targetVariant: undefined, + images: undefined, }; } @@ -1156,6 +1218,9 @@ export const GenerativeAWS = { if (message.targetVariant !== undefined) { writer.uint32(106).string(message.targetVariant); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(114).fork()).ldelim(); + } return writer; }, @@ -1215,6 +1280,13 @@ export const GenerativeAWS = { message.targetVariant = reader.string(); continue; + case 14: + if (tag !== 114) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -1233,6 +1305,7 @@ export const GenerativeAWS = { endpoint: isSet(object.endpoint) ? globalThis.String(object.endpoint) : undefined, targetModel: isSet(object.targetModel) ? globalThis.String(object.targetModel) : undefined, targetVariant: isSet(object.targetVariant) ? globalThis.String(object.targetVariant) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, }; }, @@ -1259,6 +1332,9 @@ export const GenerativeAWS = { if (message.targetVariant !== undefined) { obj.targetVariant = message.targetVariant; } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } return obj; }, @@ -1274,6 +1350,9 @@ export const GenerativeAWS = { message.endpoint = object.endpoint ?? undefined; message.targetModel = object.targetModel ?? undefined; message.targetVariant = object.targetVariant ?? undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; return message; }, }; @@ -1632,7 +1711,7 @@ export const GenerativeMistral = { }; function createBaseGenerativeOllama(): GenerativeOllama { - return { apiEndpoint: undefined, model: undefined, temperature: undefined }; + return { apiEndpoint: undefined, model: undefined, temperature: undefined, images: undefined }; } export const GenerativeOllama = { @@ -1646,6 +1725,9 @@ export const GenerativeOllama = { if (message.temperature !== undefined) { writer.uint32(25).double(message.temperature); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(34).fork()).ldelim(); + } return writer; }, @@ -1677,6 +1759,13 @@ export const GenerativeOllama = { message.temperature = reader.double(); continue; + case 4: + if (tag !== 34) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -1691,6 +1780,7 @@ export const GenerativeOllama = { apiEndpoint: isSet(object.apiEndpoint) ? globalThis.String(object.apiEndpoint) : undefined, model: isSet(object.model) ? globalThis.String(object.model) : undefined, temperature: isSet(object.temperature) ? globalThis.Number(object.temperature) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, }; }, @@ -1705,6 +1795,9 @@ export const GenerativeOllama = { if (message.temperature !== undefined) { obj.temperature = message.temperature; } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } return obj; }, @@ -1716,6 +1809,9 @@ export const GenerativeOllama = { message.apiEndpoint = object.apiEndpoint ?? undefined; message.model = object.model ?? undefined; message.temperature = object.temperature ?? undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; return message; }, }; @@ -1735,6 +1831,7 @@ function createBaseGenerativeOpenAI(): GenerativeOpenAI { resourceName: undefined, deploymentId: undefined, isAzure: undefined, + images: undefined, }; } @@ -1779,6 +1876,9 @@ export const GenerativeOpenAI = { if (message.isAzure !== undefined) { writer.uint32(104).bool(message.isAzure); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(114).fork()).ldelim(); + } return writer; }, @@ -1880,6 +1980,13 @@ export const GenerativeOpenAI = { message.isAzure = reader.bool(); continue; + case 14: + if (tag !== 114) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -1904,6 +2011,7 @@ export const GenerativeOpenAI = { resourceName: isSet(object.resourceName) ? globalThis.String(object.resourceName) : undefined, deploymentId: isSet(object.deploymentId) ? globalThis.String(object.deploymentId) : undefined, isAzure: isSet(object.isAzure) ? globalThis.Boolean(object.isAzure) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, }; }, @@ -1948,6 +2056,9 @@ export const GenerativeOpenAI = { if (message.isAzure !== undefined) { obj.isAzure = message.isAzure; } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } return obj; }, @@ -1969,6 +2080,9 @@ export const GenerativeOpenAI = { message.resourceName = object.resourceName ?? undefined; message.deploymentId = object.deploymentId ?? undefined; message.isAzure = object.isAzure ?? undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; return message; }, }; @@ -1987,6 +2101,7 @@ function createBaseGenerativeGoogle(): GenerativeGoogle { projectId: undefined, endpointId: undefined, region: undefined, + images: undefined, }; } @@ -2028,6 +2143,9 @@ export const GenerativeGoogle = { if (message.region !== undefined) { writer.uint32(98).string(message.region); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(106).fork()).ldelim(); + } return writer; }, @@ -2122,6 +2240,13 @@ export const GenerativeGoogle = { message.region = reader.string(); continue; + case 13: + if (tag !== 106) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -2145,6 +2270,7 @@ export const GenerativeGoogle = { projectId: isSet(object.projectId) ? globalThis.String(object.projectId) : undefined, endpointId: isSet(object.endpointId) ? globalThis.String(object.endpointId) : undefined, region: isSet(object.region) ? globalThis.String(object.region) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, }; }, @@ -2186,6 +2312,9 @@ export const GenerativeGoogle = { if (message.region !== undefined) { obj.region = message.region; } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } return obj; }, @@ -2208,6 +2337,9 @@ export const GenerativeGoogle = { message.projectId = object.projectId ?? undefined; message.endpointId = object.endpointId ?? undefined; message.region = object.region ?? undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; return message; }, }; @@ -2574,6 +2706,125 @@ export const GenerativeFriendliAI = { }, }; +function createBaseGenerativeNvidia(): GenerativeNvidia { + return { baseUrl: undefined, model: undefined, temperature: undefined, topP: undefined, maxTokens: undefined }; +} + +export const GenerativeNvidia = { + encode(message: GenerativeNvidia, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.baseUrl !== undefined) { + writer.uint32(10).string(message.baseUrl); + } + if (message.model !== undefined) { + writer.uint32(18).string(message.model); + } + if (message.temperature !== undefined) { + writer.uint32(25).double(message.temperature); + } + if (message.topP !== undefined) { + writer.uint32(33).double(message.topP); + } + if (message.maxTokens !== undefined) { + writer.uint32(40).int64(message.maxTokens); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GenerativeNvidia { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGenerativeNvidia(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.baseUrl = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.model = reader.string(); + continue; + case 3: + if (tag !== 25) { + break; + } + + message.temperature = reader.double(); + continue; + case 4: + if (tag !== 33) { + break; + } + + message.topP = reader.double(); + continue; + case 5: + if (tag !== 40) { + break; + } + + message.maxTokens = longToNumber(reader.int64() as Long); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GenerativeNvidia { + return { + baseUrl: isSet(object.baseUrl) ? globalThis.String(object.baseUrl) : undefined, + model: isSet(object.model) ? globalThis.String(object.model) : undefined, + temperature: isSet(object.temperature) ? globalThis.Number(object.temperature) : undefined, + topP: isSet(object.topP) ? globalThis.Number(object.topP) : undefined, + maxTokens: isSet(object.maxTokens) ? globalThis.Number(object.maxTokens) : undefined, + }; + }, + + toJSON(message: GenerativeNvidia): unknown { + const obj: any = {}; + if (message.baseUrl !== undefined) { + obj.baseUrl = message.baseUrl; + } + if (message.model !== undefined) { + obj.model = message.model; + } + if (message.temperature !== undefined) { + obj.temperature = message.temperature; + } + if (message.topP !== undefined) { + obj.topP = message.topP; + } + if (message.maxTokens !== undefined) { + obj.maxTokens = Math.round(message.maxTokens); + } + return obj; + }, + + create(base?: DeepPartial): GenerativeNvidia { + return GenerativeNvidia.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): GenerativeNvidia { + const message = createBaseGenerativeNvidia(); + message.baseUrl = object.baseUrl ?? undefined; + message.model = object.model ?? undefined; + message.temperature = object.temperature ?? undefined; + message.topP = object.topP ?? undefined; + message.maxTokens = object.maxTokens ?? undefined; + return message; + }, +}; + function createBaseGenerativeAnthropicMetadata(): GenerativeAnthropicMetadata { return { usage: undefined }; } @@ -4246,6 +4497,154 @@ export const GenerativeFriendliAIMetadata_Usage = { }, }; +function createBaseGenerativeNvidiaMetadata(): GenerativeNvidiaMetadata { + return { usage: undefined }; +} + +export const GenerativeNvidiaMetadata = { + encode(message: GenerativeNvidiaMetadata, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.usage !== undefined) { + GenerativeNvidiaMetadata_Usage.encode(message.usage, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GenerativeNvidiaMetadata { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGenerativeNvidiaMetadata(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.usage = GenerativeNvidiaMetadata_Usage.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GenerativeNvidiaMetadata { + return { usage: isSet(object.usage) ? GenerativeNvidiaMetadata_Usage.fromJSON(object.usage) : undefined }; + }, + + toJSON(message: GenerativeNvidiaMetadata): unknown { + const obj: any = {}; + if (message.usage !== undefined) { + obj.usage = GenerativeNvidiaMetadata_Usage.toJSON(message.usage); + } + return obj; + }, + + create(base?: DeepPartial): GenerativeNvidiaMetadata { + return GenerativeNvidiaMetadata.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): GenerativeNvidiaMetadata { + const message = createBaseGenerativeNvidiaMetadata(); + message.usage = (object.usage !== undefined && object.usage !== null) + ? GenerativeNvidiaMetadata_Usage.fromPartial(object.usage) + : undefined; + return message; + }, +}; + +function createBaseGenerativeNvidiaMetadata_Usage(): GenerativeNvidiaMetadata_Usage { + return { promptTokens: undefined, completionTokens: undefined, totalTokens: undefined }; +} + +export const GenerativeNvidiaMetadata_Usage = { + encode(message: GenerativeNvidiaMetadata_Usage, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.promptTokens !== undefined) { + writer.uint32(8).int64(message.promptTokens); + } + if (message.completionTokens !== undefined) { + writer.uint32(16).int64(message.completionTokens); + } + if (message.totalTokens !== undefined) { + writer.uint32(24).int64(message.totalTokens); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GenerativeNvidiaMetadata_Usage { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGenerativeNvidiaMetadata_Usage(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 8) { + break; + } + + message.promptTokens = longToNumber(reader.int64() as Long); + continue; + case 2: + if (tag !== 16) { + break; + } + + message.completionTokens = longToNumber(reader.int64() as Long); + continue; + case 3: + if (tag !== 24) { + break; + } + + message.totalTokens = longToNumber(reader.int64() as Long); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GenerativeNvidiaMetadata_Usage { + return { + promptTokens: isSet(object.promptTokens) ? globalThis.Number(object.promptTokens) : undefined, + completionTokens: isSet(object.completionTokens) ? globalThis.Number(object.completionTokens) : undefined, + totalTokens: isSet(object.totalTokens) ? globalThis.Number(object.totalTokens) : undefined, + }; + }, + + toJSON(message: GenerativeNvidiaMetadata_Usage): unknown { + const obj: any = {}; + if (message.promptTokens !== undefined) { + obj.promptTokens = Math.round(message.promptTokens); + } + if (message.completionTokens !== undefined) { + obj.completionTokens = Math.round(message.completionTokens); + } + if (message.totalTokens !== undefined) { + obj.totalTokens = Math.round(message.totalTokens); + } + return obj; + }, + + create(base?: DeepPartial): GenerativeNvidiaMetadata_Usage { + return GenerativeNvidiaMetadata_Usage.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): GenerativeNvidiaMetadata_Usage { + const message = createBaseGenerativeNvidiaMetadata_Usage(); + message.promptTokens = object.promptTokens ?? undefined; + message.completionTokens = object.completionTokens ?? undefined; + message.totalTokens = object.totalTokens ?? undefined; + return message; + }, +}; + function createBaseGenerativeMetadata(): GenerativeMetadata { return { anthropic: undefined, @@ -4259,6 +4658,7 @@ function createBaseGenerativeMetadata(): GenerativeMetadata { google: undefined, databricks: undefined, friendliai: undefined, + nvidia: undefined, }; } @@ -4297,6 +4697,9 @@ export const GenerativeMetadata = { if (message.friendliai !== undefined) { GenerativeFriendliAIMetadata.encode(message.friendliai, writer.uint32(90).fork()).ldelim(); } + if (message.nvidia !== undefined) { + GenerativeNvidiaMetadata.encode(message.nvidia, writer.uint32(98).fork()).ldelim(); + } return writer; }, @@ -4384,6 +4787,13 @@ export const GenerativeMetadata = { message.friendliai = GenerativeFriendliAIMetadata.decode(reader, reader.uint32()); continue; + case 12: + if (tag !== 98) { + break; + } + + message.nvidia = GenerativeNvidiaMetadata.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -4406,6 +4816,7 @@ export const GenerativeMetadata = { google: isSet(object.google) ? GenerativeGoogleMetadata.fromJSON(object.google) : undefined, databricks: isSet(object.databricks) ? GenerativeDatabricksMetadata.fromJSON(object.databricks) : undefined, friendliai: isSet(object.friendliai) ? GenerativeFriendliAIMetadata.fromJSON(object.friendliai) : undefined, + nvidia: isSet(object.nvidia) ? GenerativeNvidiaMetadata.fromJSON(object.nvidia) : undefined, }; }, @@ -4444,6 +4855,9 @@ export const GenerativeMetadata = { if (message.friendliai !== undefined) { obj.friendliai = GenerativeFriendliAIMetadata.toJSON(message.friendliai); } + if (message.nvidia !== undefined) { + obj.nvidia = GenerativeNvidiaMetadata.toJSON(message.nvidia); + } return obj; }, @@ -4485,6 +4899,9 @@ export const GenerativeMetadata = { message.friendliai = (object.friendliai !== undefined && object.friendliai !== null) ? GenerativeFriendliAIMetadata.fromPartial(object.friendliai) : undefined; + message.nvidia = (object.nvidia !== undefined && object.nvidia !== null) + ? GenerativeNvidiaMetadata.fromPartial(object.nvidia) + : undefined; return message; }, }; diff --git a/src/utils/dbVersion.ts b/src/utils/dbVersion.ts index 279537e2..f2154836 100644 --- a/src/utils/dbVersion.ts +++ b/src/utils/dbVersion.ts @@ -219,6 +219,16 @@ export class DbVersionSupport { }; }); }; + + supportsVectorsFieldInGRPC = () => { + return this.dbVersionProvider.getVersion().then((version) => { + return { + version: version, + supports: version.isAtLeast(1, 29, 0), + message: undefined, + }; + }); + }; } const EMPTY_VERSION = ''; From cc4ce77c46bf5468c9e1fcd5ac2adbf055b9a596 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 12 Feb 2025 11:20:29 +0000 Subject: [PATCH 2/4] Fix type errors, add build to pre-commit --- package.json | 3 +- src/collections/generate/index.ts | 140 ++++++++++++++++++------------ 2 files changed, 87 insertions(+), 56 deletions(-) diff --git a/package.json b/package.json index a6c4f981..6bd6f1ed 100644 --- a/package.json +++ b/package.json @@ -100,7 +100,8 @@ "lint-staged": { "*.{ts,js}": [ "npm run format:check", - "npm run lint -- --cache" + "npm run lint -- --cache", + "npm run prepack" ] } } diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index 3af6fef1..d783e61a 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -72,12 +72,14 @@ class GenerateManager implements Generate { ): Promise> { return this.check .fetchObjects(opts) - .then(({ search }) => - search.withFetch({ + .then(({ search }) => ({ + search, + args: { ...Serialize.search.fetchObjects(opts), generative: Serialize.generative(generate), - }) - ) + }, + })) + .then(({ search, args }) => search.withFetch(args)) .then((reply) => this.parseReply(reply)); } @@ -94,12 +96,14 @@ class GenerateManager implements Generate { public bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn { return this.check .bm25(opts) - .then(({ search }) => - search.withBm25({ + .then(({ search }) => ({ + search, + args: { ...Serialize.search.bm25(query, opts), generative: Serialize.generative(generate), - }) - ) + }, + })) + .then(({ search, args }) => search.withBm25(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -116,20 +120,31 @@ class GenerateManager implements Generate { public hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn { return this.check .hybridSearch(opts) - .then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) => - search.withHybrid({ - ...Serialize.search.hybrid( - { - query, - supportsTargets, - supportsVectorsForTargets, - supportsWeightsForTargets, - }, - opts - ), - generative: Serialize.generative(generate), + .then( + async ({ + search, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + supportsVectors, + }) => ({ + search, + args: { + ...(await Serialize.search.hybrid( + { + query, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + supportsVectors, + }, + opts + )), + generative: Serialize.generative(generate), + }, }) ) + .then(({ search, args }) => search.withHybrid(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -150,21 +165,21 @@ class GenerateManager implements Generate { ): GenerateReturn { return this.check .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - toBase64FromMedia(image).then((image) => - search.withNearImage({ - ...Serialize.search.nearImage( - { - image, - supportsTargets, - supportsWeightsForTargets, - }, - opts - ), - generative: Serialize.generative(generate), - }) - ) - ) + .then(async ({ search, supportsTargets, supportsWeightsForTargets }) => ({ + search, + args: { + ...Serialize.search.nearImage( + { + image: await toBase64FromMedia(image), + supportsTargets, + supportsWeightsForTargets, + }, + opts + ), + generative: Serialize.generative(generate), + }, + })) + .then(({ search, args }) => search.withNearImage(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -181,8 +196,9 @@ class GenerateManager implements Generate { public nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn { return this.check .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - search.withNearObject({ + .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ + search, + args: { ...Serialize.search.nearObject( { id, @@ -192,8 +208,9 @@ class GenerateManager implements Generate { opts ), generative: Serialize.generative(generate), - }) - ) + }, + })) + .then(({ search, args }) => search.withNearObject(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -214,8 +231,9 @@ class GenerateManager implements Generate { ): GenerateReturn { return this.check .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - search.withNearText({ + .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ + search, + args: { ...Serialize.search.nearText( { query, @@ -225,8 +243,9 @@ class GenerateManager implements Generate { opts ), generative: Serialize.generative(generate), - }) - ) + }, + })) + .then(({ search, args }) => search.withNearText(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -247,20 +266,31 @@ class GenerateManager implements Generate { ): GenerateReturn { return this.check .nearVector(vector, opts) - .then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) => - search.withNearVector({ - ...Serialize.search.nearVector( - { - vector, - supportsTargets, - supportsVectorsForTargets, - supportsWeightsForTargets, - }, - opts - ), - generative: Serialize.generative(generate), + .then( + async ({ + search, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + supportsVectors, + }) => ({ + search, + args: { + ...(await Serialize.search.nearVector( + { + vector, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + supportsVectors, + }, + opts + )), + generative: Serialize.generative(generate), + }, }) ) + .then(({ search, args }) => search.withNearVector(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } From 2c2964b759323b864d7885c75e389eaa9be6f722 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Thu, 13 Feb 2025 16:05:36 +0000 Subject: [PATCH 3/4] Making sweeping changes: - Introduce `V` generic for collections allowing users to define types for their multiple vectors - Add support for creating multivector collections - Add BC support for querying multivector collections - Add yielding to de/ser logic of multivectors to avoid expensive blocking CPU loops - Update CI image --- .github/workflows/main.yaml | 2 +- package.json | 3 +- src/collections/aggregate/index.ts | 70 ++--- src/collections/aggregate/integration.test.ts | 4 +- src/collections/collection/index.ts | 51 ++-- src/collections/config/integration.test.ts | 4 + src/collections/config/types/vectorIndex.ts | 5 + src/collections/config/utils.ts | 14 + src/collections/configure/parsing.ts | 21 ++ .../configure/types/vectorIndex.ts | 5 + src/collections/configure/types/vectorizer.ts | 10 +- src/collections/configure/unit.test.ts | 20 ++ src/collections/configure/vectorIndex.ts | 20 ++ src/collections/data/integration.test.ts | 2 +- src/collections/deserialize/index.ts | 190 +++++++----- src/collections/filters/integration.test.ts | 2 +- src/collections/generate/index.ts | 112 +++---- src/collections/generate/integration.test.ts | 8 +- src/collections/generate/types.ts | 171 ++++++----- src/collections/index.ts | 70 +++-- src/collections/iterator/index.ts | 8 +- src/collections/iterator/integration.test.ts | 2 +- src/collections/journey.test.ts | 1 + src/collections/query/check.ts | 34 +-- src/collections/query/index.ts | 95 +++--- src/collections/query/integration.test.ts | 21 +- src/collections/query/types.ts | 228 +++++++------- src/collections/query/utils.ts | 40 ++- src/collections/references/classes.ts | 12 +- src/collections/serialize/index.ts | 288 ++++++++++-------- src/collections/serialize/unit.test.ts | 97 +++++- src/collections/types/generate.ts | 26 +- src/collections/types/internal.ts | 4 + src/collections/types/query.ts | 48 +-- src/collections/vectors/journey.test.ts | 159 ++++++++++ src/collections/vectors/multiTargetVector.ts | 46 +-- src/utils/yield.ts | 1 + 37 files changed, 1212 insertions(+), 682 deletions(-) create mode 100644 src/collections/vectors/journey.test.ts create mode 100644 src/utils/yield.ts diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index d63cd1f2..ed689977 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -12,7 +12,7 @@ env: WEAVIATE_126: 1.26.14 WEAVIATE_127: 1.27.11 WEAVIATE_128: 1.28.4 - WEAVIATE_129: 1.29.0-rc.0 + WEAVIATE_129: 1.29.0-rc.1 jobs: checks: diff --git a/package.json b/package.json index 6bd6f1ed..37112bb0 100644 --- a/package.json +++ b/package.json @@ -101,7 +101,8 @@ "*.{ts,js}": [ "npm run format:check", "npm run lint -- --cache", - "npm run prepack" + "npm run prepack", + "npm run docs" ] } } diff --git a/src/collections/aggregate/index.ts b/src/collections/aggregate/index.ts index 06f86649..a4cce65e 100644 --- a/src/collections/aggregate/index.ts +++ b/src/collections/aggregate/index.ts @@ -9,7 +9,7 @@ import { WeaviateInvalidInputError, WeaviateQueryError } from '../../errors.js'; import { Aggregator } from '../../graphql/index.js'; import { PrimitiveKeys, toBase64FromMedia } from '../../index.js'; import { Deserialize } from '../deserialize/index.js'; -import { Bm25QueryProperty, NearVectorInputType } from '../query/types.js'; +import { Bm25QueryProperty, NearVectorInputType, TargetVector } from '../query/types.js'; import { NearVectorInputGuards } from '../query/utils.js'; import { Serialize } from '../serialize/index.js'; @@ -31,27 +31,27 @@ export type GroupByAggregate = { export type AggregateOverAllOptions = AggregateBaseOptions; -export type AggregateNearOptions = AggregateBaseOptions & { +export type AggregateNearOptions = AggregateBaseOptions & { certainty?: number; distance?: number; objectLimit?: number; - targetVector?: string; + targetVector?: TargetVector; }; -export type AggregateHybridOptions = AggregateBaseOptions & { +export type AggregateHybridOptions = AggregateBaseOptions & { alpha?: number; maxVectorDistance?: number; objectLimit?: number; queryProperties?: (PrimitiveKeys | Bm25QueryProperty)[]; - targetVector?: string; + targetVector?: TargetVector; vector?: number[]; }; -export type AggregateGroupByHybridOptions = AggregateHybridOptions & { +export type AggregateGroupByHybridOptions = AggregateHybridOptions & { groupBy: PropertyOf | GroupByAggregate; }; -export type AggregateGroupByNearOptions = AggregateNearOptions & { +export type AggregateGroupByNearOptions = AggregateNearOptions & { groupBy: PropertyOf | GroupByAggregate; }; @@ -346,9 +346,9 @@ export type AggregateGroupByResult< }; }; -class AggregateManager implements Aggregate { +class AggregateManager implements Aggregate { connection: Connection; - groupBy: AggregateGroupBy; + groupBy: AggregateGroupBy; name: string; dbVersionSupport: DbVersionSupport; consistencyLevel?: ConsistencyLevel; @@ -373,7 +373,7 @@ class AggregateManager implements Aggregate { this.groupBy = { hybrid: async >( query: string, - opts: AggregateGroupByHybridOptions + opts: AggregateGroupByHybridOptions ): Promise[]> => { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; @@ -402,7 +402,7 @@ class AggregateManager implements Aggregate { }, nearImage: async >( image: string | Buffer, - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]> => { const [b64, usesGrpc] = await Promise.all([await toBase64FromMedia(image), await this.grpcChecker]); if (usesGrpc) { @@ -430,7 +430,7 @@ class AggregateManager implements Aggregate { }, nearObject: async >( id: string, - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]> => { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; @@ -457,7 +457,7 @@ class AggregateManager implements Aggregate { }, nearText: async >( query: string | string[], - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]> => { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; @@ -484,7 +484,7 @@ class AggregateManager implements Aggregate { }, nearVector: async >( vector: number[], - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]> => { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; @@ -593,19 +593,19 @@ class AggregateManager implements Aggregate { return `${propertyName} { ${body} }`; } - static use( + static use( connection: Connection, name: string, dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: string - ): AggregateManager { - return new AggregateManager(connection, name, dbVersionSupport, consistencyLevel, tenant); + ): AggregateManager { + return new AggregateManager(connection, name, dbVersionSupport, consistencyLevel, tenant); } async hybrid>( query: string, - opts?: AggregateHybridOptions + opts?: AggregateHybridOptions ): Promise> { if (await this.grpcChecker) { return this.grpc() @@ -628,7 +628,7 @@ class AggregateManager implements Aggregate { async nearImage>( image: string | Buffer, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise> { const [b64, usesGrpc] = await Promise.all([await toBase64FromMedia(image), await this.grpcChecker]); if (usesGrpc) { @@ -650,7 +650,7 @@ class AggregateManager implements Aggregate { async nearObject>( id: string, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise> { if (await this.grpcChecker) { return this.grpc() @@ -671,7 +671,7 @@ class AggregateManager implements Aggregate { async nearText>( query: string | string[], - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise> { if (await this.grpcChecker) { return this.grpc() @@ -692,7 +692,7 @@ class AggregateManager implements Aggregate { async nearVector>( vector: NearVectorInputType, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise> { if (await this.grpcChecker) { return this.grpc() @@ -770,9 +770,9 @@ class AggregateManager implements Aggregate { }; } -export interface Aggregate { +export interface Aggregate { /** This namespace contains methods perform a group by search while aggregating metrics. */ - groupBy: AggregateGroupBy; + groupBy: AggregateGroupBy; /** * Aggregate metrics over the objects returned by a hybrid search on this collection. * @@ -784,7 +784,7 @@ export interface Aggregate { */ hybrid>( query: string, - opts?: AggregateHybridOptions + opts?: AggregateHybridOptions ): Promise>; /** * Aggregate metrics over the objects returned by a near image vector search on this collection. @@ -799,7 +799,7 @@ export interface Aggregate { */ nearImage>( image: string | Buffer, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise>; /** * Aggregate metrics over the objects returned by a near object search on this collection. @@ -814,7 +814,7 @@ export interface Aggregate { */ nearObject>( id: string, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise>; /** * Aggregate metrics over the objects returned by a near vector search on this collection. @@ -829,7 +829,7 @@ export interface Aggregate { */ nearText>( query: string | string[], - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise>; /** * Aggregate metrics over the objects returned by a near vector search on this collection. @@ -844,7 +844,7 @@ export interface Aggregate { */ nearVector>( vector: number[], - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise>; /** * Aggregate metrics over all the objects in this collection without any vector search. @@ -855,7 +855,7 @@ export interface Aggregate { overAll>(opts?: AggregateOverAllOptions): Promise>; } -export interface AggregateGroupBy { +export interface AggregateGroupBy { /** * Aggregate metrics over the objects grouped by a specified property and returned by a hybrid search on this collection. * @@ -867,7 +867,7 @@ export interface AggregateGroupBy { */ hybrid>( query: string, - opts: AggregateGroupByHybridOptions + opts: AggregateGroupByHybridOptions ): Promise[]>; /** * Aggregate metrics over the objects grouped by a specified property and returned by a near image vector search on this collection. @@ -882,7 +882,7 @@ export interface AggregateGroupBy { */ nearImage>( image: string | Buffer, - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]>; /** * Aggregate metrics over the objects grouped by a specified property and returned by a near object search on this collection. @@ -897,7 +897,7 @@ export interface AggregateGroupBy { */ nearObject>( id: string, - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]>; /** * Aggregate metrics over the objects grouped by a specified property and returned by a near text vector search on this collection. @@ -912,7 +912,7 @@ export interface AggregateGroupBy { */ nearText>( query: string | string[], - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]>; /** * Aggregate metrics over the objects grouped by a specified property and returned by a near vector search on this collection. @@ -927,7 +927,7 @@ export interface AggregateGroupBy { */ nearVector>( vector: number[], - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]>; /** * Aggregate metrics over all the objects in this collection grouped by a specified property without any vector search. diff --git a/src/collections/aggregate/integration.test.ts b/src/collections/aggregate/integration.test.ts index f06017c0..592782ba 100644 --- a/src/collections/aggregate/integration.test.ts +++ b/src/collections/aggregate/integration.test.ts @@ -485,7 +485,7 @@ describe('Testing of collection.aggregate search methods', () => { it('should return an aggregation on a nearVector search', async () => { const obj = await collection.query.fetchObjectById(uuid, { includeVector: true }); - const result = await collection.aggregate.nearVector(obj?.vectors.default!, { + const result = await collection.aggregate.nearVector(obj?.vectors.default as number[], { objectLimit: 1000, returnMetrics: collection.metrics.aggregate('text').text(['count']), }); @@ -494,7 +494,7 @@ describe('Testing of collection.aggregate search methods', () => { it('should return a grouped aggregation on a nearVector search', async () => { const obj = await collection.query.fetchObjectById(uuid, { includeVector: true }); - const result = await collection.aggregate.groupBy.nearVector(obj?.vectors.default!, { + const result = await collection.aggregate.groupBy.nearVector(obj?.vectors.default as number[], { objectLimit: 1000, groupBy: 'text', returnMetrics: collection.metrics.aggregate('text').text(['count']), diff --git a/src/collections/collection/index.ts b/src/collections/collection/index.ts index c4f7ee31..3fa1108c 100644 --- a/src/collections/collection/index.ts +++ b/src/collections/collection/index.ts @@ -15,11 +15,12 @@ import query, { Query } from '../query/index.js'; import sort, { Sort } from '../sort/index.js'; import tenants, { TenantBase, Tenants } from '../tenants/index.js'; import { QueryMetadata, QueryProperty, QueryReference } from '../types/index.js'; +import { IncludeVector } from '../types/internal.js'; import multiTargetVector, { MultiTargetVector } from '../vectors/multiTargetVector.js'; -export interface Collection { +export interface Collection { /** This namespace includes all the querying methods available to you when using Weaviate's standard aggregation capabilities. */ - aggregate: Aggregate; + aggregate: Aggregate; /** This namespace includes all the backup methods available to you when backing up a collection in Weaviate. */ backup: BackupCollection; /** This namespace includes all the CRUD methods available to you when modifying the configuration of the collection in Weaviate. */ @@ -29,19 +30,19 @@ export interface Collection { /** This namespace includes the methods by which you can create the `FilterValue` values for use when filtering queries over your collection. */ filter: Filter; /** This namespace includes all the querying methods available to you when using Weaviate's generative capabilities. */ - generate: Generate; + generate: Generate; /** This namespace includes the methods by which you can create the `MetricsX` values for use when aggregating over your collection. */ metrics: Metrics; /** The name of the collection. */ name: N; /** This namespace includes all the querying methods available to you when using Weaviate's standard query capabilities. */ - query: Query; + query: Query; /** This namespaces includes the methods by which you can create the `Sorting` values for use when sorting queries over your collection. */ sort: Sort; /** This namespace includes all the CRUD methods available to you when modifying the tenants of a multi-tenancy-enabled collection in Weaviate. */ tenants: Tenants; /** This namespaces includes the methods by which you cna create the `MultiTargetVectorJoin` values for use when performing multi-target vector searches over your collection. */ - multiTargetVector: MultiTargetVector; + multiTargetVector: MultiTargetVector; /** * Use this method to check if the collection exists in Weaviate. * @@ -62,7 +63,7 @@ export interface Collection { * to request the vector back as well. In addition, if `return_references=None` then none of the references * are returned. Use `wvc.QueryReference` to specify which references to return. */ - iterator: (opts?: IteratorOptions) => Iterator; + iterator: (opts?: IteratorOptions) => Iterator; /** * Use this method to return the total number of objects in the collection. * @@ -77,9 +78,9 @@ export interface Collection { * This method does not send a request to Weaviate. It only returns a new collection object that is specific to the consistency level you specify. * * @param {ConsistencyLevel} consistencyLevel The consistency level to use. - * @returns {Collection} A new collection object specific to the consistency level you specified. + * @returns {Collection} A new collection object specific to the consistency level you specified. */ - withConsistency: (consistencyLevel: ConsistencyLevel) => Collection; + withConsistency: (consistencyLevel: ConsistencyLevel) => Collection; /** * Use this method to return a collection object specific to a single tenant. * @@ -89,13 +90,13 @@ export interface Collection { * * @typedef {TenantBase} TT A type that extends TenantBase. * @param {string | TT} tenant The tenant name or tenant object to use. - * @returns {Collection} A new collection object specific to the tenant you specified. + * @returns {Collection} A new collection object specific to the tenant you specified. */ - withTenant: (tenant: string | TT) => Collection; + withTenant: (tenant: string | TT) => Collection; } -export type IteratorOptions = { - includeVector?: boolean | string[]; +export type IteratorOptions = { + includeVector?: IncludeVector; returnMetadata?: QueryMetadata; returnProperties?: QueryProperty[]; returnReferences?: QueryReference[]; @@ -106,41 +107,47 @@ const isString = (value: any): value is string => typeof value === 'string'; const capitalizeCollectionName = (name: N): N => (name.charAt(0).toUpperCase() + name.slice(1)) as N; -const collection = ( +const collection = ( connection: Connection, name: N, dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: string -): Collection => { +): Collection => { if (!isString(name)) { throw new WeaviateInvalidInputError(`The collection name must be a string, got: ${typeof name}`); } const capitalizedName = capitalizeCollectionName(name); - const aggregateCollection = aggregate( + const aggregateCollection = aggregate( + connection, + capitalizedName, + dbVersionSupport, + consistencyLevel, + tenant + ); + const queryCollection = query( connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant ); - const queryCollection = query(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant); return { aggregate: aggregateCollection, backup: backupCollection(connection, capitalizedName), config: config(connection, capitalizedName, dbVersionSupport, tenant), data: data(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), filter: filter(), - generate: generate(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), + generate: generate(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), metrics: metrics(), - multiTargetVector: multiTargetVector(), + multiTargetVector: multiTargetVector(), name: name, query: queryCollection, sort: sort(), tenants: tenants(connection, capitalizedName, dbVersionSupport), exists: () => new ClassExists(connection).withClassName(capitalizedName).do(), - iterator: (opts?: IteratorOptions) => - new Iterator((limit: number, after?: string) => + iterator: (opts?: IteratorOptions) => + new Iterator((limit: number, after?: string) => queryCollection .fetchObjects({ limit, @@ -154,9 +161,9 @@ const collection = ( ), length: () => aggregateCollection.overAll().then(({ totalCount }) => totalCount), withConsistency: (consistencyLevel: ConsistencyLevel) => - collection(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), + collection(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), withTenant: (tenant: string | TT) => - collection( + collection( connection, capitalizedName, dbVersionSupport, diff --git a/src/collections/config/integration.test.ts b/src/collections/config/integration.test.ts index 254c6112..77171f21 100644 --- a/src/collections/config/integration.test.ts +++ b/src/collections/config/integration.test.ts @@ -72,6 +72,7 @@ describe('Testing of the collection.config namespace', () => { filterStrategy: 'sweeping', flatSearchCutoff: 40000, distance: 'cosine', + multiVector: undefined, quantizer: undefined, type: 'hnsw', }); @@ -127,6 +128,7 @@ describe('Testing of the collection.config namespace', () => { filterStrategy: 'sweeping', flatSearchCutoff: 40000, distance: 'cosine', + multiVector: undefined, quantizer: undefined, type: 'hnsw', }); @@ -499,6 +501,7 @@ describe('Testing of the collection.config namespace', () => { filterStrategy: 'sweeping', flatSearchCutoff: 40000, distance: 'cosine', + multiVector: undefined, quantizer: { bitCompression: false, segments: 0, @@ -608,6 +611,7 @@ describe('Testing of the collection.config namespace', () => { filterStrategy: 'sweeping', flatSearchCutoff: 40000, distance: 'cosine', + multiVector: undefined, type: 'hnsw', quantizer: { bitCompression: false, diff --git a/src/collections/config/types/vectorIndex.ts b/src/collections/config/types/vectorIndex.ts index ddd6ea90..da42130f 100644 --- a/src/collections/config/types/vectorIndex.ts +++ b/src/collections/config/types/vectorIndex.ts @@ -9,6 +9,7 @@ export type VectorIndexConfigHNSW = { filterStrategy: VectorIndexFilterStrategy; flatSearchCutoff: number; maxConnections: number; + multiVector: MultiVectorConfig | undefined; quantizer: PQConfig | BQConfig | SQConfig | undefined; skip: boolean; vectorCacheMaxObjects: number; @@ -61,6 +62,10 @@ export type PQConfig = { type: 'pq'; }; +export type MultiVectorConfig = { + aggregation: 'maxSim' | string; +}; + export type PQEncoderConfig = { type: PQEncoderType; distribution: PQEncoderDistribution; diff --git a/src/collections/config/utils.ts b/src/collections/config/utils.ts index da6ab9d6..f11be817 100644 --- a/src/collections/config/utils.ts +++ b/src/collections/config/utils.ts @@ -27,6 +27,7 @@ import { InvertedIndexConfig, ModuleConfig, MultiTenancyConfig, + MultiVectorConfig, PQConfig, PQEncoderConfig, PQEncoderDistribution, @@ -380,12 +381,25 @@ class ConfigMapping { filterStrategy: exists(v.filterStrategy) ? v.filterStrategy : 'sweeping', flatSearchCutoff: v.flatSearchCutoff, maxConnections: v.maxConnections, + multiVector: exists(v.multivector) + ? ConfigMapping.multiVector(v.multivector) + : undefined, quantizer: quantizer, skip: v.skip, vectorCacheMaxObjects: v.vectorCacheMaxObjects, type: 'hnsw', }; } + static multiVector(v: Record): MultiVectorConfig | undefined { + if (!exists(v.enabled)) + throw new WeaviateDeserializationError('Multi vector enabled was not returned by Weaviate'); + if (v.enabled === false) return undefined; + if (!exists(v.aggregation)) + throw new WeaviateDeserializationError('Multi vector aggregation was not returned by Weaviate'); + return { + aggregation: v.aggregation, + }; + } static bq(v?: Record): BQConfig | undefined { if (v === undefined) throw new WeaviateDeserializationError('BQ was not returned by Weaviate'); if (!exists(v.enabled)) diff --git a/src/collections/configure/parsing.ts b/src/collections/configure/parsing.ts index 09319424..08cf2a33 100644 --- a/src/collections/configure/parsing.ts +++ b/src/collections/configure/parsing.ts @@ -5,6 +5,9 @@ import { PQConfigUpdate, SQConfigCreate, SQConfigUpdate, + VectorIndexConfigDynamicCreate, + VectorIndexConfigFlatCreate, + VectorIndexConfigHNSWCreate, } from './types/index.js'; type QuantizerConfig = @@ -36,6 +39,24 @@ export class QuantizerGuards { } } +type VectorIndexConfig = + | VectorIndexConfigHNSWCreate + | VectorIndexConfigFlatCreate + | VectorIndexConfigDynamicCreate + | Record; + +export class VectorIndexGuards { + static isHNSW(config?: VectorIndexConfig): config is VectorIndexConfigHNSWCreate { + return (config as VectorIndexConfigHNSWCreate)?.type === 'hnsw'; + } + static isFlat(config?: VectorIndexConfig): config is VectorIndexConfigFlatCreate { + return (config as VectorIndexConfigFlatCreate)?.type === 'flat'; + } + static isDynamic(config?: VectorIndexConfig): config is VectorIndexConfigDynamicCreate { + return (config as VectorIndexConfigDynamicCreate)?.type === 'dynamic'; + } +} + export function parseWithDefault(value: D | undefined, defaultValue: D): D { return value !== undefined ? value : defaultValue; } diff --git a/src/collections/configure/types/vectorIndex.ts b/src/collections/configure/types/vectorIndex.ts index 4f759a7f..2622abf5 100644 --- a/src/collections/configure/types/vectorIndex.ts +++ b/src/collections/configure/types/vectorIndex.ts @@ -1,6 +1,7 @@ import { BQConfig, ModuleConfig, + MultiVectorConfig, PQConfig, PQEncoderDistribution, PQEncoderType, @@ -46,6 +47,8 @@ export type SQConfigUpdate = { type: 'sq'; }; +export type MultiVectorConfigCreate = RecursivePartial; + export type VectorIndexConfigHNSWCreate = RecursivePartial; export type VectorIndexConfigDynamicCreate = RecursivePartial; @@ -130,6 +133,8 @@ export type VectorIndexConfigHNSWCreateOptions = { filterStrategy?: VectorIndexFilterStrategy; /** The maximum number of connections. Default is 64. */ maxConnections?: number; + /** The multi-vector configuration to use. Use `vectorIndex.multiVector` to make one. */ + multiVector?: MultiVectorConfigCreate; /** The quantizer configuration to use. Use `vectorIndex.quantizer.bq` or `vectorIndex.quantizer.pq` to make one. */ quantizer?: PQConfigCreate | BQConfigCreate | SQConfigCreate; /** Whether to skip the index. Default is false. */ diff --git a/src/collections/configure/types/vectorizer.ts b/src/collections/configure/types/vectorizer.ts index 5318e8f4..b6afdfc8 100644 --- a/src/collections/configure/types/vectorizer.ts +++ b/src/collections/configure/types/vectorizer.ts @@ -53,9 +53,13 @@ export type VectorConfigUpdate>; }; -export type VectorizersConfigCreate = - | VectorConfigCreate, undefined, VectorIndexType, Vectorizer> - | VectorConfigCreate, string, VectorIndexType, Vectorizer>[]; +export type VectorizersConfigCreate = V extends undefined + ? + | VectorConfigCreate, undefined, VectorIndexType, Vectorizer> + | VectorConfigCreate, string, VectorIndexType, Vectorizer>[] + : + | VectorConfigCreate, undefined, VectorIndexType, Vectorizer> + | VectorConfigCreate, keyof V & string, VectorIndexType, Vectorizer>[]; export type ConfigureNonTextVectorizerOptions< N extends string | undefined, diff --git a/src/collections/configure/unit.test.ts b/src/collections/configure/unit.test.ts index 93c50f65..7df30c2b 100644 --- a/src/collections/configure/unit.test.ts +++ b/src/collections/configure/unit.test.ts @@ -129,6 +129,7 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { quantizer: { type: 'pq', }, + type: 'hnsw', }, }); }); @@ -183,6 +184,7 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { type: 'pq', }, skip: true, + type: 'hnsw', vectorCacheMaxObjects: 2000000000000, }, }); @@ -196,6 +198,7 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { quantizer: { type: 'bq', }, + type: 'flat', }, }); }); @@ -220,6 +223,7 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { rescoreLimit: 100, type: 'bq', }, + type: 'flat', }, }); }); @@ -239,6 +243,22 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { trainingLimit: 200, type: 'sq', }, + type: 'hnsw', + }, + }); + }); + + it('should create an hnsw VectorIndexConfig type with multivector enabled', () => { + const config = configure.vectorIndex.hnsw({ + multiVector: configure.vectorIndex.multiVector.multiVector({ aggregation: 'maxSim' }), + }); + expect(config).toEqual>({ + name: 'hnsw', + config: { + multiVector: { + aggregation: 'maxSim', + }, + type: 'hnsw', }, }); }); diff --git a/src/collections/configure/vectorIndex.ts b/src/collections/configure/vectorIndex.ts index a9e8790e..b61c2135 100644 --- a/src/collections/configure/vectorIndex.ts +++ b/src/collections/configure/vectorIndex.ts @@ -7,6 +7,7 @@ import { import { BQConfigCreate, BQConfigUpdate, + MultiVectorConfigCreate, PQConfigCreate, PQConfigUpdate, SQConfigCreate, @@ -44,6 +45,7 @@ const configure = { distance, vectorCacheMaxObjects, quantizer: quantizer, + type: 'flat', }, }; }, @@ -66,6 +68,7 @@ const configure = { ...rest, distance: distanceMetric, quantizer: rest.quantizer, + type: 'hnsw', } : undefined, }; @@ -89,10 +92,27 @@ const configure = { threshold: opts.threshold, hnsw: isModuleConfig(opts.hnsw) ? opts.hnsw.config : configure.hnsw(opts.hnsw).config, flat: isModuleConfig(opts.flat) ? opts.flat.config : configure.flat(opts.flat).config, + type: 'dynamic', } : undefined, }; }, + /** + * Define the configuration for a multi-vector index. + */ + multiVector: { + /** + * Create an object of type `MultiVectorConfigCreate` to be used when defining the configuration of a multi-vector index. + * + * @param {object} [options.aggregation] The aggregation method to use. Default is 'maxSim'. + * @returns {MultiVectorConfigCreate} The object of type `MultiVectorConfigCreate`. + */ + multiVector: (options?: { aggregation?: 'maxSim' | string }): MultiVectorConfigCreate => { + return { + aggregation: options?.aggregation, + }; + }, + }, /** * Define the quantizer configuration to use when creating a vector index. */ diff --git a/src/collections/data/integration.test.ts b/src/collections/data/integration.test.ts index 3b55c64f..8493d1aa 100644 --- a/src/collections/data/integration.test.ts +++ b/src/collections/data/integration.test.ts @@ -328,7 +328,7 @@ describe('Testing of the collection.data methods with a single target reference' collection.query.fetchObjectById(toBeReplacedID, { returnReferences: [{ linkOn: 'ref' }], }); - const assert = (obj: WeaviateObject | null, id: string) => { + const assert = (obj: WeaviateObject | null, id: string) => { expect(obj).not.toBeNull(); expect(obj?.references?.ref?.objects[0].uuid).toEqual(id); }; diff --git a/src/collections/deserialize/index.ts b/src/collections/deserialize/index.ts index 588c2642..54886eb4 100644 --- a/src/collections/deserialize/index.ts +++ b/src/collections/deserialize/index.ts @@ -11,12 +11,14 @@ import { AggregateReply_Aggregations_Aggregation_Text, AggregateReply_Group_GroupedBy, } from '../../proto/v1/aggregate.js'; +import { Vectors_VectorType } from '../../proto/v1/base.js'; import { BatchObject as BatchObjectGRPC, BatchObjectsReply } from '../../proto/v1/batch.js'; import { BatchDeleteReply } from '../../proto/v1/batch_delete.js'; import { ListValue, Properties as PropertiesGrpc, Value } from '../../proto/v1/properties.js'; import { MetadataResult, PropertiesResult, SearchReply } from '../../proto/v1/search_get.js'; import { TenantActivityStatus, TenantsGetReply } from '../../proto/v1/tenants.js'; import { DbVersionSupport } from '../../utils/dbVersion.js'; +import { yieldToEventLoop } from '../../utils/yield.js'; import { AggregateBoolean, AggregateDate, @@ -26,7 +28,10 @@ import { AggregateText, AggregateType, PropertiesMetrics, + Vectors, + WeaviateObject, } from '../index.js'; +import { MultiVectorType, SingleVectorType } from '../query/types.js'; import { referenceFromObjects } from '../references/utils.js'; import { Tenant } from '../tenants/index.js'; import { @@ -45,6 +50,9 @@ import { WeaviateReturn, } from '../types/index.js'; +const UINT16LEN = 2; +const UINT32LEN = 4; + export class Deserialize { private supports125ListValue: boolean; @@ -193,50 +201,57 @@ export class Deserialize { }); } - public query(reply: SearchReply): WeaviateReturn { + public async query(reply: SearchReply): Promise> { return { - objects: reply.results.map((result) => { - return { - metadata: Deserialize.metadata(result.metadata), - properties: this.properties(result.properties), - references: this.references(result.properties), - uuid: Deserialize.uuid(result.metadata), - vectors: Deserialize.vectors(result.metadata), - } as any; - }), + objects: await Promise.all( + reply.results.map(async (result) => { + return { + metadata: Deserialize.metadata(result.metadata), + properties: this.properties(result.properties), + references: await this.references(result.properties), + uuid: Deserialize.uuid(result.metadata), + vectors: await Deserialize.vectors(result.metadata), + } as unknown as WeaviateObject; + }) + ), }; } - public generate(reply: SearchReply): GenerativeReturn { + public async generate(reply: SearchReply): Promise> { return { - objects: reply.results.map((result) => { - return { - generated: result.metadata?.generativePresent ? result.metadata?.generative : undefined, - metadata: Deserialize.metadata(result.metadata), - properties: this.properties(result.properties), - references: this.references(result.properties), - uuid: Deserialize.uuid(result.metadata), - vectors: Deserialize.vectors(result.metadata), - } as any; - }), + objects: await Promise.all( + reply.results.map(async (result) => { + return { + generated: result.metadata?.generativePresent ? result.metadata?.generative : undefined, + metadata: Deserialize.metadata(result.metadata), + properties: this.properties(result.properties), + references: await this.references(result.properties), + uuid: Deserialize.uuid(result.metadata), + vectors: await Deserialize.vectors(result.metadata), + } as unknown as WeaviateObject; + }) + ), generated: reply.generativeGroupedResult, }; } - public queryGroupBy(reply: SearchReply): GroupByReturn { - const objects: GroupByObject[] = []; - const groups: Record> = {}; - reply.groupByResults.forEach((result) => { - const objs = result.objects.map((object) => { - return { - belongsToGroup: result.name, - metadata: Deserialize.metadata(object.metadata), - properties: this.properties(object.properties), - references: this.references(object.properties), - uuid: Deserialize.uuid(object.metadata), - vectors: Deserialize.vectors(object.metadata), - } as any; - }); + public async queryGroupBy(reply: SearchReply): Promise> { + const objects: GroupByObject[] = []; + const groups: Record> = {}; + for (const result of reply.groupByResults) { + // eslint-disable-next-line no-await-in-loop + const objs = await Promise.all( + result.objects.map(async (object) => { + return { + belongsToGroup: result.name, + metadata: Deserialize.metadata(object.metadata), + properties: this.properties(object.properties), + references: await this.references(object.properties), + uuid: Deserialize.uuid(object.metadata), + vectors: await Deserialize.vectors(object.metadata), + } as unknown as GroupByObject; + }) + ); groups[result.name] = { maxDistance: result.maxDistance, minDistance: result.minDistance, @@ -245,27 +260,30 @@ export class Deserialize { objects: objs, }; objects.push(...objs); - }); + } return { objects: objects, groups: groups, }; } - public generateGroupBy(reply: SearchReply): GenerativeGroupByReturn { - const objects: GroupByObject[] = []; - const groups: Record> = {}; - reply.groupByResults.forEach((result) => { - const objs = result.objects.map((object) => { - return { - belongsToGroup: result.name, - metadata: Deserialize.metadata(object.metadata), - properties: this.properties(object.properties), - references: this.references(object.properties), - uuid: Deserialize.uuid(object.metadata), - vectors: Deserialize.vectors(object.metadata), - } as any; - }); + public async generateGroupBy(reply: SearchReply): Promise> { + const objects: GroupByObject[] = []; + const groups: Record> = {}; + for (const result of reply.groupByResults) { + // eslint-disable-next-line no-await-in-loop + const objs = await Promise.all( + result.objects.map(async (object) => { + return { + belongsToGroup: result.name, + metadata: Deserialize.metadata(object.metadata), + properties: this.properties(object.properties), + references: await this.references(object.properties), + uuid: Deserialize.uuid(object.metadata), + vectors: await Deserialize.vectors(object.metadata), + } as unknown as GroupByObject; + }) + ); groups[result.name] = { maxDistance: result.maxDistance, minDistance: result.minDistance, @@ -275,7 +293,7 @@ export class Deserialize { generated: result.generative?.result, }; objects.push(...objs); - }); + } return { objects: objects, groups: groups, @@ -288,28 +306,31 @@ export class Deserialize { return this.objectProperties(properties.nonRefProps); } - private references(properties?: PropertiesResult) { + private async references(properties?: PropertiesResult) { if (!properties) return undefined; if (properties.refProps.length === 0) return properties.refPropsRequested ? {} : undefined; const out: any = {}; - properties.refProps.forEach((property) => { + for (const property of properties.refProps) { const uuids: string[] = []; out[property.propName] = referenceFromObjects( - property.properties.map((property) => { - const uuid = Deserialize.uuid(property.metadata); - uuids.push(uuid); - return { - metadata: Deserialize.metadata(property.metadata), - properties: this.properties(property), - references: this.references(property), - uuid: uuid, - vectors: Deserialize.vectors(property.metadata), - }; - }), + // eslint-disable-next-line no-await-in-loop + await Promise.all( + property.properties.map(async (property) => { + const uuid = Deserialize.uuid(property.metadata); + uuids.push(uuid); + return { + metadata: Deserialize.metadata(property.metadata), + properties: this.properties(property), + references: await this.references(property), + uuid: uuid, + vectors: await Deserialize.vectors(property.metadata), + }; + }) + ), property.properties.length > 0 ? property.properties[0].targetCollection : '', uuids ); - }); + } return out; } @@ -375,7 +396,33 @@ export class Deserialize { return metadata.id; } - private static vectorFromBytes(bytes: Uint8Array) { + /** + * Convert an Uint8Array into a 2D vector array. + * + * Defined as an async method so that control can be relinquished back to the event loop on each outer loop for large vectors. + */ + private static vectorsFromBytes(bytes: Uint8Array): Promise { + const dimOffset = UINT16LEN; + const dimBytes = Buffer.from(bytes.slice(0, dimOffset)); + const vectorDimension = dimBytes.readUInt16LE(0); + + const vecByteLength = UINT32LEN * vectorDimension; + const howMany = (bytes.byteLength - dimOffset) / vecByteLength; + + return Promise.all( + Array(howMany) + .fill(0) + .map((_, i) => + yieldToEventLoop().then(() => + Deserialize.vectorFromBytes( + bytes.slice(dimOffset + i * vecByteLength, dimOffset + (i + 1) * vecByteLength) + ) + ) + ) + ); + } + + private static vectorFromBytes(bytes: Uint8Array): SingleVectorType { const buffer = Buffer.from(bytes); const view = new Float32Array(buffer.buffer, buffer.byteOffset, buffer.byteLength / 4); // vector is float32 in weaviate return Array.from(view); @@ -393,14 +440,21 @@ export class Deserialize { return Array.from(view); } - private static vectors(metadata?: MetadataResult): Record { + private static async vectors(metadata?: MetadataResult): Promise { if (!metadata) return {}; if (metadata.vectorBytes.length === 0 && metadata.vector.length === 0 && metadata.vectors.length === 0) return {}; if (metadata.vectorBytes.length > 0) return { default: Deserialize.vectorFromBytes(metadata.vectorBytes) }; return Object.fromEntries( - metadata.vectors.map((vector) => [vector.name, Deserialize.vectorFromBytes(vector.vectorBytes)]) + await Promise.all( + metadata.vectors.map(async (vector) => [ + vector.name, + vector.type === Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32 + ? Deserialize.vectorFromBytes(vector.vectorBytes) + : await Deserialize.vectorsFromBytes(vector.vectorBytes), + ]) + ) ); } diff --git a/src/collections/filters/integration.test.ts b/src/collections/filters/integration.test.ts index abfbd8ba..ff57aed3 100644 --- a/src/collections/filters/integration.test.ts +++ b/src/collections/filters/integration.test.ts @@ -95,7 +95,7 @@ describe('Testing of the filter class with a simple collection', () => { return uuids; }); const res = await collection.query.fetchObjectById(ids[0], { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); it('should filter a fetch objects query with a single filter and generic collection', async () => { diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index d783e61a..50c9ef83 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -34,42 +34,44 @@ import { } from '../types/index.js'; import { Generate } from './types.js'; -class GenerateManager implements Generate { - private check: Check; +class GenerateManager implements Generate { + private check: Check; - private constructor(check: Check) { + private constructor(check: Check) { this.check = check; } - public static use( + public static use( connection: Connection, name: string, dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: string - ): GenerateManager { - return new GenerateManager(new Check(connection, name, dbVersionSupport, consistencyLevel, tenant)); + ): GenerateManager { + return new GenerateManager( + new Check(connection, name, dbVersionSupport, consistencyLevel, tenant) + ); } private async parseReply(reply: SearchReply) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); - return deserialize.generate(reply); + return deserialize.generate(reply); } private async parseGroupByReply( - opts: SearchOptions | GroupByOptions | undefined, + opts: SearchOptions | GroupByOptions | undefined, reply: SearchReply ) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); return Serialize.search.isGroupBy(opts) - ? deserialize.generateGroupBy(reply) - : deserialize.generate(reply); + ? deserialize.generateGroupBy(reply) + : deserialize.generate(reply); } public fetchObjects( generate: GenerateOptions, - opts?: FetchObjectsOptions - ): Promise> { + opts?: FetchObjectsOptions + ): Promise> { return this.check .fetchObjects(opts) .then(({ search }) => ({ @@ -86,14 +88,14 @@ class GenerateManager implements Generate { public bm25( query: string, generate: GenerateOptions, - opts?: BaseBm25Options - ): Promise>; + opts?: BaseBm25Options + ): Promise>; public bm25( query: string, generate: GenerateOptions, - opts: GroupByBm25Options - ): Promise>; - public bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn { + opts: GroupByBm25Options + ): Promise>; + public bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn { return this.check .bm25(opts) .then(({ search }) => ({ @@ -110,14 +112,18 @@ class GenerateManager implements Generate { public hybrid( query: string, generate: GenerateOptions, - opts?: BaseHybridOptions - ): Promise>; + opts?: BaseHybridOptions + ): Promise>; public hybrid( query: string, generate: GenerateOptions, - opts: GroupByHybridOptions - ): Promise>; - public hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn { + opts: GroupByHybridOptions + ): Promise>; + public hybrid( + query: string, + generate: GenerateOptions, + opts?: HybridOptions + ): GenerateReturn { return this.check .hybridSearch(opts) .then( @@ -151,18 +157,18 @@ class GenerateManager implements Generate { public nearImage( image: string | Buffer, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; public nearImage( image: string | Buffer, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; public nearImage( image: string | Buffer, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return this.check .nearSearch(opts) .then(async ({ search, supportsTargets, supportsWeightsForTargets }) => ({ @@ -186,14 +192,18 @@ class GenerateManager implements Generate { public nearObject( id: string, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; + public nearObject( + id: string, + generate: GenerateOptions, + opts: GroupByNearOptions + ): Promise>; public nearObject( id: string, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - public nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ @@ -217,18 +227,18 @@ class GenerateManager implements Generate { public nearText( query: string | string[], generate: GenerateOptions, - opts?: BaseNearTextOptions - ): Promise>; + opts?: BaseNearTextOptions + ): Promise>; public nearText( query: string | string[], generate: GenerateOptions, - opts: GroupByNearTextOptions - ): Promise>; + opts: GroupByNearTextOptions + ): Promise>; public nearText( query: string | string[], generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ @@ -252,18 +262,18 @@ class GenerateManager implements Generate { public nearVector( vector: number[], generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; public nearVector( vector: number[], generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; public nearVector( vector: number[], generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return this.check .nearVector(vector, opts) .then( @@ -298,20 +308,20 @@ class GenerateManager implements Generate { media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; public nearMedia( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; public nearMedia( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => { diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index 1be98451..a3d0378d 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -65,7 +65,7 @@ maybe('Testing of the collection.generate methods with a simple collection', () }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); describe('using a non-generic collection', () => { @@ -206,7 +206,7 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); // it('should groupBy without search', async () => { @@ -366,8 +366,8 @@ maybe('Testing of the collection.generate methods with a multi vector collection }, }); const res = await collection.query.fetchObjectById(id1, { includeVector: true }); - titleVector = res!.vectors.title!; - title2Vector = res!.vectors.title2!; + titleVector = res!.vectors.title as number[]; + title2Vector = res!.vectors.title2 as number[]; }); if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 24, 0))) { await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); diff --git a/src/collections/generate/types.ts b/src/collections/generate/types.ts index b211a46a..78ffe7ec 100644 --- a/src/collections/generate/types.ts +++ b/src/collections/generate/types.ts @@ -22,7 +22,7 @@ import { GenerativeReturn, } from '../types/index.js'; -interface Bm25 { +interface Bm25 { /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -32,10 +32,14 @@ interface Bm25 { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseBm25Options} [opts] - The available options for performing the BM25 search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseBm25Options} [opts] - The available options for performing the BM25 search. + * @return {Promise>} - The results of the search including the generated data. */ - bm25(query: string, generate: GenerateOptions, opts?: BaseBm25Options): Promise>; + bm25( + query: string, + generate: GenerateOptions, + opts?: BaseBm25Options + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -45,14 +49,14 @@ interface Bm25 { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByBm25Options} opts - The available options for performing the BM25 search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByBm25Options} opts - The available options for performing the BM25 search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ bm25( query: string, generate: GenerateOptions, - opts: GroupByBm25Options - ): Promise>; + opts: GroupByBm25Options + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -62,13 +66,13 @@ interface Bm25 { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {Bm25Options} [opts] - The available options for performing the BM25 search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {Bm25Options} [opts] - The available options for performing the BM25 search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn; + bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn; } -interface Hybrid { +interface Hybrid { /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -78,14 +82,14 @@ interface Hybrid { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseHybridOptions} [opts] - The available options for performing the hybrid search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseHybridOptions} [opts] - The available options for performing the hybrid search. + * @return {Promise>} - The results of the search including the generated data. */ hybrid( query: string, generate: GenerateOptions, - opts?: BaseHybridOptions - ): Promise>; + opts?: BaseHybridOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -95,14 +99,14 @@ interface Hybrid { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByHybridOptions} opts - The available options for performing the hybrid search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByHybridOptions} opts - The available options for performing the hybrid search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ hybrid( query: string, generate: GenerateOptions, - opts: GroupByHybridOptions - ): Promise>; + opts: GroupByHybridOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -112,13 +116,13 @@ interface Hybrid { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {HybridOptions} [opts] - The available options for performing the hybrid search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {HybridOptions} [opts] - The available options for performing the hybrid search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn; + hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn; } -interface NearMedia { +interface NearMedia { /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -131,15 +135,15 @@ interface NearMedia { * @param {string | Buffer} media - The media file to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseNearOptions} [opts] - The available options for performing the near-media search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseNearOptions} [opts] - The available options for performing the near-media search. + * @return {Promise>} - The results of the search including the generated data. */ nearMedia( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -152,15 +156,15 @@ interface NearMedia { * @param {string | Buffer} media - The media file to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByNearOptions} opts - The available options for performing the near-media search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByNearOptions} opts - The available options for performing the near-media search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ nearMedia( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -173,18 +177,18 @@ interface NearMedia { * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {NearOptions} [opts] - The available options for performing the near-media search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {NearOptions} [opts] - The available options for performing the near-media search. + * @return {GenerateReturn} - The results of the search including the generated data. */ nearMedia( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn; + opts?: NearOptions + ): GenerateReturn; } -interface NearObject { +interface NearObject { /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -194,14 +198,14 @@ interface NearObject { * * @param {string} id - The ID of the object to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseNearOptions} [opts] - The available options for performing the near-object search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseNearOptions} [opts] - The available options for performing the near-object search. + * @return {Promise>} - The results of the search including the generated data. */ nearObject( id: string, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -211,14 +215,14 @@ interface NearObject { * * @param {string} id - The ID of the object to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByNearOptions} opts - The available options for performing the near-object search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByNearOptions} opts - The available options for performing the near-object search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ nearObject( id: string, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -228,13 +232,13 @@ interface NearObject { * * @param {string} id - The ID of the object to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {NearOptions} [opts] - The available options for performing the near-object search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {NearOptions} [opts] - The available options for performing the near-object search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn; + nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn; } -interface NearText { +interface NearText { /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -246,14 +250,14 @@ interface NearText { * * @param {string | string[]} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseNearTextOptions} [opts] - The available options for performing the near-text search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseNearTextOptions} [opts] - The available options for performing the near-text search. + * @return {Promise>} - The results of the search including the generated data. */ nearText( query: string | string[], generate: GenerateOptions, - opts?: BaseNearTextOptions - ): Promise>; + opts?: BaseNearTextOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -265,14 +269,14 @@ interface NearText { * * @param {string | string[]} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByNearTextOptions} opts - The available options for performing the near-text search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByNearTextOptions} opts - The available options for performing the near-text search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ nearText( query: string | string[], generate: GenerateOptions, - opts: GroupByNearTextOptions - ): Promise>; + opts: GroupByNearTextOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -284,17 +288,17 @@ interface NearText { * * @param {string | string[]} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {NearTextOptions} [opts] - The available options for performing the near-text search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {NearTextOptions} [opts] - The available options for performing the near-text search. + * @return {GenerateReturn} - The results of the search including the generated data. */ nearText( query: string | string[], generate: GenerateOptions, - opts?: NearTextOptions - ): GenerateReturn; + opts?: NearTextOptions + ): GenerateReturn; } -interface NearVector { +interface NearVector { /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -304,14 +308,14 @@ interface NearVector { * * @param {NearVectorInputType} vector - The vector(s) to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseNearOptions} [opts] - The available options for performing the near-vector search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseNearOptions} [opts] - The available options for performing the near-vector search. + * @return {Promise>} - The results of the search including the generated data. */ nearVector( vector: NearVectorInputType, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -321,14 +325,14 @@ interface NearVector { * * @param {NearVectorInputType} vector - The vector(s) to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByNearOptions} opts - The available options for performing the near-vector search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByNearOptions} opts - The available options for performing the near-vector search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ nearVector( vector: NearVectorInputType, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -338,22 +342,25 @@ interface NearVector { * * @param {NearVectorInputType} vector - The vector(s) to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {NearOptions} [opts] - The available options for performing the near-vector search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {NearOptions} [opts] - The available options for performing the near-vector search. + * @return {GenerateReturn} - The results of the search including the generated data. */ nearVector( vector: NearVectorInputType, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn; + opts?: NearOptions + ): GenerateReturn; } -export interface Generate - extends Bm25, - Hybrid, - NearMedia, - NearObject, - NearText, - NearVector { - fetchObjects: (generate: GenerateOptions, opts?: FetchObjectsOptions) => Promise>; +export interface Generate + extends Bm25, + Hybrid, + NearMedia, + NearObject, + NearText, + NearVector { + fetchObjects: ( + generate: GenerateOptions, + opts?: FetchObjectsOptions + ) => Promise>; } diff --git a/src/collections/index.ts b/src/collections/index.ts index 0472c459..55ed194a 100644 --- a/src/collections/index.ts +++ b/src/collections/index.ts @@ -6,7 +6,7 @@ import { ClassCreator, ClassDeleter, ClassGetter, SchemaGetter } from '../schema import { DbVersionSupport } from '../utils/dbVersion.js'; import collection, { Collection } from './collection/index.js'; import { classToCollection, resolveProperty, resolveReference } from './config/utils.js'; -import { QuantizerGuards } from './configure/parsing.js'; +import { QuantizerGuards, VectorIndexGuards } from './configure/parsing.js'; import { configGuards } from './index.js'; import { CollectionConfig, @@ -24,13 +24,13 @@ import { ShardingConfigCreate, VectorConfigCreate, VectorIndexConfigCreate, - VectorIndexConfigDynamicCreate, VectorIndexConfigFlatCreate, VectorIndexConfigHNSWCreate, VectorIndexType, Vectorizer, VectorizerConfig, VectorizersConfigCreate, + Vectors, } from './types/index.js'; import { PrimitiveKeys } from './types/internal.js'; @@ -40,7 +40,7 @@ import { PrimitiveKeys } from './types/internal.js'; * Inspect [the docs](https://weaviate.io/developers/weaviate/configuration) for more information on the * different configuration options and how they affect the behavior of your collection. */ -export type CollectionConfigCreate = { +export type CollectionConfigCreate = { /** The name of the collection. */ name: N; /** The description of the collection. */ @@ -62,23 +62,36 @@ export type CollectionConfigCreate = { /** The configuration for Weaviate's sharding strategy. Is mutually exclusive with `replication`. */ sharding?: ShardingConfigCreate; /** The configuration for Weaviate's vectorizer(s) capabilities. */ - vectorizers?: VectorizersConfigCreate; + vectorizers?: VectorizersConfigCreate; }; const parseVectorIndex = (module: ModuleConfig): any => { if (module.config === undefined) return undefined; - if (module.name === 'dynamic') { - const { hnsw, flat, ...conf } = module.config as VectorIndexConfigDynamicCreate; + if (VectorIndexGuards.isDynamic(module.config)) { + const { hnsw, flat, ...conf } = module.config; return { ...conf, hnsw: parseVectorIndex({ name: 'hnsw', config: hnsw }), flat: parseVectorIndex({ name: 'flat', config: flat }), }; } - const { quantizer, ...conf } = module.config as + + let multiVector; + if (VectorIndexGuards.isHNSW(module.config) && module.config.multiVector !== undefined) { + multiVector = { + ...module.config.multiVector, + enabled: true, + }; + } + + const { quantizer, ...rest } = module.config as | VectorIndexConfigFlatCreate | VectorIndexConfigHNSWCreate | Record; + const conf = { + ...rest, + multivector: multiVector, + }; if (quantizer === undefined) return conf; if (QuantizerGuards.isBQCreate(quantizer)) { const { type, ...quant } = quantizer; @@ -118,9 +131,11 @@ const collections = (connection: Connection, dbVersionSupport: DbVersionSupport) .then((schema) => (schema.classes ? schema.classes.map(classToCollection) : [])); const deleteCollection = (name: string) => new ClassDeleter(connection).withClassName(name).do(); return { - create: async function ( - config: CollectionConfigCreate - ) { + create: async function < + TProperties extends Properties | undefined = undefined, + TName = string, + TVectors extends Vectors | undefined = undefined + >(config: CollectionConfigCreate) { const { name, invertedIndex, multiTenancy, replication, sharding, ...rest } = config; const supportsDynamicVectorIndex = await dbVersionSupport.supportsDynamicVectorIndex(); @@ -137,7 +152,7 @@ const collections = (connection: Connection, dbVersionSupport: DbVersionSupport) moduleConfig[config.reranker.name] = config.reranker.config ? config.reranker.config : {}; } - const makeVectorsConfig = (configVectorizers: VectorizersConfigCreate) => { + const makeVectorsConfig = (configVectorizers: VectorizersConfigCreate) => { let vectorizers: string[] = []; const vectorsConfig: Record = {}; const vectorizersConfig = Array.isArray(configVectorizers) @@ -258,11 +273,11 @@ const collections = (connection: Connection, dbVersionSupport: DbVersionSupport) schema.properties = [...properties, ...references]; await new ClassCreator(connection).withClass(schema).do(); - return collection(connection, name, dbVersionSupport); + return collection(connection, name, dbVersionSupport); }, createFromSchema: async function (config: WeaviateClass) { const { class: name } = await new ClassCreator(connection).withClass(config).do(); - return collection(connection, name as string, dbVersionSupport); + return collection(connection, name as string, dbVersionSupport); }, delete: deleteCollection, deleteAll: () => listAll().then((configs) => Promise.all(configs?.map((c) => deleteCollection(c.name)))), @@ -275,14 +290,25 @@ const collections = (connection: Connection, dbVersionSupport: DbVersionSupport) listAll: listAll, get: ( name: TName - ) => collection(connection, name, dbVersionSupport), + ) => collection(connection, name, dbVersionSupport), + use: < + TProperties extends Properties | undefined = undefined, + TName extends string = string, + TVectors extends Vectors | undefined = undefined + >( + name: TName + ) => collection(connection, name, dbVersionSupport), }; }; export interface Collections { - create( - config: CollectionConfigCreate - ): Promise>; + create< + TProperties extends Properties | undefined = undefined, + TName = string, + TVectors extends Vectors | undefined = undefined + >( + config: CollectionConfigCreate + ): Promise>; createFromSchema(config: WeaviateClass): Promise>; delete(collection: string): Promise; deleteAll(): Promise; @@ -292,9 +318,13 @@ export interface Collections { name: TName ): Collection; listAll(): Promise; - // use( - // name: TName - // ): Collection; + use< + TName extends string = string, + TProperties extends Properties | undefined = undefined, + TVectors extends Vectors | undefined = undefined + >( + name: TName + ): Collection; } export default collections; diff --git a/src/collections/iterator/index.ts b/src/collections/iterator/index.ts index 0e631bb9..edb35762 100644 --- a/src/collections/iterator/index.ts +++ b/src/collections/iterator/index.ts @@ -3,16 +3,16 @@ import { WeaviateObject } from '../types/index.js'; const ITERATOR_CACHE_SIZE = 100; -export class Iterator { - private cache: WeaviateObject[] = []; +export class Iterator { + private cache: WeaviateObject[] = []; private last: string | undefined = undefined; - constructor(private query: (limit: number, after?: string) => Promise[]>) { + constructor(private query: (limit: number, after?: string) => Promise[]>) { this.query = query; } [Symbol.asyncIterator]() { return { - next: async (): Promise>> => { + next: async (): Promise>> => { const objects = await this.query(ITERATOR_CACHE_SIZE, this.last); this.cache = objects; if (this.cache.length == 0) { diff --git a/src/collections/iterator/integration.test.ts b/src/collections/iterator/integration.test.ts index 6937b270..d621da49 100644 --- a/src/collections/iterator/integration.test.ts +++ b/src/collections/iterator/integration.test.ts @@ -45,7 +45,7 @@ describe('Testing of the collection.iterator method with a simple collection', ( }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); it('should iterate through the collection with no options returning the objects', async () => { diff --git a/src/collections/journey.test.ts b/src/collections/journey.test.ts index 85da3a8f..5a545f9d 100644 --- a/src/collections/journey.test.ts +++ b/src/collections/journey.test.ts @@ -187,6 +187,7 @@ describe('Journey testing of the client using a WCD cluster', () => { maxConnections: (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 26, 0))) ? 64 : 32, + multiVector: undefined, skip: false, vectorCacheMaxObjects: 1000000000000, quantizer: undefined, diff --git a/src/collections/query/check.ts b/src/collections/query/check.ts index cf437632..41a4606c 100644 --- a/src/collections/query/check.ts +++ b/src/collections/query/check.ts @@ -17,7 +17,7 @@ import { SearchOptions, } from './types.js'; -export class Check { +export class Check { private connection: Connection; private name: string; public dbVersionSupport: DbVersionSupport; @@ -40,7 +40,7 @@ export class Check { private getSearcher = () => this.connection.search(this.name, this.consistencyLevel, this.tenant); - private checkSupportForNamedVectors = async (opts?: BaseNearOptions) => { + private checkSupportForNamedVectors = async (opts?: BaseNearOptions) => { if (!Serialize.isNamedVectors(opts)) return; const check = await this.dbVersionSupport.supportsNamedVectors(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); @@ -48,20 +48,20 @@ export class Check { private checkSupportForBm25AndHybridGroupByQueries = async ( query: 'Bm25' | 'Hybrid', - opts?: SearchOptions | GroupByOptions + opts?: SearchOptions | GroupByOptions ) => { if (!Serialize.search.isGroupBy(opts)) return; const check = await this.dbVersionSupport.supportsBm25AndHybridGroupByQueries(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message(query)); }; - private checkSupportForHybridNearTextAndNearVectorSubSearches = async (opts?: HybridOptions) => { + private checkSupportForHybridNearTextAndNearVectorSubSearches = async (opts?: HybridOptions) => { if (opts?.vector === undefined || Array.isArray(opts.vector)) return; const check = await this.dbVersionSupport.supportsHybridNearTextAndNearVectorSubsearchQueries(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); }; - private checkSupportForMultiTargetSearch = async (opts?: BaseNearOptions) => { + private checkSupportForMultiTargetSearch = async (opts?: BaseNearOptions) => { if (!Serialize.isMultiTarget(opts)) return false; const check = await this.dbVersionSupport.supportsMultiTargetVectorSearch(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); @@ -79,7 +79,7 @@ export class Check { return check.supports; }; - private checkSupportForMultiWeightPerTargetSearch = async (opts?: BaseNearOptions) => { + private checkSupportForMultiWeightPerTargetSearch = async (opts?: BaseNearOptions) => { if (!Serialize.isMultiWeightPerTarget(opts)) return false; const check = await this.dbVersionSupport.supportsMultiWeightsPerTargetSearch(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); @@ -102,15 +102,11 @@ export class Check { vec?: NearVectorInputType | HybridNearVectorSubSearch | HybridNearTextSubSearch ) => { if (vec === undefined || Serialize.isHybridNearTextSearch(vec)) return false; - if (Serialize.isHybridNearVectorSearch(vec) && !Serialize.isMultiVectorPerTarget(vec.vector)) - return false; - if (Serialize.isHybridVectorSearch(vec) && !Serialize.isMultiVectorPerTarget(vec)) return false; - const check = await this.dbVersionSupport.supportsMultiVectorPerTargetSearch(); - if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); + const check = await this.dbVersionSupport.supportsVectorsFieldInGRPC(); return check.supports; }; - public nearSearch = (opts?: BaseNearOptions) => { + public nearSearch = (opts?: BaseNearOptions) => { return Promise.all([ this.getSearcher(), this.checkSupportForMultiTargetSearch(opts), @@ -123,14 +119,14 @@ export class Check { }); }; - public nearVector = (vec: NearVectorInputType, opts?: BaseNearOptions) => { + public nearVector = (vec: NearVectorInputType, opts?: BaseNearOptions) => { return Promise.all([ this.getSearcher(), this.checkSupportForMultiTargetSearch(opts), this.checkSupportForMultiVectorSearch(vec), this.checkSupportForMultiVectorPerTargetSearch(vec), this.checkSupportForMultiWeightPerTargetSearch(opts), - this.checkSupportForVectors(), + this.checkSupportForVectors(vec), this.checkSupportForNamedVectors(opts), ]).then( ([ @@ -155,14 +151,14 @@ export class Check { ); }; - public hybridSearch = (opts?: BaseHybridOptions) => { + public hybridSearch = (opts?: BaseHybridOptions) => { return Promise.all([ this.getSearcher(), this.checkSupportForMultiTargetSearch(opts), this.checkSupportForMultiVectorSearch(opts?.vector), this.checkSupportForMultiVectorPerTargetSearch(opts?.vector), this.checkSupportForMultiWeightPerTargetSearch(opts), - this.checkSupportForVectors(), + this.checkSupportForVectors(opts?.vector), this.checkSupportForNamedVectors(opts), this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts), this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts), @@ -189,19 +185,19 @@ export class Check { ); }; - public fetchObjects = (opts?: FetchObjectsOptions) => { + public fetchObjects = (opts?: FetchObjectsOptions) => { return Promise.all([this.getSearcher(), this.checkSupportForNamedVectors(opts)]).then(([search]) => { return { search }; }); }; - public fetchObjectById = (opts?: FetchObjectByIdOptions) => { + public fetchObjectById = (opts?: FetchObjectByIdOptions) => { return Promise.all([this.getSearcher(), this.checkSupportForNamedVectors(opts)]).then(([search]) => { return { search }; }); }; - public bm25 = (opts?: BaseBm25Options) => { + public bm25 = (opts?: BaseBm25Options) => { return Promise.all([ this.getSearcher(), this.checkSupportForNamedVectors(opts), diff --git a/src/collections/query/index.ts b/src/collections/query/index.ts index 69adc646..1dda49bf 100644 --- a/src/collections/query/index.ts +++ b/src/collections/query/index.ts @@ -34,39 +34,44 @@ import { SearchOptions, } from './types.js'; -class QueryManager implements Query { - private check: Check; +class QueryManager implements Query { + private check: Check; - private constructor(check: Check) { + private constructor(check: Check) { this.check = check; } - public static use( + public static use( connection: Connection, name: string, dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: string - ): QueryManager { - return new QueryManager(new Check(connection, name, dbVersionSupport, consistencyLevel, tenant)); + ): QueryManager { + return new QueryManager( + new Check(connection, name, dbVersionSupport, consistencyLevel, tenant) + ); } private async parseReply(reply: SearchReply) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); - return deserialize.query(reply); + return deserialize.query(reply); } private async parseGroupByReply( - opts: SearchOptions | GroupByOptions | undefined, + opts: SearchOptions | GroupByOptions | undefined, reply: SearchReply ) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); return Serialize.search.isGroupBy(opts) - ? deserialize.queryGroupBy(reply) - : deserialize.query(reply); + ? deserialize.queryGroupBy(reply) + : deserialize.query(reply); } - public fetchObjectById(id: string, opts?: FetchObjectByIdOptions): Promise | null> { + public fetchObjectById( + id: string, + opts?: FetchObjectByIdOptions + ): Promise | null> { return this.check .fetchObjectById(opts) .then(({ search }) => search.withFetch(Serialize.search.fetchObjectById({ id, ...opts }))) @@ -74,25 +79,25 @@ class QueryManager implements Query { .then((ret) => (ret.objects.length === 1 ? ret.objects[0] : null)); } - public fetchObjects(opts?: FetchObjectsOptions): Promise> { + public fetchObjects(opts?: FetchObjectsOptions): Promise> { return this.check .fetchObjects(opts) .then(({ search }) => search.withFetch(Serialize.search.fetchObjects(opts))) .then((reply) => this.parseReply(reply)); } - public bm25(query: string, opts?: BaseBm25Options): Promise>; - public bm25(query: string, opts: GroupByBm25Options): Promise>; - public bm25(query: string, opts?: Bm25Options): QueryReturn { + public bm25(query: string, opts?: BaseBm25Options): Promise>; + public bm25(query: string, opts: GroupByBm25Options): Promise>; + public bm25(query: string, opts?: Bm25Options): QueryReturn { return this.check .bm25(opts) .then(({ search }) => search.withBm25(Serialize.search.bm25(query, opts))) .then((reply) => this.parseGroupByReply(opts, reply)); } - public hybrid(query: string, opts?: BaseHybridOptions): Promise>; - public hybrid(query: string, opts: GroupByHybridOptions): Promise>; - public hybrid(query: string, opts?: HybridOptions): QueryReturn { + public hybrid(query: string, opts?: BaseHybridOptions): Promise>; + public hybrid(query: string, opts: GroupByHybridOptions): Promise>; + public hybrid(query: string, opts?: HybridOptions): QueryReturn { return this.check .hybridSearch(opts) .then( @@ -104,22 +109,25 @@ class QueryManager implements Query { supportsVectors, }) => ({ search, - args: await Serialize.search.hybrid({ - query, - supportsTargets, - supportsWeightsForTargets, - supportsVectorsForTargets, - supportsVectors, - }), + args: await Serialize.search.hybrid( + { + query, + supportsTargets, + supportsWeightsForTargets, + supportsVectorsForTargets, + supportsVectors, + }, + opts + ), }) ) .then(({ search, args }) => search.withHybrid(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearImage(image: string | Buffer, opts?: BaseNearOptions): Promise>; - public nearImage(image: string | Buffer, opts: GroupByNearOptions): Promise>; - public nearImage(image: string | Buffer, opts?: NearOptions): QueryReturn { + public nearImage(image: string | Buffer, opts?: BaseNearOptions): Promise>; + public nearImage(image: string | Buffer, opts: GroupByNearOptions): Promise>; + public nearImage(image: string | Buffer, opts?: NearOptions): QueryReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => { @@ -142,14 +150,14 @@ class QueryManager implements Query { public nearMedia( media: string | Buffer, type: NearMediaType, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; public nearMedia( media: string | Buffer, type: NearMediaType, - opts: GroupByNearOptions - ): Promise>; - public nearMedia(media: string | Buffer, type: NearMediaType, opts?: NearOptions): QueryReturn { + opts: GroupByNearOptions + ): Promise>; + public nearMedia(media: string | Buffer, type: NearMediaType, opts?: NearOptions): QueryReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => { @@ -189,9 +197,9 @@ class QueryManager implements Query { .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearObject(id: string, opts?: BaseNearOptions): Promise>; - public nearObject(id: string, opts: GroupByNearOptions): Promise>; - public nearObject(id: string, opts?: NearOptions): QueryReturn { + public nearObject(id: string, opts?: BaseNearOptions): Promise>; + public nearObject(id: string, opts: GroupByNearOptions): Promise>; + public nearObject(id: string, opts?: NearOptions): QueryReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ @@ -209,9 +217,9 @@ class QueryManager implements Query { .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearText(query: string | string[], opts?: BaseNearTextOptions): Promise>; - public nearText(query: string | string[], opts: GroupByNearTextOptions): Promise>; - public nearText(query: string | string[], opts?: NearTextOptions): QueryReturn { + public nearText(query: string | string[], opts?: BaseNearTextOptions): Promise>; + public nearText(query: string | string[], opts: GroupByNearTextOptions): Promise>; + public nearText(query: string | string[], opts?: NearTextOptions): QueryReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ @@ -229,9 +237,12 @@ class QueryManager implements Query { .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearVector(vector: NearVectorInputType, opts?: BaseNearOptions): Promise>; - public nearVector(vector: NearVectorInputType, opts: GroupByNearOptions): Promise>; - public nearVector(vector: NearVectorInputType, opts?: NearOptions): QueryReturn { + public nearVector(vector: NearVectorInputType, opts?: BaseNearOptions): Promise>; + public nearVector( + vector: NearVectorInputType, + opts: GroupByNearOptions + ): Promise>; + public nearVector(vector: NearVectorInputType, opts?: NearOptions): QueryReturn { return this.check .nearVector(vector, opts) .then( diff --git a/src/collections/query/integration.test.ts b/src/collections/query/integration.test.ts index 49a2e236..72380fa8 100644 --- a/src/collections/query/integration.test.ts +++ b/src/collections/query/integration.test.ts @@ -62,7 +62,7 @@ describe('Testing of the collection.query methods with a simple collection', () }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); it('should fetch an object by its id', async () => { @@ -556,16 +556,25 @@ describe('Testing of the collection.query methods with a collection with a neste describe('Testing of the collection.query methods with a collection with a multiple vectors', () => { let client: WeaviateClient; - let collection: Collection; + let collection: Collection< + TestCollectionQueryWithMultiVectorProps, + 'TestCollectionQueryWithMultiVector', + TestCollectionQueryWithMultiVectorVectors + >; const collectionName = 'TestCollectionQueryWithMultiVector'; let id1: string; let id2: string; - type TestCollectionQueryWithMultiVector = { + type TestCollectionQueryWithMultiVectorProps = { title: string; }; + type TestCollectionQueryWithMultiVectorVectors = { + title: number[]; + title2: number[]; + }; + afterAll(() => { return client.collections.delete(collectionName).catch((err) => { console.error(err); @@ -575,10 +584,10 @@ describe('Testing of the collection.query methods with a collection with a multi beforeAll(async () => { client = await weaviate.connectToLocal(); - collection = client.collections.get(collectionName); + collection = client.collections.use(collectionName); const query = () => client.collections - .create({ + .create({ name: collectionName, properties: [ { @@ -1107,7 +1116,7 @@ describe('Testing of the groupBy collection.query methods with a simple collecti }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); // it('should groupBy without search', async () => { diff --git a/src/collections/query/types.ts b/src/collections/query/types.ts index 3ad93d21..be5b4600 100644 --- a/src/collections/query/types.ts +++ b/src/collections/query/types.ts @@ -11,12 +11,12 @@ import { WeaviateObject, WeaviateReturn, } from '../types/index.js'; -import { PrimitiveKeys } from '../types/internal.js'; +import { IncludeVector, PrimitiveKeys } from '../types/internal.js'; /** Options available in the `query.fetchObjectById` method */ -export type FetchObjectByIdOptions = { +export type FetchObjectByIdOptions = { /** Whether to include the vector of the object in the response. If using named vectors, pass an array of strings to include only specific vectors. */ - includeVector?: boolean | string[]; + includeVector?: IncludeVector; /** * Which properties of the object to return. Can be primitive, in which case specify their names, or nested, in which case * use the QueryNested type. If not specified, all properties are returned. @@ -27,7 +27,7 @@ export type FetchObjectByIdOptions = { }; /** Options available in the `query.fetchObjects` method */ -export type FetchObjectsOptions = { +export type FetchObjectsOptions = { /** How many objects to return in the query */ limit?: number; /** How many objects to skip in the query. Incompatible with the `after` cursor */ @@ -39,7 +39,7 @@ export type FetchObjectsOptions = { /** The sorting to be applied to the query. Use `weaviate.sort.*` to create sorting */ sort?: Sorting; /** Whether to include the vector of the object in the response. If using named vectors, pass an array of strings to include only specific vectors. */ - includeVector?: boolean | string[]; + includeVector?: IncludeVector; /** Which metadata of the object to return. If not specified, no metadata is returned. */ returnMetadata?: QueryMetadata; /** @@ -52,7 +52,7 @@ export type FetchObjectsOptions = { }; /** Base options available to all the query methods that involve searching. */ -export type SearchOptions = { +export type SearchOptions = { /** How many objects to return in the query */ limit?: number; /** How many objects to skip in the query. Incompatible with the `after` cursor */ @@ -64,7 +64,7 @@ export type SearchOptions = { /** How to rerank the query results. Requires a configured [reranking](https://weaviate.io/developers/weaviate/concepts/reranking) module. */ rerank?: RerankOptions; /** Whether to include the vector of the object in the response. If using named vectors, pass an array of strings to include only specific vectors. */ - includeVector?: boolean | string[]; + includeVector?: IncludeVector; /** Which metadata of the object to return. If not specified, no metadata is returned. */ returnMetadata?: QueryMetadata; /** @@ -90,19 +90,19 @@ export type Bm25SearchOptions = { }; /** Base options available in the `query.bm25` method */ -export type BaseBm25Options = SearchOptions & Bm25SearchOptions; +export type BaseBm25Options = SearchOptions & Bm25SearchOptions; /** Options available in the `query.bm25` method when specifying the `groupBy` parameter. */ -export type GroupByBm25Options = BaseBm25Options & { +export type GroupByBm25Options = BaseBm25Options & { /** The group by options to apply to the search. */ groupBy: GroupByOptions; }; /** Options available in the `query.bm25` method */ -export type Bm25Options = BaseBm25Options | GroupByBm25Options | undefined; +export type Bm25Options = BaseBm25Options | GroupByBm25Options | undefined; /** Options available to the hybrid search type only */ -export type HybridSearchOptions = { +export type HybridSearchOptions = { /** The weight of the BM25 score. If not specified, the default weight specified by the server is used. */ alpha?: number; /** The type of fusion to apply. If not specified, the default fusion type specified by the server is used. */ @@ -112,13 +112,13 @@ export type HybridSearchOptions = { /** The properties to search in. If not specified, all properties are searched. */ queryProperties?: (PrimitiveKeys | Bm25QueryProperty)[]; /** Specify which vector(s) to search on if using named vectors. */ - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; /** The specific vector to search for or a specific vector subsearch. If not specified, the query is vectorized and used in the similarity search. */ vector?: NearVectorInputType | HybridNearTextSubSearch | HybridNearVectorSubSearch; }; /** Base options available in the `query.hybrid` method */ -export type BaseHybridOptions = SearchOptions & HybridSearchOptions; +export type BaseHybridOptions = SearchOptions & HybridSearchOptions; export type HybridSubSearchBase = { certainty?: number; @@ -136,28 +136,28 @@ export type HybridNearVectorSubSearch = HybridSubSearchBase & { }; /** Options available in the `query.hybrid` method when specifying the `groupBy` parameter. */ -export type GroupByHybridOptions = BaseHybridOptions & { +export type GroupByHybridOptions = BaseHybridOptions & { /** The group by options to apply to the search. */ groupBy: GroupByOptions; }; /** Options available in the `query.hybrid` method */ -export type HybridOptions = BaseHybridOptions | GroupByHybridOptions | undefined; +export type HybridOptions = BaseHybridOptions | GroupByHybridOptions | undefined; -export type NearSearchOptions = { +export type NearSearchOptions = { /** The minimum similarity score to return. Incompatible with the `distance` param. */ certainty?: number; /** The maximum distance to search. Incompatible with the `certainty` param. */ distance?: number; /** Specify which vector to search on if using named vectors. */ - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; }; /** Base options for the near search queries. */ -export type BaseNearOptions = SearchOptions & NearSearchOptions; +export type BaseNearOptions = SearchOptions & NearSearchOptions; /** Options available in the near search queries when specifying the `groupBy` parameter. */ -export type GroupByNearOptions = BaseNearOptions & { +export type GroupByNearOptions = BaseNearOptions & { /** The group by options to apply to the search. */ groupBy: GroupByOptions; }; @@ -170,24 +170,28 @@ export type MoveOptions = { }; /** Base options for the `query.nearText` method. */ -export type BaseNearTextOptions = BaseNearOptions & { +export type BaseNearTextOptions = BaseNearOptions & { moveTo?: MoveOptions; moveAway?: MoveOptions; }; /** Options available in the near text search queries when specifying the `groupBy` parameter. */ -export type GroupByNearTextOptions = BaseNearTextOptions & { +export type GroupByNearTextOptions = BaseNearTextOptions & { groupBy: GroupByOptions; }; /** The type of the media to search for in the `query.nearMedia` method */ export type NearMediaType = 'audio' | 'depth' | 'image' | 'imu' | 'thermal' | 'video'; +export type SingleVectorType = number[]; + +export type MultiVectorType = number[][]; + /** The allowed types of primitive vectors as stored in Weaviate. * * These correspond to 1-dimensional vectors, created by modules named `x2vec-`, and 2-dimensional vectors, created by modules named `x2colbert-`. */ -export type PrimitiveVectorType = number[] | number[][]; +export type PrimitiveVectorType = SingleVectorType | MultiVectorType; export type ListOfVectors = { kind: 'listOfVectors'; @@ -203,7 +207,7 @@ export type ListOfVectors = { */ export type NearVectorInputType = | PrimitiveVectorType - | Record | ListOfVectors>; + | Record | ListOfVectors>; /** * Over which vector spaces to perform the vector search query in the `nearX` search method. One of: @@ -211,9 +215,11 @@ export type NearVectorInputType = * - a multi-vector space search, in which case pass an array of strings with the names of the vector spaces to search in. * - a weighted multi-vector space search, in which case pass an object of type `MultiTargetVectorJoin` detailing the vector spaces to search in. */ -export type TargetVectorInputType = string | string[] | MultiTargetVectorJoin; +export type TargetVectorInputType = TargetVector | TargetVector[] | MultiTargetVectorJoin; + +export type TargetVector = V extends undefined ? string : keyof V & string; -interface Bm25 { +interface Bm25 { /** * Search for objects in this collection using the keyword-based BM25 algorithm. * @@ -222,10 +228,10 @@ interface Bm25 { * This overload is for performing a search without the `groupBy` param. * * @param {string} query - The query to search for. - * @param {BaseBm25Options} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseBm25Options} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - bm25(query: string, opts?: BaseBm25Options): Promise>; + bm25(query: string, opts?: BaseBm25Options): Promise>; /** * Search for objects in this collection using the keyword-based BM25 algorithm. * @@ -234,10 +240,10 @@ interface Bm25 { * This overload is for performing a search with the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GroupByBm25Options} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {GroupByBm25Options} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - bm25(query: string, opts: GroupByBm25Options): Promise>; + bm25(query: string, opts: GroupByBm25Options): Promise>; /** * Search for objects in this collection using the keyword-based BM25 algorithm. * @@ -246,13 +252,13 @@ interface Bm25 { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} query - The query to search for. - * @param {Bm25Options} [opts] - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {Bm25Options} [opts] - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - bm25(query: string, opts?: Bm25Options): QueryReturn; + bm25(query: string, opts?: Bm25Options): QueryReturn; } -interface Hybrid { +interface Hybrid { /** * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -261,10 +267,10 @@ interface Hybrid { * This overload is for performing a search without the `groupBy` param. * * @param {string} query - The query to search for in the BM25 keyword search.. - * @param {BaseHybridOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseHybridOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - hybrid(query: string, opts?: BaseHybridOptions): Promise>; + hybrid(query: string, opts?: BaseHybridOptions): Promise>; /** * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -273,10 +279,10 @@ interface Hybrid { * This overload is for performing a search with the `groupBy` param. * * @param {string} query - The query to search for in the BM25 keyword search.. - * @param {GroupByHybridOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {GroupByHybridOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - hybrid(query: string, opts: GroupByHybridOptions): Promise>; + hybrid(query: string, opts: GroupByHybridOptions): Promise>; /** * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -285,13 +291,13 @@ interface Hybrid { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} query - The query to search for in the BM25 keyword search.. - * @param {HybridOptions} [opts] - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {HybridOptions} [opts] - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - hybrid(query: string, opts?: HybridOptions): QueryReturn; + hybrid(query: string, opts?: HybridOptions): QueryReturn; } -interface NearImage { +interface NearImage { /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have an image-capable vectorization module installed in order to use this method, @@ -302,10 +308,10 @@ interface NearImage { * This overload is for performing a search without the `groupBy` param. * * @param {string | Buffer} image - The image to search on. This can be a base64 string, a file path string, or a buffer. - * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearImage(image: string | Buffer, opts?: BaseNearOptions): Promise>; + nearImage(image: string | Buffer, opts?: BaseNearOptions): Promise>; /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have an image-capable vectorization module installed in order to use this method, @@ -316,10 +322,10 @@ interface NearImage { * This overload is for performing a search with the `groupBy` param. * * @param {string | Buffer} image - The image to search on. This can be a base64 string, a file path string, or a buffer. - * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearImage(image: string | Buffer, opts: GroupByNearOptions): Promise>; + nearImage(image: string | Buffer, opts: GroupByNearOptions): Promise>; /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have an image-capable vectorization module installed in order to use this method, @@ -330,13 +336,13 @@ interface NearImage { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string | Buffer} image - The image to search on. This can be a base64 string, a file path string, or a buffer. - * @param {NearOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearImage(image: string | Buffer, opts?: NearOptions): QueryReturn; + nearImage(image: string | Buffer, opts?: NearOptions): QueryReturn; } -interface NearMedia { +interface NearMedia { /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind` or `multi2vec-palm`. @@ -347,14 +353,14 @@ interface NearMedia { * * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search for, e.g. 'audio'. - * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ nearMedia( media: string | Buffer, type: NearMediaType, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind` or `multi2vec-palm`. @@ -365,14 +371,14 @@ interface NearMedia { * * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search for, e.g. 'audio'. - * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ nearMedia( media: string | Buffer, type: NearMediaType, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind` or `multi2vec-palm`. @@ -383,13 +389,13 @@ interface NearMedia { * * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search for, e.g. 'audio'. - * @param {NearOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearMedia(media: string | Buffer, type: NearMediaType, opts?: NearOptions): QueryReturn; + nearMedia(media: string | Buffer, type: NearMediaType, opts?: NearOptions): QueryReturn; } -interface NearObject { +interface NearObject { /** * Search for objects in this collection by another object using a vector-based similarity search. * @@ -398,10 +404,10 @@ interface NearObject { * This overload is for performing a search without the `groupBy` param. * * @param {string} id - The UUID of the object to search for. - * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearObject(id: string, opts?: BaseNearOptions): Promise>; + nearObject(id: string, opts?: BaseNearOptions): Promise>; /** * Search for objects in this collection by another object using a vector-based similarity search. * @@ -410,10 +416,10 @@ interface NearObject { * This overload is for performing a search with the `groupBy` param. * * @param {string} id - The UUID of the object to search for. - * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearObject(id: string, opts: GroupByNearOptions): Promise>; + nearObject(id: string, opts: GroupByNearOptions): Promise>; /** * Search for objects in this collection by another object using a vector-based similarity search. * @@ -422,13 +428,13 @@ interface NearObject { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {number[]} id - The UUID of the object to search for. - * @param {NearOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearObject(id: string, opts?: NearOptions): QueryReturn; + nearObject(id: string, opts?: NearOptions): QueryReturn; } -interface NearText { +interface NearText { /** * Search for objects in this collection by text using text-capable vectorization module and vector-based similarity search. * You must have a text-capable vectorization module installed in order to use this method, @@ -439,10 +445,10 @@ interface NearText { * This overload is for performing a search without the `groupBy` param. * * @param {string | string[]} query - The text query to search for. - * @param {BaseNearTextOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearTextOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearText(query: string | string[], opts?: BaseNearTextOptions): Promise>; + nearText(query: string | string[], opts?: BaseNearTextOptions): Promise>; /** * Search for objects in this collection by text using text-capable vectorization module and vector-based similarity search. * You must have a text-capable vectorization module installed in order to use this method, @@ -453,10 +459,10 @@ interface NearText { * This overload is for performing a search with the `groupBy` param. * * @param {string | string[]} query - The text query to search for. - * @param {GroupByNearTextOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearTextOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearText(query: string | string[], opts: GroupByNearTextOptions): Promise>; + nearText(query: string | string[], opts: GroupByNearTextOptions): Promise>; /** * Search for objects in this collection by text using text-capable vectorization module and vector-based similarity search. * You must have a text-capable vectorization module installed in order to use this method, @@ -467,13 +473,13 @@ interface NearText { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string | string[]} query - The text query to search for. - * @param {NearTextOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearTextOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearText(query: string | string[], opts?: NearTextOptions): QueryReturn; + nearText(query: string | string[], opts?: NearTextOptions): QueryReturn; } -interface NearVector { +interface NearVector { /** * Search for objects by vector in this collection using a vector-based similarity search. * @@ -482,10 +488,10 @@ interface NearVector { * This overload is for performing a search without the `groupBy` param. * * @param {NearVectorInputType} vector - The vector(s) to search on. - * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearVector(vector: NearVectorInputType, opts?: BaseNearOptions): Promise>; + nearVector(vector: NearVectorInputType, opts?: BaseNearOptions): Promise>; /** * Search for objects by vector in this collection using a vector-based similarity search. * @@ -494,10 +500,10 @@ interface NearVector { * This overload is for performing a search with the `groupBy` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearVector(vector: NearVectorInputType, opts: GroupByNearOptions): Promise>; + nearVector(vector: NearVectorInputType, opts: GroupByNearOptions): Promise>; /** * Search for objects by vector in this collection using a vector-based similarity search. * @@ -506,43 +512,43 @@ interface NearVector { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {NearOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearVector(vector: NearVectorInputType, opts?: NearOptions): QueryReturn; + nearVector(vector: NearVectorInputType, opts?: NearOptions): QueryReturn; } /** All the available methods on the `.query` namespace. */ -export interface Query - extends Bm25, - Hybrid, - NearImage, - NearMedia, - NearObject, - NearText, - NearVector { +export interface Query + extends Bm25, + Hybrid, + NearImage, + NearMedia, + NearObject, + NearText, + NearVector { /** * Retrieve an object from the server by its UUID. * * @param {string} id - The UUID of the object to retrieve. * @param {FetchObjectByIdOptions} [opts] - The available options for fetching the object. - * @returns {Promise | null>} - The object with the given UUID, or null if it does not exist. + * @returns {Promise | null>} - The object with the given UUID, or null if it does not exist. */ - fetchObjectById: (id: string, opts?: FetchObjectByIdOptions) => Promise | null>; + fetchObjectById: (id: string, opts?: FetchObjectByIdOptions) => Promise | null>; /** * Retrieve objects from the server without searching. * * @param {FetchObjectsOptions} [opts] - The available options for fetching the objects. - * @returns {Promise>} - The objects within the fetched collection. + * @returns {Promise>} - The objects within the fetched collection. */ - fetchObjects: (opts?: FetchObjectsOptions) => Promise>; + fetchObjects: (opts?: FetchObjectsOptions) => Promise>; } /** Options available in the `query.nearImage`, `query.nearMedia`, `query.nearObject`, and `query.nearVector` methods */ -export type NearOptions = BaseNearOptions | GroupByNearOptions | undefined; +export type NearOptions = BaseNearOptions | GroupByNearOptions | undefined; /** Options available in the `query.nearText` method */ -export type NearTextOptions = BaseNearTextOptions | GroupByNearTextOptions | undefined; +export type NearTextOptions = BaseNearTextOptions | GroupByNearTextOptions | undefined; /** The return type of the `query` methods. It is a union of a standard query and a group by query due to function overloading. */ -export type QueryReturn = Promise> | Promise>; +export type QueryReturn = Promise> | Promise>; diff --git a/src/collections/query/utils.ts b/src/collections/query/utils.ts index fb986368..912d8e46 100644 --- a/src/collections/query/utils.ts +++ b/src/collections/query/utils.ts @@ -1,32 +1,42 @@ -import { MultiTargetVectorJoin } from '../index.js'; -import { ListOfVectors, NearVectorInputType, PrimitiveVectorType, TargetVectorInputType } from './types.js'; +import { MultiTargetVectorJoin, Vectors } from '../index.js'; +import { + ListOfVectors, + MultiVectorType, + NearVectorInputType, + PrimitiveVectorType, + SingleVectorType, + TargetVectorInputType, +} from './types.js'; export class NearVectorInputGuards { - public static is1D(input: NearVectorInputType): input is number[] { + public static is1D(input: NearVectorInputType): input is SingleVectorType { return Array.isArray(input) && input.length > 0 && !Array.isArray(input[0]); } - public static is2D(input: NearVectorInputType): input is number[][] { + public static is2D(input: NearVectorInputType): input is MultiVectorType { return Array.isArray(input) && input.length > 0 && Array.isArray(input[0]) && input[0].length > 0; } public static isObject( input: NearVectorInputType - ): input is Record | ListOfVectors> { + ): input is Record< + string, + PrimitiveVectorType | ListOfVectors | ListOfVectors + > { return !Array.isArray(input); } public static isListOf1D( - input: PrimitiveVectorType | ListOfVectors | ListOfVectors - ): input is ListOfVectors { - const i = input as ListOfVectors; + input: PrimitiveVectorType | ListOfVectors | ListOfVectors + ): input is ListOfVectors { + const i = input as ListOfVectors; return !Array.isArray(input) && i.kind === 'listOfVectors' && i.dimensionality == '1D'; } public static isListOf2D( - input: PrimitiveVectorType | ListOfVectors | ListOfVectors - ): input is ListOfVectors { - const i = input as ListOfVectors; + input: PrimitiveVectorType | ListOfVectors | ListOfVectors + ): input is ListOfVectors { + const i = input as ListOfVectors; return !Array.isArray(input) && i.kind === 'listOfVectors' && i.dimensionality == '2D'; } } @@ -41,16 +51,16 @@ export class ArrayInputGuards { } export class TargetVectorInputGuards { - public static isSingle(input: TargetVectorInputType): input is string { + public static isSingle(input: TargetVectorInputType): input is string { return typeof input === 'string'; } - public static isMulti(input: TargetVectorInputType): input is string[] { + public static isMulti(input: TargetVectorInputType): input is string[] { return Array.isArray(input); } - public static isMultiJoin(input: TargetVectorInputType): input is MultiTargetVectorJoin { - const i = input as MultiTargetVectorJoin; + public static isMultiJoin(input: TargetVectorInputType): input is MultiTargetVectorJoin { + const i = input as MultiTargetVectorJoin; return i.combination !== undefined && i.targetVectors !== undefined; } } diff --git a/src/collections/references/classes.ts b/src/collections/references/classes.ts index 0200b7dc..b762d336 100644 --- a/src/collections/references/classes.ts +++ b/src/collections/references/classes.ts @@ -1,13 +1,19 @@ -import { Properties, ReferenceInput, ReferenceToMultiTarget, WeaviateObject } from '../types/index.js'; +import { + Properties, + ReferenceInput, + ReferenceToMultiTarget, + Vectors, + WeaviateObject, +} from '../types/index.js'; import { Beacon } from './types.js'; import { uuidToBeacon } from './utils.js'; export class ReferenceManager { - public objects: WeaviateObject[]; + public objects: WeaviateObject[]; public targetCollection: string; public uuids?: string[]; - constructor(targetCollection: string, objects?: WeaviateObject[], uuids?: string[]) { + constructor(targetCollection: string, objects?: WeaviateObject[], uuids?: string[]) { this.objects = objects ?? []; this.targetCollection = targetCollection; this.uuids = uuids; diff --git a/src/collections/serialize/index.ts b/src/collections/serialize/index.ts index 961d8b70..9c9181e3 100644 --- a/src/collections/serialize/index.ts +++ b/src/collections/serialize/index.ts @@ -87,6 +87,7 @@ import { Vectors as VectorsGrpc, Vectors_VectorType, } from '../../proto/v1/base.js'; +import { yieldToEventLoop } from '../../utils/yield.js'; import { FilterId } from '../filters/classes.js'; import { FilterValue, Filters } from '../filters/index.js'; import { @@ -119,10 +120,14 @@ import { HybridNearVectorSubSearch, HybridOptions, HybridSearchOptions, + ListOfVectors, + MultiVectorType, NearOptions, NearTextOptions, NearVectorInputType, + PrimitiveVectorType, SearchOptions, + SingleVectorType, TargetVectorInputType, } from '../query/types.js'; import { ArrayInputGuards, NearVectorInputGuards, TargetVectorInputGuards } from '../query/utils.js'; @@ -399,9 +404,9 @@ class Aggregate { }); }; - public static hybrid = async ( + public static hybrid = async ( query: string, - opts?: AggregateHybridOptions> + opts?: AggregateHybridOptions, V> ): Promise => { return { ...Aggregate.common(opts), @@ -417,9 +422,9 @@ class Aggregate { }; }; - public static nearImage = ( + public static nearImage = ( image: string, - opts?: AggregateNearOptions> + opts?: AggregateNearOptions, V> ): AggregateNearImageArgs => { return { ...Aggregate.common(opts), @@ -433,9 +438,9 @@ class Aggregate { }; }; - public static nearObject = ( + public static nearObject = ( id: string, - opts?: AggregateNearOptions> + opts?: AggregateNearOptions, V> ): AggregateNearObjectArgs => { return { ...Aggregate.common(opts), @@ -449,9 +454,9 @@ class Aggregate { }; }; - public static nearText = ( + public static nearText = ( query: string | string[], - opts?: AggregateNearOptions> + opts?: AggregateNearOptions, V> ): AggregateNearTextArgs => { return { ...Aggregate.common(opts), @@ -465,9 +470,9 @@ class Aggregate { }; }; - public static nearVector = async ( + public static nearVector = async ( vector: NearVectorInputType, - opts?: AggregateNearOptions> + opts?: AggregateNearOptions, V> ): Promise => { return { ...Aggregate.common(opts), @@ -605,7 +610,7 @@ class Search { return args.groupBy !== undefined; }; - private static common = (args?: SearchOptions): BaseSearchArgs => { + private static common = (args?: SearchOptions): BaseSearchArgs => { const out: BaseSearchArgs = { autocut: args?.autoLimit, limit: args?.limit, @@ -623,15 +628,15 @@ class Search { return out; }; - public static bm25 = (query: string, opts?: Bm25Options): SearchBm25Args => { + public static bm25 = (query: string, opts?: Bm25Options): SearchBm25Args => { return { ...Search.common(opts), bm25Search: Serialize.bm25Search({ query, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static fetchObjects = (args?: FetchObjectsOptions): SearchFetchArgs => { + public static fetchObjects = (args?: FetchObjectsOptions): SearchFetchArgs => { return { ...Search.common(args), after: args?.after, @@ -639,7 +644,9 @@ class Search { }; }; - public static fetchObjectById = (args: { id: string } & FetchObjectByIdOptions): SearchFetchArgs => { + public static fetchObjectById = ( + args: { id: string } & FetchObjectByIdOptions + ): SearchFetchArgs => { return Search.common({ filters: new FilterId().equal(args.id), includeVector: args.includeVector, @@ -649,7 +656,7 @@ class Search { }); }; - public static hybrid = async ( + public static hybrid = async ( args: { query: string; supportsTargets: boolean; @@ -657,113 +664,113 @@ class Search { supportsWeightsForTargets: boolean; supportsVectors: boolean; }, - opts?: HybridOptions + opts?: HybridOptions ): Promise => { return { - ...Search.common(opts), + ...Search.common(opts), hybridSearch: await Serialize.hybridSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearAudio = ( + public static nearAudio = ( args: { audio: string; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearOptions + opts?: NearOptions ): SearchNearAudioArgs => { return { ...Search.common(opts), nearAudio: Serialize.nearAudioSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearDepth = ( + public static nearDepth = ( args: { depth: string; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearOptions + opts?: NearOptions ): SearchNearDepthArgs => { return { ...Search.common(opts), nearDepth: Serialize.nearDepthSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearImage = ( + public static nearImage = ( args: { image: string; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearOptions + opts?: NearOptions ): SearchNearImageArgs => { return { ...Search.common(opts), nearImage: Serialize.nearImageSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearIMU = ( + public static nearIMU = ( args: { imu: string; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearOptions + opts?: NearOptions ): SearchNearIMUArgs => { return { ...Search.common(opts), nearIMU: Serialize.nearIMUSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearObject = ( + public static nearObject = ( args: { id: string; supportsTargets: boolean; supportsWeightsForTargets: boolean }, - opts?: NearOptions + opts?: NearOptions ): SearchNearObjectArgs => { return { ...Search.common(opts), nearObject: Serialize.nearObjectSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearText = ( + public static nearText = ( args: { query: string | string[]; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearTextOptions + opts?: NearTextOptions ): SearchNearTextArgs => { return { ...Search.common(opts), nearText: Serialize.nearTextSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearThermal = ( + public static nearThermal = ( args: { thermal: string; supportsTargets: boolean; supportsWeightsForTargets: boolean }, - opts?: NearOptions + opts?: NearOptions ): SearchNearThermalArgs => { return { ...Search.common(opts), nearThermal: Serialize.nearThermalSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearVector = async ( + public static nearVector = async ( args: { vector: NearVectorInputType; supportsTargets: boolean; @@ -771,22 +778,22 @@ class Search { supportsWeightsForTargets: boolean; supportsVectors: boolean; }, - opts?: NearOptions + opts?: NearOptions ): Promise => { return { ...Search.common(opts), nearVector: await Serialize.nearVectorSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearVideo = ( + public static nearVideo = ( args: { video: string; supportsTargets: boolean; supportsWeightsForTargets: boolean }, - opts?: NearOptions + opts?: NearOptions ): SearchNearVideoArgs => { return { ...Search.common(opts), nearVideo: Serialize.nearVideoSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; } @@ -795,15 +802,15 @@ export class Serialize { static aggregate = Aggregate; static search = Search; - public static isNamedVectors = (opts?: BaseNearOptions): boolean => { + public static isNamedVectors = (opts?: BaseNearOptions): boolean => { return Array.isArray(opts?.includeVector) || opts?.targetVector !== undefined; }; - public static isMultiTarget = (opts?: BaseNearOptions): boolean => { + public static isMultiTarget = (opts?: BaseNearOptions): boolean => { return opts?.targetVector !== undefined && !TargetVectorInputGuards.isSingle(opts.targetVector); }; - public static isMultiWeightPerTarget = (opts?: BaseNearOptions): boolean => { + public static isMultiWeightPerTarget = (opts?: BaseNearOptions): boolean => { return ( opts?.targetVector !== undefined && TargetVectorInputGuards.isMultiJoin(opts.targetVector) && @@ -851,9 +858,14 @@ export class Serialize { }); }; - public static isHybridVectorSearch = ( - vector: BaseHybridOptions['vector'] - ): vector is number[] | Record => { + public static isHybridVectorSearch = ( + vector: BaseHybridOptions['vector'] + ): vector is + | PrimitiveVectorType + | Record< + string, + PrimitiveVectorType | ListOfVectors | ListOfVectors + > => { return ( vector !== undefined && !Serialize.isHybridNearTextSearch(vector) && @@ -861,28 +873,28 @@ export class Serialize { ); }; - public static isHybridNearTextSearch = ( - vector: BaseHybridOptions['vector'] + public static isHybridNearTextSearch = ( + vector: BaseHybridOptions['vector'] ): vector is HybridNearTextSubSearch => { return (vector as HybridNearTextSubSearch)?.query !== undefined; }; - public static isHybridNearVectorSearch = ( - vector: BaseHybridOptions['vector'] + public static isHybridNearVectorSearch = ( + vector: BaseHybridOptions['vector'] ): vector is HybridNearVectorSubSearch => { return (vector as HybridNearVectorSubSearch)?.vector !== undefined; }; - private static hybridVector = async (args: { + private static hybridVector = async (args: { supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; supportsVectors: boolean; - vector?: BaseHybridOptions['vector']; + vector?: BaseHybridOptions['vector']; }) => { const vector = args.vector; if (Serialize.isHybridVectorSearch(vector)) { - const { targets, targetVectors, vectorBytes, vectorPerTarget, vectorForTargets } = + const { targets, targetVectors, vectorBytes, vectorPerTarget, vectorForTargets, vectors } = await Serialize.vectors({ ...args, argumentName: 'vector', @@ -893,10 +905,14 @@ export class Serialize { : { targetVectors, targets, - nearVector: NearVector.fromPartial({ - vectorForTargets, - vectorPerTarget, - }), + nearVector: + vectorForTargets != undefined || vectorPerTarget != undefined + ? NearVector.fromPartial({ + vectorForTargets, + vectorPerTarget, + }) + : undefined, + vectors, }; } else if (Serialize.isHybridNearTextSearch(vector)) { const { targetVectors, targets } = Serialize.targetVector(args); @@ -912,7 +928,7 @@ export class Serialize { }), }; } else if (Serialize.isHybridNearVectorSearch(vector)) { - const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets } = + const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets, vectors } = await Serialize.vectors({ ...args, argumentName: 'vector', @@ -927,6 +943,7 @@ export class Serialize { vectorBytes, vectorPerTarget, vectorForTargets, + vectors, }), }; } else { @@ -935,14 +952,14 @@ export class Serialize { } }; - public static hybridSearch = async ( + public static hybridSearch = async ( args: { query: string; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; supportsVectors: boolean; - } & HybridSearchOptions + } & HybridSearchOptions ): Promise => { const fusionType = (fusionType?: string): Hybrid_FusionType => { switch (fusionType) { @@ -954,7 +971,8 @@ export class Serialize { return Hybrid_FusionType.FUSION_TYPE_UNSPECIFIED; } }; - const { targets, targetVectors, vectorBytes, nearText, nearVector } = await Serialize.hybridVector(args); + const { targets, targetVectors, vectorBytes, nearText, nearVector, vectors } = + await Serialize.hybridVector(args); return Hybrid.fromPartial({ query: args.query, alpha: args.alpha ? args.alpha : 0.5, @@ -966,11 +984,12 @@ export class Serialize { targets, nearText, nearVector, + vectors, }); }; - public static nearAudioSearch = ( - args: { audio: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearAudioSearch = ( + args: { audio: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions ): NearAudioSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearAudioSearch.fromPartial({ @@ -982,8 +1001,8 @@ export class Serialize { }); }; - public static nearDepthSearch = ( - args: { depth: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearDepthSearch = ( + args: { depth: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions ): NearDepthSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearDepthSearch.fromPartial({ @@ -995,8 +1014,8 @@ export class Serialize { }); }; - public static nearImageSearch = ( - args: { image: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearImageSearch = ( + args: { image: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions ): NearImageSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearImageSearch.fromPartial({ @@ -1008,8 +1027,8 @@ export class Serialize { }); }; - public static nearIMUSearch = ( - args: { imu: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearIMUSearch = ( + args: { imu: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions ): NearIMUSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearIMUSearch.fromPartial({ @@ -1021,8 +1040,8 @@ export class Serialize { }); }; - public static nearObjectSearch = ( - args: { id: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearObjectSearch = ( + args: { id: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions ): NearObject => { const { targets, targetVectors } = Serialize.targetVector(args); return NearObject.fromPartial({ @@ -1034,11 +1053,11 @@ export class Serialize { }); }; - public static nearTextSearch = (args: { + public static nearTextSearch = (args: { query: string | string[]; supportsTargets: boolean; supportsWeightsForTargets: boolean; - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; certainty?: number; distance?: number; moveAway?: { concepts?: string[]; force?: number; objects?: string[] }; @@ -1068,8 +1087,11 @@ export class Serialize { }); }; - public static nearThermalSearch = ( - args: { thermal: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearThermalSearch = ( + args: { thermal: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions< + T, + V + > ): NearThermalSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearThermalSearch.fromPartial({ @@ -1114,7 +1136,7 @@ export class Serialize { dv.setUint16(uint16Len, vectors.length, true); await Promise.all( vectors.map((vector, i) => - new Promise((resolve) => setTimeout(resolve, 0)).then(() => + yieldToEventLoop().then(() => vector.forEach((v, j) => dv.setFloat32(uint16Len + i * dim * uint32len + j * uint32len, v, true)) ) ) @@ -1123,7 +1145,7 @@ export class Serialize { return new Uint8Array(dv.buffer); }; - public static nearVectorSearch = async (args: { + public static nearVectorSearch = async (args: { vector: NearVectorInputType; supportsTargets: boolean; supportsVectorsForTargets: boolean; @@ -1131,28 +1153,21 @@ export class Serialize { supportsVectors: boolean; certainty?: number; distance?: number; - targetVector?: TargetVectorInputType; - }): Promise => { - const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets } = - await Serialize.vectors({ - ...args, - argumentName: 'nearVector', - }); - return NearVector.fromPartial({ + targetVector?: TargetVectorInputType; + }): Promise => + NearVector.fromPartial({ certainty: args.certainty, distance: args.distance, - targetVectors, - targets, - vectorPerTarget, - vectorBytes, - vectorForTargets, + ...(await Serialize.vectors({ + ...args, + argumentName: 'nearVector', + })), }); - }; - public static targetVector = (args: { + public static targetVector = (args: { supportsTargets: boolean; supportsWeightsForTargets: boolean; - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; }): { targets?: Targets; targetVectors?: string[] } => { if (args.targetVector === undefined) { return {}; @@ -1177,13 +1192,13 @@ export class Serialize { } }; - static vectors = async (args: { + static vectors = async (args: { supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; supportsVectors: boolean; argumentName: 'nearVector' | 'vector'; - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; vector?: NearVectorInputType; }): Promise<{ targetVectors?: string[]; @@ -1242,11 +1257,6 @@ export class Serialize { } const vectorForTargets: VectorForTarget[] = []; for (const [target, vector] of Object.entries(args.vector)) { - const vectorForTarget: VectorForTarget = { - name: target, - vectorBytes: new Uint8Array(), - vectors: [], - }; if (!args.supportsVectors) { if (NearVectorInputGuards.isListOf2D(vector)) { throw new WeaviateUnsupportedFeatureError( @@ -1270,15 +1280,13 @@ export class Serialize { continue; } vectorForTargets.push({ name: target, vectorBytes: Serialize.vectorToBytes(vector), vectors: [] }); + continue; } - if (ArrayInputGuards.is2DArray(vector)) { - vectorForTarget.vectors.push( - Vectors.fromPartial({ - type: Vectors_VectorType.VECTOR_TYPE_MULTI_FP32, - vectorBytes: await Serialize.vectorsToBytes(vector), // eslint-disable-line no-await-in-loop - }) - ); - } + const vectorForTarget: VectorForTarget = { + name: target, + vectorBytes: new Uint8Array(), + vectors: [], + }; if (NearVectorInputGuards.isListOf1D(vector)) { vectorForTarget.vectors.push( Vectors.fromPartial({ @@ -1286,8 +1294,7 @@ export class Serialize { vectorBytes: await Serialize.vectorsToBytes(vector.vectors), // eslint-disable-line no-await-in-loop }) ); - } - if (NearVectorInputGuards.isListOf2D(vector)) { + } else if (NearVectorInputGuards.isListOf2D(vector)) { for (const v of vector.vectors) { vectorForTarget.vectors.push( Vectors.fromPartial({ @@ -1296,9 +1303,22 @@ export class Serialize { }) ); } - vectorForTargets.push(vectorForTarget); - continue; + } else if (ArrayInputGuards.is2DArray(vector)) { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_MULTI_FP32, + vectorBytes: await Serialize.vectorsToBytes(vector), // eslint-disable-line no-await-in-loop + }) + ); + } else { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, + vectorBytes: Serialize.vectorToBytes(vector), + }) + ); } + vectorForTargets.push(vectorForTarget); } return args.targetVector !== undefined ? { @@ -1332,6 +1352,9 @@ export class Serialize { }; } if (NearVectorInputGuards.is2D(args.vector)) { + if (!args.supportsVectors) { + throw new WeaviateUnsupportedFeatureError('Multi-vectors are not supported in Weaviate <1.29.0'); + } const { targetVectors, targets } = Serialize.targetVector(args); const vectorBytes = await Serialize.vectorsToBytes(args.vector); return { @@ -1343,8 +1366,8 @@ export class Serialize { throw invalidVectorError; }; - private static targets = ( - targets: MultiTargetVectorJoin, + private static targets = ( + targets: MultiTargetVectorJoin, supportsWeightsForTargets: boolean ): { combination: CombinationMethod; @@ -1377,7 +1400,7 @@ export class Serialize { .map(([target, weight]) => { return { target, - weight, + weight: weight as number | number[], }; }) .reduce((acc, { target, weight }) => { @@ -1413,8 +1436,8 @@ export class Serialize { } }; - public static nearVideoSearch = ( - args: { video: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearVideoSearch = ( + args: { video: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions ): NearVideoSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearVideoSearch.fromPartial({ @@ -1824,11 +1847,20 @@ export class Serialize { let vectorBytes: Uint8Array | undefined; let vectors: VectorsGrpc[] | undefined; if (obj.vectors !== undefined && !Array.isArray(obj.vectors)) { - vectors = Object.entries(obj.vectors).map(([k, v]) => - VectorsGrpc.fromPartial({ - vectorBytes: Serialize.vectorToBytes(v), - name: k, - }) + vectors = Object.entries(obj.vectors).flatMap(([k, v]) => + NearVectorInputGuards.is1D(v) + ? [ + VectorsGrpc.fromPartial({ + vectorBytes: Serialize.vectorToBytes(v), + name: k, + }), + ] + : v.map((vv) => + VectorsGrpc.fromPartial({ + vectorBytes: Serialize.vectorToBytes(vv), + name: k, + }) + ) ); } else if (Array.isArray(obj.vectors) && requiresInsertFix) { vectors = [ diff --git a/src/collections/serialize/unit.test.ts b/src/collections/serialize/unit.test.ts index 721d1e46..7eb468a6 100644 --- a/src/collections/serialize/unit.test.ts +++ b/src/collections/serialize/unit.test.ts @@ -12,7 +12,7 @@ import { SearchNearVectorArgs, SearchNearVideoArgs, } from '../../grpc/searcher.js'; -import { Filters, Filters_Operator } from '../../proto/v1/base.js'; +import { Filters, Filters_Operator, Vectors, Vectors_VectorType } from '../../proto/v1/base.js'; import { BM25, CombinationMethod, @@ -143,13 +143,14 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for simple hybrid', () => { - const args = Serialize.search.hybrid( + it('should parse args for simple hybrid <1.29', async () => { + const args = await Serialize.search.hybrid( { query: 'test', supportsTargets: false, supportsVectorsForTargets: false, supportsWeightsForTargets: false, + supportsVectors: false, }, { queryProperties: ['name'], @@ -174,13 +175,53 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for multi-vector & multi-target hybrid', () => { - const args = Serialize.search.hybrid( + it('should parse args for simple hybrid >=1.29', async () => { + const args = await Serialize.search.hybrid( { query: 'test', supportsTargets: true, supportsVectorsForTargets: true, supportsWeightsForTargets: true, + supportsVectors: true, + }, + { + queryProperties: ['name'], + alpha: 0.6, + vector: [1, 2, 3], + targetVector: 'title', + fusionType: 'Ranked', + maxVectorDistance: 0.4, + } + ); + expect(args).toEqual({ + hybridSearch: Hybrid.fromPartial({ + query: 'test', + properties: ['name'], + alpha: 0.6, + vectors: [ + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, + vectorBytes: new Uint8Array(new Float32Array([1, 2, 3]).buffer), + }), + ], + targets: { + targetVectors: ['title'], + }, + fusionType: Hybrid_FusionType.FUSION_TYPE_RANKED, + vectorDistance: 0.4, + }), + metadata: MetadataRequest.fromPartial({ uuid: true }), + }); + }); + + it('should parse args for multi-vector & multi-target hybrid', async () => { + const args = await Serialize.search.hybrid( + { + query: 'test', + supportsTargets: true, + supportsVectorsForTargets: true, + supportsWeightsForTargets: true, + supportsVectors: false, }, { queryProperties: ['name'], @@ -364,12 +405,13 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for nearVector with single vector', () => { - const args = Serialize.search.nearVector({ + it('should parse args for nearVector with single vector <1.29', async () => { + const args = await Serialize.search.nearVector({ vector: [1, 2, 3], supportsTargets: false, supportsVectorsForTargets: false, supportsWeightsForTargets: false, + supportsVectors: false, }); expect(args).toEqual({ nearVector: NearVector.fromPartial({ @@ -379,8 +421,29 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for nearVector with two named vectors and supportsTargets (<1.27.0)', () => { - const args = Serialize.search.nearVector({ + it('should parse args for nearVector with single vector >=1.29', async () => { + const args = await Serialize.search.nearVector({ + vector: [1, 2, 3], + supportsTargets: false, + supportsVectorsForTargets: false, + supportsWeightsForTargets: false, + supportsVectors: true, + }); + expect(args).toEqual({ + nearVector: NearVector.fromPartial({ + vectors: [ + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, + vectorBytes: new Uint8Array(new Float32Array([1, 2, 3]).buffer), + }), + ], + }), + metadata: MetadataRequest.fromPartial({ uuid: true }), + }); + }); + + it('should parse args for nearVector with two named vectors and supportsTargets (<1.27.0)', async () => { + const args = await Serialize.search.nearVector({ vector: { a: [1, 2, 3], b: [4, 5, 6], @@ -388,6 +451,7 @@ describe('Unit testing of Serialize', () => { supportsTargets: true, supportsVectorsForTargets: false, supportsWeightsForTargets: false, + supportsVectors: false, }); expect(args).toEqual({ nearVector: NearVector.fromPartial({ @@ -401,8 +465,8 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for nearVector with two named vectors and all supports (==1.27.x)', () => { - const args = Serialize.search.nearVector({ + it('should parse args for nearVector with two named vectors and all supports (==1.27.x)', async () => { + const args = await Serialize.search.nearVector({ vector: { a: [ [1, 2, 3], @@ -413,6 +477,7 @@ describe('Unit testing of Serialize', () => { supportsTargets: true, supportsVectorsForTargets: true, supportsWeightsForTargets: true, + supportsVectors: false, }); expect(args).toEqual({ nearVector: NearVector.fromPartial({ @@ -656,7 +721,7 @@ describe('Unit testing of Serialize', () => { }; type Test = { name: string; - targetVector: TargetVectorInputType; + targetVector: TargetVectorInputType; supportsTargets: boolean; supportsWeightsForTargets: boolean; out: Out; @@ -682,7 +747,7 @@ describe('Unit testing of Serialize', () => { }, { name: 'should parse MultiTargetJoin sum', - targetVector: multiTargetVector().average(['a', 'b']), + targetVector: multiTargetVector().average(['a', 'b']), supportsTargets: true, supportsWeightsForTargets: false, out: { @@ -724,7 +789,7 @@ describe('Unit testing of Serialize', () => { }, { name: 'should parse MultiTargetJoin minimum', - targetVector: multiTargetVector().minimum(['a', 'b']), + targetVector: multiTargetVector().minimum(['a', 'b']), supportsTargets: true, supportsWeightsForTargets: false, out: { @@ -736,7 +801,7 @@ describe('Unit testing of Serialize', () => { }, { name: 'should parse MultiTargetJoin sum', - targetVector: multiTargetVector().average(['a', 'b']), + targetVector: multiTargetVector().average(['a', 'b']), supportsTargets: true, supportsWeightsForTargets: false, out: { @@ -778,7 +843,7 @@ describe('Unit testing of Serialize', () => { }, { name: 'should parse MultiTargetJoin sum', - targetVector: multiTargetVector().sum(['a', 'b']), + targetVector: multiTargetVector().sum(['a', 'b']), supportsTargets: true, supportsWeightsForTargets: false, out: { diff --git a/src/collections/types/generate.ts b/src/collections/types/generate.ts index b3f6bac2..82907cd1 100644 --- a/src/collections/types/generate.ts +++ b/src/collections/types/generate.ts @@ -1,6 +1,6 @@ import { GroupByObject, GroupByResult, WeaviateGenericObject, WeaviateNonGenericObject } from './query.js'; -export type GenerativeGenericObject = WeaviateGenericObject & { +export type GenerativeGenericObject = WeaviateGenericObject & { /** The LLM-generated output applicable to this single object. */ generated?: string; }; @@ -15,28 +15,32 @@ export type GenerativeNonGenericObject = WeaviateNonGenericObject & { * Depending on the generic type `T`, the object will have subfields that map from `T`'s specific type definition. * If not, then the object will be non-generic and have a `properties` field that maps from a generic string to a `WeaviateField`. */ -export type GenerativeObject = T extends undefined - ? GenerativeNonGenericObject - : GenerativeGenericObject; +export type GenerativeObject = T extends undefined + ? V extends undefined + ? GenerativeNonGenericObject + : GenerativeGenericObject + : V extends undefined + ? GenerativeGenericObject + : GenerativeGenericObject; /** The return of a query method in the `collection.generate` namespace. */ -export type GenerativeReturn = { +export type GenerativeReturn = { /** The objects that were found by the query. */ - objects: GenerativeObject[]; + objects: GenerativeObject[]; /** The LLM-generated output applicable to this query as a whole. */ generated?: string; }; -export type GenerativeGroupByResult = GroupByResult & { +export type GenerativeGroupByResult = GroupByResult & { generated?: string; }; /** The return of a query method in the `collection.generate` namespace where the `groupBy` argument was specified. */ -export type GenerativeGroupByReturn = { +export type GenerativeGroupByReturn = { /** The objects that were found by the query. */ - objects: GroupByObject[]; + objects: GroupByObject[]; /** The groups that were created by the query. */ - groups: Record>; + groups: Record>; /** The LLM-generated output applicable to this query as a whole. */ generated?: string; }; @@ -51,4 +55,4 @@ export type GenerateOptions = { groupedProperties?: T extends undefined ? string[] : (keyof T)[]; }; -export type GenerateReturn = Promise> | Promise>; +export type GenerateReturn = Promise> | Promise>; diff --git a/src/collections/types/internal.ts b/src/collections/types/internal.ts index 001c24ab..0227d769 100644 --- a/src/collections/types/internal.ts +++ b/src/collections/types/internal.ts @@ -40,6 +40,10 @@ export type QueryReference = T extends undefined ? RefPropertyDefault : RefPr export type NonRefProperty = keyof T | QueryNested; export type NonPrimitiveProperty = RefProperty | QueryNested; +export type QueryVector = V extends undefined ? string : keyof V & string; + +export type IncludeVector = boolean | QueryVector[] | undefined; + export type IsEmptyType = keyof T extends never ? true : false; export type ReferenceInput = diff --git a/src/collections/types/query.ts b/src/collections/types/query.ts index c356aa38..52def317 100644 --- a/src/collections/types/query.ts +++ b/src/collections/types/query.ts @@ -1,4 +1,5 @@ import { WeaviateField } from '../index.js'; +import { PrimitiveVectorType } from '../query/types.js'; import { CrossReferenceDefault } from '../references/index.js'; import { ExtractCrossReferenceType, @@ -26,7 +27,7 @@ export type QueryMetadata = 'all' | MetadataKeys | undefined; export type ReturnMetadata = Partial; -export type WeaviateGenericObject = { +export type WeaviateGenericObject = { /** The generic returned properties of the object derived from the type `T`. */ properties: ReturnProperties; /** The returned metadata of the object. */ @@ -36,7 +37,7 @@ export type WeaviateGenericObject = { /** The UUID of the object. */ uuid: string; /** The returned vectors of the object. */ - vectors: Vectors; + vectors: V; }; export type WeaviateNonGenericObject = { @@ -56,45 +57,58 @@ export type ReturnProperties = Pick>; export type ReturnReferences = Pick>; -export type Vectors = Record; +export interface Vectors { + [k: string]: PrimitiveVectorType; +} -export type ReturnVectors = V extends string[] - ? { [Key in V[number]]: number[] } - : Record; +export type ReturnVectors = I extends true + ? V + : I extends Array + ? Pick< + V, + { + [Key in keyof V]: Key extends U ? Key : never; + }[keyof V] + > + : never; /** An object belonging to a collection as returned by the methods in the `collection.query` namespace. * * Depending on the generic type `T`, the object will have subfields that map from `T`'s specific type definition. * If not, then the object will be non-generic and have a `properties` field that maps from a generic string to a `WeaviateField`. */ -export type WeaviateObject = T extends undefined // need this instead of Properties to avoid circular type reference - ? WeaviateNonGenericObject - : WeaviateGenericObject; +export type WeaviateObject = T extends undefined // need this instead of Properties to avoid circular type reference + ? V extends undefined + ? WeaviateNonGenericObject + : WeaviateGenericObject + : V extends undefined + ? WeaviateGenericObject + : WeaviateGenericObject; /** The return of a query method in the `collection.query` namespace. */ -export type WeaviateReturn = { +export type WeaviateReturn = { /** The objects that were found by the query. */ - objects: WeaviateObject[]; + objects: WeaviateObject[]; }; -export type GroupByObject = WeaviateObject & { +export type GroupByObject = WeaviateObject & { belongsToGroup: string; }; -export type GroupByResult = { +export type GroupByResult = { name: string; minDistance: number; maxDistance: number; numberOfObjects: number; - objects: WeaviateObject[]; + objects: WeaviateObject[]; }; /** The return of a query method in the `collection.query` namespace where the `groupBy` argument was specified. */ -export type GroupByReturn = { +export type GroupByReturn = { /** The objects that were found by the query. */ - objects: GroupByObject[]; + objects: GroupByObject[]; /** The groups that were created by the query. */ - groups: Record>; + groups: Record>; }; export type GroupByOptions = T extends undefined diff --git a/src/collections/vectors/journey.test.ts b/src/collections/vectors/journey.test.ts new file mode 100644 index 00000000..2b0cd08f --- /dev/null +++ b/src/collections/vectors/journey.test.ts @@ -0,0 +1,159 @@ +import weaviate, { + VectorIndexConfigHNSW, + WeaviateClient, + WeaviateField, + WeaviateGenericObject, +} from '../../index.js'; +import { DbVersion } from '../../utils/dbVersion.js'; +import { Collection } from '../collection/index.js'; +import { MultiVectorType, SingleVectorType } from '../query/types.js'; + +const only = DbVersion.fromString(`v${process.env.WEAVIATE_VERSION!}`).isAtLeast(1, 29, 0) + ? describe + : describe.skip; + +only('Testing of the collection.query methods with a collection with multvectors', () => { + let client: WeaviateClient; + let collection: Collection; + const collectionName = 'TestCollectionQueryWithMultiVectors'; + + let id1: string; + let id2: string; + + let singleVector: SingleVectorType; + let multiVector: MultiVectorType; + + type MyVectors = { + regular: SingleVectorType; + colbert: MultiVectorType; + }; + + afterAll(() => { + return client.collections.delete(collectionName).catch((err) => { + console.error(err); + throw err; + }); + }); + + beforeAll(async () => { + client = await weaviate.connectToLocal(); + collection = client.collections.use(collectionName); + }); + + afterAll(() => client.collections.delete(collectionName)); + + it('should be able to create a collection including multivectors', async () => { + const { hnsw } = weaviate.configure.vectorIndex; + const { multiVector } = weaviate.configure.vectorIndex.multiVector; + collection = await client.collections.create({ + name: collectionName, + vectorizers: [ + weaviate.configure.vectorizer.none({ + name: 'regular', + }), + weaviate.configure.vectorizer.none({ + name: 'colbert', + vectorIndexConfig: hnsw({ + multiVector: multiVector(), + }), + }), + ], + }); + }); + + it('should be able to get the config of the created collection', async () => { + const config = await collection.config.get(); + expect(config.vectorizers.regular).toBeDefined(); + expect(config.vectorizers.colbert).toBeDefined(); + expect((config.vectorizers.regular.indexConfig as VectorIndexConfigHNSW).multiVector).toBeUndefined(); + expect((config.vectorizers.colbert.indexConfig as VectorIndexConfigHNSW).multiVector).toBeDefined(); + }); + + it('should be able to insert one object with multiple multivectors', async () => { + id1 = await collection.data.insert({ + vectors: { + regular: [1, 2, 3, 4], + colbert: [ + [1, 2], + [3, 4], + ], + }, + }); + }); + + it('should be able to get the inserted object with its vectors', async () => { + const obj = await collection.query.fetchObjectById(id1, { includeVector: true }); + const assert = (obj: any): obj is WeaviateGenericObject, MyVectors> => { + expect(obj).not.toBeNull(); + return true; + }; + if (assert(obj)) { + singleVector = obj.vectors.regular; + multiVector = obj.vectors.colbert; + expect(obj.uuid).toBe(id1); + expect(obj.vectors).toBeDefined(); + expect(obj.vectors.regular).toEqual([1, 2, 3, 4]); + expect(obj.vectors.colbert).toEqual([ + [1, 2], + [3, 4], + ]); + } + }); + + it('should be able to query with hybrid for the inserted object over the single vector space', async () => { + const result = await collection.query.hybrid('', { + alpha: 1, + vector: singleVector, + targetVector: ['regular'], + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with hybrid for the inserted object over the multi vector space', async () => { + const result = await collection.query.hybrid('', { + alpha: 1, + vector: multiVector, + targetVector: ['colbert'], + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with hybrid for the inserted object over both spaces simultaneously', async () => { + const result = await collection.query.hybrid('', { + alpha: 1, + vector: { regular: singleVector, colbert: multiVector }, + targetVector: collection.multiTargetVector.sum(['regular', 'colbert']), + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with nearVector for the inserted object over the single vector space', async () => { + const result = await collection.query.nearVector(singleVector, { + certainty: 0.5, + targetVector: ['regular'], + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with nearVector for the inserted object over the multi vector space', async () => { + const result = await collection.query.nearVector(multiVector, { + certainty: 0.5, + targetVector: ['colbert'], + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with nearVector for the inserted object over both spaces simultaneously', async () => { + const result = await collection.query.nearVector( + { regular: singleVector, colbert: multiVector }, + { targetVector: collection.multiTargetVector.sum(['regular', 'colbert']) } + ); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); +}); diff --git a/src/collections/vectors/multiTargetVector.ts b/src/collections/vectors/multiTargetVector.ts index f9d2abdf..2ce41fbb 100644 --- a/src/collections/vectors/multiTargetVector.ts +++ b/src/collections/vectors/multiTargetVector.ts @@ -1,3 +1,5 @@ +import { TargetVector } from '../query/types.js'; + /** The allowed combination methods for multi-target vector joins */ export type MultiTargetVectorJoinCombination = | 'sum' @@ -7,55 +9,63 @@ export type MultiTargetVectorJoinCombination = | 'manual-weights'; /** Weights for each target vector in a multi-target vector join */ -export type MultiTargetVectorWeights = Record; +export type MultiTargetVectorWeights = Partial, number | number[]>>; /** A multi-target vector join used when specifying a vector-based query */ -export type MultiTargetVectorJoin = { +export type MultiTargetVectorJoin = { /** The combination method to use for the target vectors */ combination: MultiTargetVectorJoinCombination; /** The target vectors to combine */ - targetVectors: string[]; + targetVectors: TargetVector[]; /** The weights to use for each target vector */ - weights?: MultiTargetVectorWeights; + weights?: MultiTargetVectorWeights; }; -export default () => { +export default () => { return { - sum: (targetVectors: string[]): MultiTargetVectorJoin => { + sum: []>(targetVectors: T): MultiTargetVectorJoin => { return { combination: 'sum' as MultiTargetVectorJoinCombination, targetVectors }; }, - average: (targetVectors: string[]): MultiTargetVectorJoin => { + average: []>(targetVectors: T): MultiTargetVectorJoin => { return { combination: 'average' as MultiTargetVectorJoinCombination, targetVectors }; }, - minimum: (targetVectors: string[]): MultiTargetVectorJoin => { + minimum: []>(targetVectors: T): MultiTargetVectorJoin => { return { combination: 'minimum' as MultiTargetVectorJoinCombination, targetVectors }; }, - relativeScore: (weights: MultiTargetVectorWeights): MultiTargetVectorJoin => { + relativeScore: []>( + weights: MultiTargetVectorWeights + ): MultiTargetVectorJoin => { return { combination: 'relative-score' as MultiTargetVectorJoinCombination, - targetVectors: Object.keys(weights), + targetVectors: Object.keys(weights) as T, weights, }; }, - manualWeights: (weights: MultiTargetVectorWeights): MultiTargetVectorJoin => { + manualWeights: []>( + weights: MultiTargetVectorWeights + ): MultiTargetVectorJoin => { return { combination: 'manual-weights' as MultiTargetVectorJoinCombination, - targetVectors: Object.keys(weights), + targetVectors: Object.keys(weights) as T, weights, }; }, }; }; -export interface MultiTargetVector { +export interface MultiTargetVector { /** Create a multi-target vector join that sums the vector scores over the target vectors */ - sum: (targetVectors: string[]) => MultiTargetVectorJoin; + sum: []>(targetVectors: T) => MultiTargetVectorJoin; /** Create a multi-target vector join that averages the vector scores over the target vectors */ - average: (targetVectors: string[]) => MultiTargetVectorJoin; + average: []>(targetVectors: T) => MultiTargetVectorJoin; /** Create a multi-target vector join that takes the minimum vector score over the target vectors */ - minimum: (targetVectors: string[]) => MultiTargetVectorJoin; + minimum: []>(targetVectors: T) => MultiTargetVectorJoin; /** Create a multi-target vector join that uses relative weights for each target vector */ - relativeScore: (weights: MultiTargetVectorWeights) => MultiTargetVectorJoin; + relativeScore: []>( + weights: MultiTargetVectorWeights + ) => MultiTargetVectorJoin; /** Create a multi-target vector join that uses manual weights for each target vector */ - manualWeights: (weights: MultiTargetVectorWeights) => MultiTargetVectorJoin; + manualWeights: []>( + weights: MultiTargetVectorWeights + ) => MultiTargetVectorJoin; } diff --git a/src/utils/yield.ts b/src/utils/yield.ts new file mode 100644 index 00000000..cf3a911c --- /dev/null +++ b/src/utils/yield.ts @@ -0,0 +1 @@ +export const yieldToEventLoop = () => new Promise((resolve) => setTimeout(resolve, 0)); From 47e8098b8e94981af1a83c2f1eb9d93cff5870f4 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Thu, 13 Feb 2025 17:08:50 +0000 Subject: [PATCH 4/4] Invert parsing of vector type logic to achieve BC --- src/collections/deserialize/index.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/collections/deserialize/index.ts b/src/collections/deserialize/index.ts index 54886eb4..a6f33fe1 100644 --- a/src/collections/deserialize/index.ts +++ b/src/collections/deserialize/index.ts @@ -450,9 +450,9 @@ export class Deserialize { await Promise.all( metadata.vectors.map(async (vector) => [ vector.name, - vector.type === Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32 - ? Deserialize.vectorFromBytes(vector.vectorBytes) - : await Deserialize.vectorsFromBytes(vector.vectorBytes), + vector.type === Vectors_VectorType.VECTOR_TYPE_MULTI_FP32 + ? await Deserialize.vectorsFromBytes(vector.vectorBytes) + : Deserialize.vectorFromBytes(vector.vectorBytes), ]) ) );