Skip to content

Commit 9e11408

Browse files
authored
refactor(firestore-vector-search): use Firebase Genkit wherever possible. (#603)
* refactor(firestore-vector-search): use genkit where possible * chore(firestore-vector-search): update CHANGELOG and bump ext version * test(firestore-vector-search): fix tests and add more coverage
1 parent 81cfd8d commit 9e11408

File tree

17 files changed

+2631
-518
lines changed

17 files changed

+2631
-518
lines changed

firestore-vector-search/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## Version 0.0.6
2+
3+
refactor - use Firebase Genkit where possible
4+
15
## Version 0.0.5
26

37
fix - fix backfill and fix npm audit

firestore-vector-search/extension.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
name: firestore-vector-search
16-
version: 0.0.5
16+
version: 0.0.6
1717
specVersion: v1beta
1818

1919
tags:

firestore-vector-search/functions/__tests__/__snapshots__/config.test.ts.snap

Lines changed: 0 additions & 31 deletions
This file was deleted.
Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,84 @@
1-
import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini';
2-
// import { config } from "../../src/config";
1+
jest.resetModules();
2+
3+
// Mock GoogleGenerativeAI and its methods
4+
const mockGetGenerativeModel = jest.fn();
5+
const mockBatchEmbedContents = jest.fn();
36

4-
// mock config
5-
// jest.mock("../../src/config", () => ({
6-
// ...jest.requireActual("../../src/config"),
7-
// geminiApiKey: "test-api-key",
8-
// }));
7+
jest.mock('@google/generative-ai', () => ({
8+
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
9+
getGenerativeModel: mockGetGenerativeModel,
10+
})),
11+
}));
12+
13+
jest.mock('../../src/config', () => ({
14+
config: {
15+
geminiApiKey: 'test-api-key',
16+
},
17+
}));
18+
19+
import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini';
20+
import {GoogleGenerativeAI} from '@google/generative-ai';
921

