Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for primitive multi-vectors (colbert) #264

Open
wants to merge 6 commits into
base: dev/1.29
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@
"lint-staged": {
"*.{ts,js}": [
"npm run format:check",
"npm run lint -- --cache"
"npm run lint -- --cache",
"npm run prepack"
]
}
}
86 changes: 44 additions & 42 deletions src/collections/aggregate/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { WeaviateInvalidInputError, WeaviateQueryError } from '../../errors.js';
import { Aggregator } from '../../graphql/index.js';
import { PrimitiveKeys, toBase64FromMedia } from '../../index.js';
import { Deserialize } from '../deserialize/index.js';
import { Bm25QueryProperty, NearVectorInputType } from '../query/types.js';
import { Bm25QueryProperty, NearVectorInputType, TargetVector } from '../query/types.js';
import { NearVectorInputGuards } from '../query/utils.js';
import { Serialize } from '../serialize/index.js';

Expand All @@ -31,27 +31,27 @@ export type GroupByAggregate<T> = {

export type AggregateOverAllOptions<M> = AggregateBaseOptions<M>;

export type AggregateNearOptions<M> = AggregateBaseOptions<M> & {
export type AggregateNearOptions<M, V> = AggregateBaseOptions<M> & {
certainty?: number;
distance?: number;
objectLimit?: number;
targetVector?: string;
targetVector?: TargetVector<V>;
};

export type AggregateHybridOptions<T, M> = AggregateBaseOptions<M> & {
export type AggregateHybridOptions<T, M, V> = AggregateBaseOptions<M> & {
alpha?: number;
maxVectorDistance?: number;
objectLimit?: number;
queryProperties?: (PrimitiveKeys<T> | Bm25QueryProperty<T>)[];
targetVector?: string;
targetVector?: TargetVector<V>;
vector?: number[];
};

export type AggregateGroupByHybridOptions<T, M> = AggregateHybridOptions<T, M> & {
export type AggregateGroupByHybridOptions<T, M, V> = AggregateHybridOptions<T, M, V> & {
groupBy: PropertyOf<T> | GroupByAggregate<T>;
};

export type AggregateGroupByNearOptions<T, M> = AggregateNearOptions<M> & {
export type AggregateGroupByNearOptions<T, M, V> = AggregateNearOptions<M, V> & {
groupBy: PropertyOf<T> | GroupByAggregate<T>;
};

Expand Down Expand Up @@ -346,9 +346,9 @@ export type AggregateGroupByResult<
};
};

class AggregateManager<T> implements Aggregate<T> {
class AggregateManager<T, V> implements Aggregate<T, V> {
connection: Connection;
groupBy: AggregateGroupBy<T>;
groupBy: AggregateGroupBy<T, V>;
name: string;
dbVersionSupport: DbVersionSupport;
consistencyLevel?: ConsistencyLevel;
Expand All @@ -373,14 +373,14 @@ class AggregateManager<T> implements Aggregate<T> {
this.groupBy = {
hybrid: async <M extends PropertiesMetrics<T>>(
query: string,
opts: AggregateGroupByHybridOptions<T, M>
opts: AggregateGroupByHybridOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]> => {
if (await this.grpcChecker) {
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
return this.grpc()
.then((aggregate) =>
.then(async (aggregate) =>
aggregate.withHybrid({
...Serialize.aggregate.hybrid(query, opts),
...(await Serialize.aggregate.hybrid(query, opts)),
groupBy: Serialize.aggregate.groupBy(group),
limit: group.limit,
})
Expand All @@ -402,7 +402,7 @@ class AggregateManager<T> implements Aggregate<T> {
},
nearImage: async <M extends PropertiesMetrics<T>>(
image: string | Buffer,
opts: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]> => {
const [b64, usesGrpc] = await Promise.all([await toBase64FromMedia(image), await this.grpcChecker]);
if (usesGrpc) {
Expand Down Expand Up @@ -430,7 +430,7 @@ class AggregateManager<T> implements Aggregate<T> {
},
nearObject: async <M extends PropertiesMetrics<T>>(
id: string,
opts: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]> => {
if (await this.grpcChecker) {
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
Expand All @@ -457,7 +457,7 @@ class AggregateManager<T> implements Aggregate<T> {
},
nearText: async <M extends PropertiesMetrics<T>>(
query: string | string[],
opts: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]> => {
if (await this.grpcChecker) {
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
Expand All @@ -484,14 +484,14 @@ class AggregateManager<T> implements Aggregate<T> {
},
nearVector: async <M extends PropertiesMetrics<T>>(
vector: number[],
opts: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]> => {
if (await this.grpcChecker) {
const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy;
return this.grpc()
.then((aggregate) =>
.then(async (aggregate) =>
aggregate.withNearVector({
...Serialize.aggregate.nearVector(vector, opts),
...(await Serialize.aggregate.nearVector(vector, opts)),
groupBy: Serialize.aggregate.groupBy(group),
limit: group.limit,
})
Expand Down Expand Up @@ -593,23 +593,23 @@ class AggregateManager<T> implements Aggregate<T> {
return `${propertyName} { ${body} }`;
}

static use<T>(
static use<T, V>(
connection: Connection,
name: string,
dbVersionSupport: DbVersionSupport,
consistencyLevel?: ConsistencyLevel,
tenant?: string
): AggregateManager<T> {
return new AggregateManager<T>(connection, name, dbVersionSupport, consistencyLevel, tenant);
): AggregateManager<T, V> {
return new AggregateManager<T, V>(connection, name, dbVersionSupport, consistencyLevel, tenant);
}

async hybrid<M extends PropertiesMetrics<T>>(
query: string,
opts?: AggregateHybridOptions<T, M>
opts?: AggregateHybridOptions<T, M, V>
): Promise<AggregateResult<T, M>> {
if (await this.grpcChecker) {
return this.grpc()
.then((aggregate) => aggregate.withHybrid(Serialize.aggregate.hybrid(query, opts)))
.then(async (aggregate) => aggregate.withHybrid(await Serialize.aggregate.hybrid(query, opts)))
.then((reply) => Deserialize.aggregate(reply));
}
let builder = this.base(opts?.returnMetrics, opts?.filters).withHybrid({
Expand All @@ -628,7 +628,7 @@ class AggregateManager<T> implements Aggregate<T> {

async nearImage<M extends PropertiesMetrics<T>>(
image: string | Buffer,
opts?: AggregateNearOptions<M>
opts?: AggregateNearOptions<M, V>
): Promise<AggregateResult<T, M>> {
const [b64, usesGrpc] = await Promise.all([await toBase64FromMedia(image), await this.grpcChecker]);
if (usesGrpc) {
Expand All @@ -650,7 +650,7 @@ class AggregateManager<T> implements Aggregate<T> {

async nearObject<M extends PropertiesMetrics<T>>(
id: string,
opts?: AggregateNearOptions<M>
opts?: AggregateNearOptions<M, V>
): Promise<AggregateResult<T, M>> {
if (await this.grpcChecker) {
return this.grpc()
Expand All @@ -671,7 +671,7 @@ class AggregateManager<T> implements Aggregate<T> {

async nearText<M extends PropertiesMetrics<T>>(
query: string | string[],
opts?: AggregateNearOptions<M>
opts?: AggregateNearOptions<M, V>
): Promise<AggregateResult<T, M>> {
if (await this.grpcChecker) {
return this.grpc()
Expand All @@ -692,14 +692,16 @@ class AggregateManager<T> implements Aggregate<T> {

async nearVector<M extends PropertiesMetrics<T>>(
vector: NearVectorInputType,
opts?: AggregateNearOptions<M>
opts?: AggregateNearOptions<M, V>
): Promise<AggregateResult<T, M>> {
if (await this.grpcChecker) {
return this.grpc()
.then((aggregate) => aggregate.withNearVector(Serialize.aggregate.nearVector(vector, opts)))
.then(async (aggregate) =>
aggregate.withNearVector(await Serialize.aggregate.nearVector(vector, opts))
)
.then((reply) => Deserialize.aggregate(reply));
}
if (!NearVectorInputGuards.is1DArray(vector)) {
if (!NearVectorInputGuards.is1D(vector)) {
throw new WeaviateInvalidInputError(
'Vector can only be a 1D array of numbers when using `nearVector` with <1.29 Weaviate versions.'
);
Expand Down Expand Up @@ -768,9 +770,9 @@ class AggregateManager<T> implements Aggregate<T> {
};
}

export interface Aggregate<T> {
export interface Aggregate<T, V> {
/** This namespace contains methods perform a group by search while aggregating metrics. */
groupBy: AggregateGroupBy<T>;
groupBy: AggregateGroupBy<T, V>;
/**
* Aggregate metrics over the objects returned by a hybrid search on this collection.
*
Expand All @@ -782,7 +784,7 @@ export interface Aggregate<T> {
*/
hybrid<M extends PropertiesMetrics<T>>(
query: string,
opts?: AggregateHybridOptions<T, M>
opts?: AggregateHybridOptions<T, M, V>
): Promise<AggregateResult<T, M>>;
/**
* Aggregate metrics over the objects returned by a near image vector search on this collection.
Expand All @@ -797,7 +799,7 @@ export interface Aggregate<T> {
*/
nearImage<M extends PropertiesMetrics<T>>(
image: string | Buffer,
opts?: AggregateNearOptions<M>
opts?: AggregateNearOptions<M, V>
): Promise<AggregateResult<T, M>>;
/**
* Aggregate metrics over the objects returned by a near object search on this collection.
Expand All @@ -812,7 +814,7 @@ export interface Aggregate<T> {
*/
nearObject<M extends PropertiesMetrics<T>>(
id: string,
opts?: AggregateNearOptions<M>
opts?: AggregateNearOptions<M, V>
): Promise<AggregateResult<T, M>>;
/**
* Aggregate metrics over the objects returned by a near vector search on this collection.
Expand All @@ -827,7 +829,7 @@ export interface Aggregate<T> {
*/
nearText<M extends PropertiesMetrics<T>>(
query: string | string[],
opts?: AggregateNearOptions<M>
opts?: AggregateNearOptions<M, V>
): Promise<AggregateResult<T, M>>;
/**
* Aggregate metrics over the objects returned by a near vector search on this collection.
Expand All @@ -842,7 +844,7 @@ export interface Aggregate<T> {
*/
nearVector<M extends PropertiesMetrics<T>>(
vector: number[],
opts?: AggregateNearOptions<M>
opts?: AggregateNearOptions<M, V>
): Promise<AggregateResult<T, M>>;
/**
* Aggregate metrics over all the objects in this collection without any vector search.
Expand All @@ -853,7 +855,7 @@ export interface Aggregate<T> {
overAll<M extends PropertiesMetrics<T>>(opts?: AggregateOverAllOptions<M>): Promise<AggregateResult<T, M>>;
}

export interface AggregateGroupBy<T> {
export interface AggregateGroupBy<T, V> {
/**
* Aggregate metrics over the objects grouped by a specified property and returned by a hybrid search on this collection.
*
Expand All @@ -865,7 +867,7 @@ export interface AggregateGroupBy<T> {
*/
hybrid<M extends PropertiesMetrics<T>>(
query: string,
opts: AggregateGroupByHybridOptions<T, M>
opts: AggregateGroupByHybridOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over the objects grouped by a specified property and returned by a near image vector search on this collection.
Expand All @@ -880,7 +882,7 @@ export interface AggregateGroupBy<T> {
*/
nearImage<M extends PropertiesMetrics<T>>(
image: string | Buffer,
opts: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over the objects grouped by a specified property and returned by a near object search on this collection.
Expand All @@ -895,7 +897,7 @@ export interface AggregateGroupBy<T> {
*/
nearObject<M extends PropertiesMetrics<T>>(
id: string,
opts: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over the objects grouped by a specified property and returned by a near text vector search on this collection.
Expand All @@ -910,7 +912,7 @@ export interface AggregateGroupBy<T> {
*/
nearText<M extends PropertiesMetrics<T>>(
query: string | string[],
opts: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over the objects grouped by a specified property and returned by a near vector search on this collection.
Expand All @@ -925,7 +927,7 @@ export interface AggregateGroupBy<T> {
*/
nearVector<M extends PropertiesMetrics<T>>(
vector: number[],
opts: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M, V>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over all the objects in this collection grouped by a specified property without any vector search.
Expand Down
4 changes: 2 additions & 2 deletions src/collections/aggregate/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ describe('Testing of collection.aggregate search methods', () => {

it('should return an aggregation on a nearVector search', async () => {
const obj = await collection.query.fetchObjectById(uuid, { includeVector: true });
const result = await collection.aggregate.nearVector(obj?.vectors.default!, {
const result = await collection.aggregate.nearVector(obj?.vectors.default as number[], {
objectLimit: 1000,
returnMetrics: collection.metrics.aggregate('text').text(['count']),
});
Expand All @@ -494,7 +494,7 @@ describe('Testing of collection.aggregate search methods', () => {

it('should return a grouped aggregation on a nearVector search', async () => {
const obj = await collection.query.fetchObjectById(uuid, { includeVector: true });
const result = await collection.aggregate.groupBy.nearVector(obj?.vectors.default!, {
const result = await collection.aggregate.groupBy.nearVector(obj?.vectors.default as number[], {
objectLimit: 1000,
groupBy: 'text',
returnMetrics: collection.metrics.aggregate('text').text(['count']),
Expand Down
Loading