diff --git a/ci/docker-compose.yml b/ci/docker-compose.yml index 74746f97..b762e87c 100644 --- a/ci/docker-compose.yml +++ b/ci/docker-compose.yml @@ -21,7 +21,7 @@ services: AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' PERSISTENCE_DATA_PATH: '/var/lib/weaviate' DEFAULT_VECTORIZER_MODULE: 'text2vec-contextionary' - ENABLE_MODULES: text2vec-contextionary,backup-filesystem,img2vec-neural + ENABLE_MODULES: text2vec-contextionary,backup-filesystem,img2vec-neural,generative-cohere,reranker-cohere BACKUP_FILESYSTEM_PATH: "/tmp/backups" CLUSTER_GOSSIP_BIND_PORT: "7100" CLUSTER_DATA_BIND_PORT: "7101" diff --git a/src/collections/config/classes.ts b/src/collections/config/classes.ts index 1b023269..7c7fdc65 100644 --- a/src/collections/config/classes.ts +++ b/src/collections/config/classes.ts @@ -3,6 +3,7 @@ import { WeaviateInvalidInputError } from '../../errors.js'; import { WeaviateClass, WeaviateInvertedIndexConfig, + WeaviateModuleConfig, WeaviateMultiTenancyConfig, WeaviateReplicationConfig, WeaviateVectorIndexConfig, @@ -17,7 +18,15 @@ import { VectorIndexConfigFlatUpdate, VectorIndexConfigHNSWUpdate, } from '../configure/types/index.js'; -import { CollectionConfigUpdate, VectorIndexType } from './types/index.js'; +import { + CollectionConfigUpdate, + GenerativeConfig, + GenerativeSearch, + ModuleConfig, + Reranker, + RerankerConfig, + VectorIndexType, +} from './types/index.js'; export class MergeWithExisting { static schema( @@ -27,6 +36,8 @@ export class MergeWithExisting { ): WeaviateClass { if (update === undefined) return current; if (update.description !== undefined) current.description = update.description; + if (update.generative !== undefined) + current.moduleConfig = MergeWithExisting.generative(current.moduleConfig, update.generative); if (update.invertedIndex !== undefined) current.invertedIndexConfig = MergeWithExisting.invertedIndex( current.invertedIndexConfig, @@ -42,6 +53,8 @@ export class MergeWithExisting { current.replicationConfig!, update.replication ); + if (update.reranker !== undefined) + current.moduleConfig = MergeWithExisting.reranker(current.moduleConfig, update.reranker); if (update.vectorizers !== undefined) { if (Array.isArray(update.vectorizers)) { current.vectorConfig = MergeWithExisting.vectors(current.vectorConfig, update.vectorizers); @@ -61,6 +74,35 @@ export class MergeWithExisting { return current; } + static generative( + current: WeaviateModuleConfig, + update: ModuleConfig + ): WeaviateModuleConfig { + if (current === undefined) throw Error('Module config is missing from the class schema.'); + if (update === undefined) return current; + const generative = update.name === 'generative-azure-openai' ? 'generative-openai' : update.name; + const currentGenerative = current[generative] as Record; + current[generative] = { + ...currentGenerative, + ...update.config, + }; + return current; + } + + static reranker( + current: WeaviateModuleConfig, + update: ModuleConfig + ): WeaviateModuleConfig { + if (current === undefined) throw Error('Module config is missing from the class schema.'); + if (update === undefined) return current; + const reranker = current[update.name] as Record; + current[update.name] = { + ...reranker, + ...update.config, + }; + return current; + } + static invertedIndex( current: WeaviateInvertedIndexConfig, update: InvertedIndexConfigUpdate diff --git a/src/collections/config/integration.test.ts b/src/collections/config/integration.test.ts index e930a0ef..254c6112 100644 --- a/src/collections/config/integration.test.ts +++ b/src/collections/config/integration.test.ts @@ -2,8 +2,11 @@ import { WeaviateUnsupportedFeatureError } from '../../errors.js'; import weaviate, { WeaviateClient, weaviateV2 } from '../../index.js'; import { + GenerativeCohereConfig, + ModuleConfig, MultiTenancyConfig, PropertyConfig, + RerankerCohereConfig, VectorIndexConfigDynamic, VectorIndexConfigHNSW, } from './types/index.js'; @@ -621,4 +624,53 @@ describe('Testing of the collection.config namespace', () => { expect(config.vectorizers.default.indexType).toEqual('hnsw'); expect(config.vectorizers.default.vectorizer.name).toEqual('none'); }); + + it('should be able to update the generative & reranker configs of a collection', async () => { + if ((await client.getWeaviateVersion()).isLowerThan(1, 25, 0)) { + console.warn('Skipping test because Weaviate version is lower than 1.25.0'); + return; + } + const collectionName = 'TestCollectionConfigUpdateGenerative'; + const collection = client.collections.get(collectionName); + await client.collections.create({ + name: collectionName, + vectorizers: weaviate.configure.vectorizer.none(), + }); + let config = await collection.config.get(); + expect(config.generative).toBeUndefined(); + + await collection.config.update({ + generative: weaviate.reconfigure.generative.cohere({ + model: 'model', + }), + }); + + config = await collection.config.get(); + expect(config.generative).toEqual>({ + name: 'generative-cohere', + config: { + model: 'model', + }, + }); + + await collection.config.update({ + reranker: weaviate.reconfigure.reranker.cohere({ + model: 'model', + }), + }); + + config = await collection.config.get(); + expect(config.generative).toEqual>({ + name: 'generative-cohere', + config: { + model: 'model', + }, + }); + expect(config.reranker).toEqual>({ + name: 'reranker-cohere', + config: { + model: 'model', + }, + }); + }); }); diff --git a/src/collections/config/types/index.ts b/src/collections/config/types/index.ts index 8e5e1fb2..04630b1c 100644 --- a/src/collections/config/types/index.ts +++ b/src/collections/config/types/index.ts @@ -9,8 +9,8 @@ import { ReplicationConfigUpdate, VectorConfigUpdate, } from '../../configure/types/index.js'; -import { GenerativeConfig } from './generative.js'; -import { RerankerConfig } from './reranker.js'; +import { GenerativeConfig, GenerativeSearch } from './generative.js'; +import { Reranker, RerankerConfig } from './reranker.js'; import { VectorIndexType } from './vectorIndex.js'; import { VectorConfig } from './vectorizer.js'; @@ -93,22 +93,24 @@ export type ShardingConfig = { export type CollectionConfig = { name: string; description?: string; - generative?: GenerativeConfig; + generative?: ModuleConfig; invertedIndex: InvertedIndexConfig; multiTenancy: MultiTenancyConfig; properties: PropertyConfig[]; references: ReferenceConfig[]; replication: ReplicationConfig; - reranker?: RerankerConfig; + reranker?: ModuleConfig; sharding: ShardingConfig; vectorizers: VectorConfig; }; export type CollectionConfigUpdate = { description?: string; + generative?: ModuleConfig; invertedIndex?: InvertedIndexConfigUpdate; multiTenancy?: MultiTenancyConfigUpdate; replication?: ReplicationConfigUpdate; + reranker?: ModuleConfig; vectorizers?: | VectorConfigUpdate | VectorConfigUpdate[]; diff --git a/src/collections/config/unit.test.ts b/src/collections/config/unit.test.ts index fdc83c5d..b8fc8596 100644 --- a/src/collections/config/unit.test.ts +++ b/src/collections/config/unit.test.ts @@ -1,11 +1,15 @@ import { WeaviateInvertedIndexConfig, + WeaviateModuleConfig, WeaviateMultiTenancyConfig, WeaviateVectorsConfig, } from '../../openapi/types'; import { MergeWithExisting } from './classes'; +import { GenerativeCohereConfig, RerankerCohereConfig } from './types'; describe('Unit testing of the MergeWithExisting class', () => { + const deepCopy = (config: any) => JSON.parse(JSON.stringify(config)); + const invertedIndex: WeaviateInvertedIndexConfig = { bm25: { b: 0.8, @@ -62,7 +66,7 @@ describe('Unit testing of the MergeWithExisting class', () => { }; it('should merge a full invertedIndexUpdate with existing schema', () => { - const merged = MergeWithExisting.invertedIndex(JSON.parse(JSON.stringify(invertedIndex)), { + const merged = MergeWithExisting.invertedIndex(deepCopy(invertedIndex), { bm25: { b: 0.9, k1: 1.4, @@ -122,8 +126,20 @@ describe('Unit testing of the MergeWithExisting class', () => { autoTenantCreation: false, }; + const moduleConfig: WeaviateModuleConfig = { + 'generative-cohere': { + kProperty: 0.1, + model: 'model', + maxTokensProperty: '5', + returnLikelihoodsProperty: 'likelihoods', + stopSequencesProperty: ['and'], + temperatureProperty: 5.2, + }, + 'reranker-cohere': {}, + }; + it('should merge a partial invertedIndexUpdate with existing schema', () => { - const merged = MergeWithExisting.invertedIndex(JSON.parse(JSON.stringify(invertedIndex)), { + const merged = MergeWithExisting.invertedIndex(deepCopy(invertedIndex), { bm25: { b: 0.9, }, @@ -147,7 +163,7 @@ describe('Unit testing of the MergeWithExisting class', () => { }); it('should merge a no quantizer HNSW vectorIndexConfig with existing schema', () => { - const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [ + const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [ { name: 'name', vectorIndex: { @@ -196,7 +212,7 @@ describe('Unit testing of the MergeWithExisting class', () => { }); it('should merge a PQ quantizer HNSW vectorIndexConfig with existing schema', () => { - const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [ + const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [ { name: 'name', vectorIndex: { @@ -245,7 +261,7 @@ describe('Unit testing of the MergeWithExisting class', () => { }); it('should merge a BQ quantizer HNSW vectorIndexConfig with existing schema', () => { - const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [ + const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [ { name: 'name', vectorIndex: { @@ -280,7 +296,7 @@ describe('Unit testing of the MergeWithExisting class', () => { }); it('should merge a SQ quantizer HNSW vectorIndexConfig with existing schema', () => { - const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [ + const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [ { name: 'name', vectorIndex: { @@ -317,7 +333,7 @@ describe('Unit testing of the MergeWithExisting class', () => { }); it('should merge a BQ quantizer Flat vectorIndexConfig with existing schema', () => { - const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(flatVectorConfig)), [ + const merged = MergeWithExisting.vectors(deepCopy(flatVectorConfig), [ { name: 'name', vectorIndex: { @@ -353,7 +369,7 @@ describe('Unit testing of the MergeWithExisting class', () => { }); it('should merge full multi tenancy config with existing schema', () => { - const merged = MergeWithExisting.multiTenancy(JSON.parse(JSON.stringify(multiTenancyConfig)), { + const merged = MergeWithExisting.multiTenancy(deepCopy(multiTenancyConfig), { autoTenantActivation: true, autoTenantCreation: true, }); @@ -363,4 +379,36 @@ describe('Unit testing of the MergeWithExisting class', () => { autoTenantCreation: true, }); }); + + it('should merge a generative config with existing schema', () => { + const merged = MergeWithExisting.generative(deepCopy(moduleConfig), { + name: 'generative-cohere', + config: { + kProperty: 0.2, + } as GenerativeCohereConfig, + }); + expect(merged).toEqual({ + ...moduleConfig, + 'generative-cohere': { + ...(moduleConfig['generative-cohere'] as any), + kProperty: 0.2, + } as GenerativeCohereConfig, + }); + }); + + it('should merge a reranker config with existing schema', () => { + const merged = MergeWithExisting.reranker(deepCopy(moduleConfig), { + name: 'reranker-cohere', + config: { + model: 'other', + } as RerankerCohereConfig, + }); + expect(merged).toEqual({ + ...moduleConfig, + 'reranker-cohere': { + ...(moduleConfig['reranker-cohere'] as any), + model: 'other', + } as RerankerCohereConfig, + }); + }); }); diff --git a/src/collections/configure/index.ts b/src/collections/configure/index.ts index 9619fe11..7a1ca15c 100644 --- a/src/collections/configure/index.ts +++ b/src/collections/configure/index.ts @@ -261,6 +261,8 @@ const reconfigure = { autoTenantCreation: options.autoTenantCreation, }; }, + generative: configure.generative, + reranker: configure.reranker, }; export {