1022
describe('Gemini Embeddings', () => {
11-
let embedClient;
23+
let embedClient: GeminiAITextEmbedClient;
24+
25+
beforeEach(async () => {
26+
// Reset mocks
27+
jest.clearAllMocks();
1228

13-
beforeEach(() => {
29+
// Mock return value for getGenerativeModel
30+
mockGetGenerativeModel.mockReturnValue({
31+
batchEmbedContents: mockBatchEmbedContents,
32+
});
33+
34+
// Instantiate and initialize the client
1435
embedClient = new GeminiAITextEmbedClient();
36+
await embedClient.initialize();
1537
});
1638

1739
describe('initialize', () => {
1840
test('should properly initialize the client', async () => {
19-
await embedClient.initialize();
20-
2141
expect(embedClient.client).toBeDefined();
22-
// expect(GoogleGenerativeAI).toHaveBeenCalledWith(config.geminiApiKey);
42+
expect(GoogleGenerativeAI).toHaveBeenCalledWith('test-api-key');
2343
});
2444
});
2545

2646
describe('getEmbeddings', () => {
2747
test('should return embeddings for a batch of text', async () => {
28-
const mockEmbedContent = jest
29-
.fn()
30-
.mockResolvedValue({embedding: [1, 2, 3]});
31-
embedClient.client = {
32-
getGenerativeModel: jest.fn(() => ({
33-
embedContent: mockEmbedContent,
34-
})),
35-
};
48+
// Mock batchEmbedContents to resolve with embeddings
49+
mockBatchEmbedContents.mockResolvedValueOnce({
50+
embeddings: [{values: [1, 2, 3]}, {values: [4, 5, 6]}],
51+
});
3652

3753
const batch = ['text1', 'text2'];
3854
const results = await embedClient.getEmbeddings(batch);
3955

40-
expect(mockEmbedContent).toHaveBeenCalledTimes(batch.length);
56+
expect(mockBatchEmbedContents).toHaveBeenCalledWith({
57+
requests: [
58+
{content: {parts: [{text: 'text1'}], role: 'user'}},
59+
{content: {parts: [{text: 'text2'}], role: 'user'}},
60+
],
61+
});
62+
4163
expect(results).toEqual([
4264
[1, 2, 3],
43-
[1, 2, 3],
65+
[4, 5, 6],
4466
]);
4567
});
4668

4769
test('should throw an error if the embedding process fails', async () => {
48-
embedClient.client = {
49-
getGenerativeModel: jest.fn(() => ({
50-
embedContent: jest
51-
.fn()
52-
.mockRejectedValue(new Error('Embedding failed')),
53-
})),
54-
};
70+
// Mock batchEmbedContents to throw an error
71+
mockBatchEmbedContents.mockRejectedValueOnce(
72+
new Error('Embedding failed')
73+
);
5574

5675
await expect(embedClient.getEmbeddings(['text'])).rejects.toThrow(
5776
'Error with embedding'
5877
);
78+
79+
expect(mockBatchEmbedContents).toHaveBeenCalledWith({
80+
requests: [{content: {parts: [{text: 'text'}], role: 'user'}}],
81+
});
5982
});
6083
});
6184
});
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
jest.resetModules();
2+
3+
// Mocking `@genkit-ai/googleai` and `@genkit-ai/vertexai`
4+
jest.mock('@genkit-ai/googleai', () => ({
5+
googleAI: jest.fn(),
6+
textEmbeddingGecko001: 'gecko-001-model',
7+
}));
8+
9+
jest.mock('@genkit-ai/vertexai', () => ({
10+
vertexAI: jest.fn(),
11+
textEmbedding004: 'text-embedding-004-model',
12+
}));
13+
14+
jest.mock('../../src/config', () => ({
15+
config: {
16+
geminiApiKey: 'test-api-key',
17+
location: 'us-central1',
18+
},
19+
}));
20+
21+
import {GenkitEmbedClient} from '../../src/embeddings/client/genkit';
22+
import {genkit} from 'genkit';
23+
import {vertexAI} from '@genkit-ai/vertexai';
24+
import {googleAI} from '@genkit-ai/googleai';
25+
26+
// Mock the genkit client with properly structured responses
27+
const mockEmbedMany = jest.fn();
28+
const mockEmbed = jest.fn();
29+
jest.mock('genkit', () => ({
30+
genkit: jest.fn().mockImplementation(() => ({
31+
embedMany: mockEmbedMany,
32+
embed: mockEmbed,
33+
})),
34+
}));
35+
36+
describe('GenkitEmbedClient', () => {
37+
let embedClient: GenkitEmbedClient;
38+
let mockVertexAI: jest.Mock;
39+
let mockGoogleAI: jest.Mock;
40+
41+
beforeEach(() => {
42+
jest.clearAllMocks();
43+
mockVertexAI = vertexAI as jest.Mock;
44+
mockGoogleAI = googleAI as jest.Mock;
45+
});
46+
47+
describe('constructor', () => {
48+
test('should initialize with Vertex AI provider', () => {
49+
embedClient = new GenkitEmbedClient({
50+
provider: 'vertexai',
51+
batchSize: 100,
52+
dimension: 768,
53+
});
54+
55+
expect(embedClient.provider).toBe('vertexai');
56+
expect(embedClient.embedder).toBe('text-embedding-004-model');
57+
expect(mockVertexAI).toHaveBeenCalledWith({
58+
location: 'us-central1',
59+
});
60+
expect(genkit).toHaveBeenCalledWith({
61+
plugins: [undefined], // because the mock returns undefined
62+
});
63+
});
64+
65+
test('should initialize with Google AI provider', () => {
66+
embedClient = new GenkitEmbedClient({
67+
provider: 'googleai',
68+
batchSize: 100,
69+
dimension: 768,
70+
});
71+
72+
expect(embedClient.provider).toBe('googleai');
73+
expect(embedClient.embedder).toBe('gecko-001-model');
74+
expect(mockGoogleAI).toHaveBeenCalledWith({
75+
apiKey: 'test-api-key',
76+
});
77+
expect(genkit).toHaveBeenCalledWith({
78+
plugins: [undefined], // because the mock returns undefined
79+
});
80+
});
81+
});
82+
83+
describe('getEmbeddings', () => {
84+
beforeEach(() => {
85+
embedClient = new GenkitEmbedClient({
86+
provider: 'vertexai',
87+
batchSize: 100,
88+
dimension: 768,
89+
});
90+
});
91+
92+
test('should return embeddings for a batch of inputs', async () => {
93+
const mockResults = [{embedding: [1, 2, 3]}, {embedding: [4, 5, 6]}];
94+
mockEmbedMany.mockResolvedValueOnce(mockResults);
95+
96+
const inputs = ['input1', 'input2'];
97+
const embeddings = await embedClient.getEmbeddings(inputs);
98+
99+
expect(mockEmbedMany).toHaveBeenCalledWith({
100+
embedder: embedClient.embedder,
101+
content: inputs,
102+
});
103+
104+
expect(embeddings).toEqual([
105+
[1, 2, 3],
106+
[4, 5, 6],
107+
]);
108+
});
109+
110+
test('should throw an error if embedding fails', async () => {
111+
mockEmbedMany.mockRejectedValueOnce(new Error('Embedding failed'));
112+
113+
await expect(embedClient.getEmbeddings(['input'])).rejects.toThrow(
114+
'Embedding failed'
115+
);
116+
117+
expect(mockEmbedMany).toHaveBeenCalledWith({
118+
embedder: embedClient.embedder,
119+
content: ['input'],
120+
});
121+
});
122+
});
123+
124+
describe('getSingleEmbedding', () => {
125+
beforeEach(() => {
126+
embedClient = new GenkitEmbedClient({
127+
provider: 'googleai',
128+
batchSize: 100,
129+
dimension: 768,
130+
});
131+
});
132+
133+
test('should return a single embedding for an input', async () => {
134+
mockEmbed.mockResolvedValueOnce([7, 8, 9]); // Changed to return array directly
135+
136+
const input = 'input1';
137+
const embedding = await embedClient.getSingleEmbedding(input);
138+
139+
expect(mockEmbed).toHaveBeenCalledWith({
140+
embedder: embedClient.embedder,
141+
content: input,
142+
});
143+
144+
expect(embedding).toEqual([7, 8, 9]);
145+
});
146+
147+
test('should throw an error if embedding fails', async () => {
148+
mockEmbed.mockRejectedValueOnce(new Error('Embedding failed'));
149+
150+
await expect(embedClient.getSingleEmbedding('input')).rejects.toThrow(
151+
'Embedding failed'
152+
);
153+
154+
expect(mockEmbed).toHaveBeenCalledWith({
155+
embedder: embedClient.embedder,
156+
content: 'input',
157+
});
158+
});
159+
});
160+
});

firestore-vector-search/functions/jest.config.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ module.exports = {
77
rootDir: './',
88
globals: {
99
'ts-jest': {
10-
tsConfig: '<rootDir>/__tests__/tsconfig.json',
10+
tsconfig: '<rootDir>/tsconfig.test.json', // Correct reference to test-specific config
1111
},
1212
fetch: global.fetch,
1313
},

0 commit comments

Comments
 (0)