Skip to content

Commit

Permalink
Merge pull request #249 from weaviate/#222/support-updating-generativ…
Browse files Browse the repository at this point in the history
…e-and-reranker

Improve generative/reranker config UX
  • Loading branch information
tsmith023 authored Jan 13, 2025
2 parents ef950c4 + 892ec12 commit 3164cfd
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 14 deletions.
2 changes: 1 addition & 1 deletion ci/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ services:
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
DEFAULT_VECTORIZER_MODULE: 'text2vec-contextionary'
ENABLE_MODULES: text2vec-contextionary,backup-filesystem,img2vec-neural
ENABLE_MODULES: text2vec-contextionary,backup-filesystem,img2vec-neural,generative-cohere,reranker-cohere
BACKUP_FILESYSTEM_PATH: "/tmp/backups"
CLUSTER_GOSSIP_BIND_PORT: "7100"
CLUSTER_DATA_BIND_PORT: "7101"
Expand Down
44 changes: 43 additions & 1 deletion src/collections/config/classes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { WeaviateInvalidInputError } from '../../errors.js';
import {
WeaviateClass,
WeaviateInvertedIndexConfig,
WeaviateModuleConfig,
WeaviateMultiTenancyConfig,
WeaviateReplicationConfig,
WeaviateVectorIndexConfig,
Expand All @@ -17,7 +18,15 @@ import {
VectorIndexConfigFlatUpdate,
VectorIndexConfigHNSWUpdate,
} from '../configure/types/index.js';
import { CollectionConfigUpdate, VectorIndexType } from './types/index.js';
import {
CollectionConfigUpdate,
GenerativeConfig,
GenerativeSearch,
ModuleConfig,
Reranker,
RerankerConfig,
VectorIndexType,
} from './types/index.js';

export class MergeWithExisting {
static schema(
Expand All @@ -27,6 +36,8 @@ export class MergeWithExisting {
): WeaviateClass {
if (update === undefined) return current;
if (update.description !== undefined) current.description = update.description;
if (update.generative !== undefined)
current.moduleConfig = MergeWithExisting.generative(current.moduleConfig, update.generative);
if (update.invertedIndex !== undefined)
current.invertedIndexConfig = MergeWithExisting.invertedIndex(
current.invertedIndexConfig,
Expand All @@ -42,6 +53,8 @@ export class MergeWithExisting {
current.replicationConfig!,
update.replication
);
if (update.reranker !== undefined)
current.moduleConfig = MergeWithExisting.reranker(current.moduleConfig, update.reranker);
if (update.vectorizers !== undefined) {
if (Array.isArray(update.vectorizers)) {
current.vectorConfig = MergeWithExisting.vectors(current.vectorConfig, update.vectorizers);
Expand All @@ -61,6 +74,35 @@ export class MergeWithExisting {
return current;
}

static generative(
current: WeaviateModuleConfig,
update: ModuleConfig<GenerativeSearch, GenerativeConfig>
): WeaviateModuleConfig {
if (current === undefined) throw Error('Module config is missing from the class schema.');
if (update === undefined) return current;
const generative = update.name === 'generative-azure-openai' ? 'generative-openai' : update.name;
const currentGenerative = current[generative] as Record<string, any>;
current[generative] = {
...currentGenerative,
...update.config,
};
return current;
}

static reranker(
current: WeaviateModuleConfig,
update: ModuleConfig<Reranker, RerankerConfig>
): WeaviateModuleConfig {
if (current === undefined) throw Error('Module config is missing from the class schema.');
if (update === undefined) return current;
const reranker = current[update.name] as Record<string, any>;
current[update.name] = {
...reranker,
...update.config,
};
return current;
}

static invertedIndex(
current: WeaviateInvertedIndexConfig,
update: InvertedIndexConfigUpdate
Expand Down
52 changes: 52 additions & 0 deletions src/collections/config/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import { WeaviateUnsupportedFeatureError } from '../../errors.js';
import weaviate, { WeaviateClient, weaviateV2 } from '../../index.js';
import {
GenerativeCohereConfig,
ModuleConfig,
MultiTenancyConfig,
PropertyConfig,
RerankerCohereConfig,
VectorIndexConfigDynamic,
VectorIndexConfigHNSW,
} from './types/index.js';
Expand Down Expand Up @@ -621,4 +624,53 @@ describe('Testing of the collection.config namespace', () => {
expect(config.vectorizers.default.indexType).toEqual('hnsw');
expect(config.vectorizers.default.vectorizer.name).toEqual('none');
});

it('should be able to update the generative & reranker configs of a collection', async () => {
if ((await client.getWeaviateVersion()).isLowerThan(1, 25, 0)) {
console.warn('Skipping test because Weaviate version is lower than 1.25.0');
return;
}
const collectionName = 'TestCollectionConfigUpdateGenerative';
const collection = client.collections.get(collectionName);
await client.collections.create({
name: collectionName,
vectorizers: weaviate.configure.vectorizer.none(),
});
let config = await collection.config.get();
expect(config.generative).toBeUndefined();

await collection.config.update({
generative: weaviate.reconfigure.generative.cohere({
model: 'model',
}),
});

config = await collection.config.get();
expect(config.generative).toEqual<ModuleConfig<'generative-cohere', GenerativeCohereConfig>>({
name: 'generative-cohere',
config: {
model: 'model',
},
});

await collection.config.update({
reranker: weaviate.reconfigure.reranker.cohere({
model: 'model',
}),
});

config = await collection.config.get();
expect(config.generative).toEqual<ModuleConfig<'generative-cohere', GenerativeCohereConfig>>({
name: 'generative-cohere',
config: {
model: 'model',
},
});
expect(config.reranker).toEqual<ModuleConfig<'reranker-cohere', RerankerCohereConfig>>({
name: 'reranker-cohere',
config: {
model: 'model',
},
});
});
});
10 changes: 6 additions & 4 deletions src/collections/config/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import {
ReplicationConfigUpdate,
VectorConfigUpdate,
} from '../../configure/types/index.js';
import { GenerativeConfig } from './generative.js';
import { RerankerConfig } from './reranker.js';
import { GenerativeConfig, GenerativeSearch } from './generative.js';
import { Reranker, RerankerConfig } from './reranker.js';
import { VectorIndexType } from './vectorIndex.js';
import { VectorConfig } from './vectorizer.js';

Expand Down Expand Up @@ -93,22 +93,24 @@ export type ShardingConfig = {
export type CollectionConfig = {
name: string;
description?: string;
generative?: GenerativeConfig;
generative?: ModuleConfig<GenerativeSearch, GenerativeConfig>;
invertedIndex: InvertedIndexConfig;
multiTenancy: MultiTenancyConfig;
properties: PropertyConfig[];
references: ReferenceConfig[];
replication: ReplicationConfig;
reranker?: RerankerConfig;
reranker?: ModuleConfig<Reranker, RerankerConfig>;
sharding: ShardingConfig;
vectorizers: VectorConfig;
};

export type CollectionConfigUpdate = {
description?: string;
generative?: ModuleConfig<GenerativeSearch, GenerativeConfig>;
invertedIndex?: InvertedIndexConfigUpdate;
multiTenancy?: MultiTenancyConfigUpdate;
replication?: ReplicationConfigUpdate;
reranker?: ModuleConfig<Reranker, RerankerConfig>;
vectorizers?:
| VectorConfigUpdate<undefined, VectorIndexType>
| VectorConfigUpdate<string, VectorIndexType>[];
Expand Down
64 changes: 56 additions & 8 deletions src/collections/config/unit.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import {
WeaviateInvertedIndexConfig,
WeaviateModuleConfig,
WeaviateMultiTenancyConfig,
WeaviateVectorsConfig,
} from '../../openapi/types';
import { MergeWithExisting } from './classes';
import { GenerativeCohereConfig, RerankerCohereConfig } from './types';

describe('Unit testing of the MergeWithExisting class', () => {
const deepCopy = (config: any) => JSON.parse(JSON.stringify(config));

const invertedIndex: WeaviateInvertedIndexConfig = {
bm25: {
b: 0.8,
Expand Down Expand Up @@ -62,7 +66,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
};

it('should merge a full invertedIndexUpdate with existing schema', () => {
const merged = MergeWithExisting.invertedIndex(JSON.parse(JSON.stringify(invertedIndex)), {
const merged = MergeWithExisting.invertedIndex(deepCopy(invertedIndex), {
bm25: {
b: 0.9,
k1: 1.4,
Expand Down Expand Up @@ -122,8 +126,20 @@ describe('Unit testing of the MergeWithExisting class', () => {
autoTenantCreation: false,
};

const moduleConfig: WeaviateModuleConfig = {
'generative-cohere': {
kProperty: 0.1,
model: 'model',
maxTokensProperty: '5',
returnLikelihoodsProperty: 'likelihoods',
stopSequencesProperty: ['and'],
temperatureProperty: 5.2,
},
'reranker-cohere': {},
};

it('should merge a partial invertedIndexUpdate with existing schema', () => {
const merged = MergeWithExisting.invertedIndex(JSON.parse(JSON.stringify(invertedIndex)), {
const merged = MergeWithExisting.invertedIndex(deepCopy(invertedIndex), {
bm25: {
b: 0.9,
},
Expand All @@ -147,7 +163,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
});

it('should merge a no quantizer HNSW vectorIndexConfig with existing schema', () => {
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [
const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [
{
name: 'name',
vectorIndex: {
Expand Down Expand Up @@ -196,7 +212,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
});

it('should merge a PQ quantizer HNSW vectorIndexConfig with existing schema', () => {
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [
const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [
{
name: 'name',
vectorIndex: {
Expand Down Expand Up @@ -245,7 +261,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
});

it('should merge a BQ quantizer HNSW vectorIndexConfig with existing schema', () => {
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [
const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [
{
name: 'name',
vectorIndex: {
Expand Down Expand Up @@ -280,7 +296,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
});

it('should merge a SQ quantizer HNSW vectorIndexConfig with existing schema', () => {
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [
const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [
{
name: 'name',
vectorIndex: {
Expand Down Expand Up @@ -317,7 +333,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
});

it('should merge a BQ quantizer Flat vectorIndexConfig with existing schema', () => {
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(flatVectorConfig)), [
const merged = MergeWithExisting.vectors(deepCopy(flatVectorConfig), [
{
name: 'name',
vectorIndex: {
Expand Down Expand Up @@ -353,7 +369,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
});

it('should merge full multi tenancy config with existing schema', () => {
const merged = MergeWithExisting.multiTenancy(JSON.parse(JSON.stringify(multiTenancyConfig)), {
const merged = MergeWithExisting.multiTenancy(deepCopy(multiTenancyConfig), {
autoTenantActivation: true,
autoTenantCreation: true,
});
Expand All @@ -363,4 +379,36 @@ describe('Unit testing of the MergeWithExisting class', () => {
autoTenantCreation: true,
});
});

it('should merge a generative config with existing schema', () => {
const merged = MergeWithExisting.generative(deepCopy(moduleConfig), {
name: 'generative-cohere',
config: {
kProperty: 0.2,
} as GenerativeCohereConfig,
});
expect(merged).toEqual({
...moduleConfig,
'generative-cohere': {
...(moduleConfig['generative-cohere'] as any),
kProperty: 0.2,
} as GenerativeCohereConfig,
});
});

it('should merge a reranker config with existing schema', () => {
const merged = MergeWithExisting.reranker(deepCopy(moduleConfig), {
name: 'reranker-cohere',
config: {
model: 'other',
} as RerankerCohereConfig,
});
expect(merged).toEqual({
...moduleConfig,
'reranker-cohere': {
...(moduleConfig['reranker-cohere'] as any),
model: 'other',
} as RerankerCohereConfig,
});
});
});
2 changes: 2 additions & 0 deletions src/collections/configure/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ const reconfigure = {
autoTenantCreation: options.autoTenantCreation,
};
},
generative: configure.generative,
reranker: configure.reranker,
};

export {
Expand Down

0 comments on commit 3164cfd

Please sign in to comment.