Skip to content

Commit 9aed9e1

Browse files
feat: Elastic Text-Embedding Model demo. (#3650)
1 parent c843b17 commit 9aed9e1

File tree

3 files changed

+107
-63
lines changed

3 files changed

+107
-63
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright 2023 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
// [START generativeaionvertexai_sdk_embedding]
20+
async function main(
21+
project,
22+
model = 'text-embedding-preview-0409',
23+
texts = 'banana bread?;banana muffins?',
24+
task = 'QUESTION_ANSWERING',
25+
outputDimensionality = 256,
26+
apiEndpoint = 'us-central1-aiplatform.googleapis.com'
27+
) {
28+
const aiplatform = require('@google-cloud/aiplatform');
29+
const {PredictionServiceClient} = aiplatform.v1;
30+
const {helpers} = aiplatform; // helps construct protobuf.Value objects.
31+
const clientOptions = {apiEndpoint: apiEndpoint};
32+
const match = apiEndpoint.match(/(?<Location>\w+-\w+)/);
33+
const location = match ? match.groups.Location : 'us-centra11';
34+
const endpoint = `projects/${project}/locations/${location}/publishers/google/models/${model}`;
35+
const parameters = helpers.toValue(outputDimensionality);
36+
37+
async function callPredict() {
38+
const instances = texts
39+
.split(';')
40+
.map(e => helpers.toValue({content: e, taskType: task}));
41+
const request = {endpoint, instances, parameters};
42+
const client = new PredictionServiceClient(clientOptions);
43+
const [response] = await client.predict(request);
44+
console.log('Got predict response');
45+
const predictions = response.predictions;
46+
for (const prediction of predictions) {
47+
const embeddings = prediction.structValue.fields.embeddings;
48+
const values = embeddings.structValue.fields.values.listValue.values;
49+
console.log('Got prediction: ' + JSON.stringify(values));
50+
}
51+
}
52+
53+
callPredict();
54+
}
55+
// [END generativeaionvertexai_sdk_embedding]
56+
57+
process.on('unhandledRejection', err => {
58+
console.error(err.message);
59+
process.exitCode = 1;
60+
});
61+
62+
main(...process.argv.slice(2));

ai-platform/snippets/predict-text-embeddings.js

+24-53
Original file line numberDiff line numberDiff line change
@@ -16,70 +16,41 @@
1616

1717
'use strict';
1818

19-
async function main(project, location = 'us-central1') {
20-
// [START aiplatform_sdk_embedding]
21-
/**
22-
* TODO(developer): Uncomment these variables before running the sample.\
23-
* (Not necessary if passing values as arguments)
24-
*/
25-
// const project = 'YOUR_PROJECT_ID';
26-
// const location = 'YOUR_PROJECT_LOCATION';
19+
// [START aiplatform_sdk_embedding]
20+
async function main(
21+
project,
22+
model = 'textembedding-gecko@003',
23+
texts = 'banana bread?;banana muffins?',
24+
task = 'RETRIEVAL_DOCUMENT',
25+
apiEndpoint = 'us-central1-aiplatform.googleapis.com'
26+
) {
2727
const aiplatform = require('@google-cloud/aiplatform');
28-
29-
// Imports the Google Cloud Prediction service client
3028
const {PredictionServiceClient} = aiplatform.v1;
31-
32-
// Import the helper module for converting arbitrary protobuf.Value objects.
33-
const {helpers} = aiplatform;
34-
35-
// Specifies the location of the api endpoint
36-
const clientOptions = {
37-
apiEndpoint: 'us-central1-aiplatform.googleapis.com',
38-
};
39-
40-
const publisher = 'google';
41-
const model = 'textembedding-gecko@001';
42-
43-
// Instantiates a client
44-
const predictionServiceClient = new PredictionServiceClient(clientOptions);
29+
const {helpers} = aiplatform; // helps construct protobuf.Value objects.
30+
const clientOptions = {apiEndpoint: apiEndpoint};
31+
const match = apiEndpoint.match(/(?<Location>\w+-\w+)/);
32+
const location = match ? match.groups.Location : 'us-centra11';
33+
const endpoint = `projects/${project}/locations/${location}/publishers/google/models/${model}`;
4534

4635
async function callPredict() {
47-
// Configure the parent resource
48-
const endpoint = `projects/${project}/locations/${location}/publishers/${publisher}/models/${model}`;
49-
50-
const instance = {
51-
content: 'What is life?',
52-
};
53-
const instanceValue = helpers.toValue(instance);
54-
const instances = [instanceValue];
55-
56-
const parameter = {
57-
temperature: 0,
58-
maxOutputTokens: 256,
59-
topP: 0,
60-
topK: 1,
61-
};
62-
const parameters = helpers.toValue(parameter);
63-
64-
const request = {
65-
endpoint,
66-
instances,
67-
parameters,
68-
};
69-
70-
// Predict request
71-
const [response] = await predictionServiceClient.predict(request);
72-
console.log('Get text embeddings response');
36+
const instances = texts
37+
.split(';')
38+
.map(e => helpers.toValue({content: e, taskType: task}));
39+
const request = {endpoint, instances};
40+
const client = new PredictionServiceClient(clientOptions);
41+
const [response] = await client.predict(request);
42+
console.log('Got predict response');
7343
const predictions = response.predictions;
74-
console.log('\tPredictions :');
7544
for (const prediction of predictions) {
76-
console.log(`\t\tPrediction : ${JSON.stringify(prediction)}`);
45+
const embeddings = prediction.structValue.fields.embeddings;
46+
const values = embeddings.structValue.fields.values.listValue.values;
47+
console.log('Got prediction: ' + JSON.stringify(values));
7748
}
7849
}
7950

8051
callPredict();
81-
// [END aiplatform_sdk_embedding]
8252
}
53+
// [END aiplatform_sdk_embedding]
8354

8455
process.on('unhandledRejection', err => {
8556
console.error(err.message);
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
/*
22
* Copyright 2023 Google LLC
33
*
4-
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* Licensed under the Apache License, Version 2.0 (the 'License');
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
77
*
88
* https://www.apache.org/licenses/LICENSE-2.0
99
*
1010
* Unless required by applicable law or agreed to in writing, software
11-
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* distributed under the License is distributed on an 'AS IS' BASIS,
1212
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
@@ -25,16 +25,27 @@ const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
2525
const cwd = path.join(__dirname, '..');
2626

2727
const project = process.env.CAIP_PROJECT_ID;
28-
const location = 'us-central1';
28+
const texts = [
29+
'banana bread?',
30+
'banana muffin?',
31+
'banana?',
32+
'recipe?',
33+
'muffin recipe?',
34+
].join(';');
2935

30-
describe('AI platform predict text embeddings', () => {
31-
it('should make predictions using a large language model', async () => {
36+
describe('predict text embeddings', () => {
37+
it('should get text embeddings using the latest model', async () => {
3238
const stdout = execSync(
33-
`node ./predict-text-embeddings.js ${project} ${location}`,
34-
{
35-
cwd,
36-
}
39+
`node ./predict-text-embeddings.js ${project} textembedding-gecko@003 '${texts}' RETRIEVAL_DOCUMENT`,
40+
{cwd}
3741
);
38-
assert.match(stdout, /Get text embeddings response/);
42+
assert.match(stdout, /Got predict response/);
43+
});
44+
it('should get text embeddings using the preview model', async () => {
45+
const stdout = execSync(
46+
`node ./predict-text-embeddings-preview.js ${project} text-embedding-preview-0409 '${texts}' QUESTION_ANSWERING 256`,
47+
{cwd}
48+
);
49+
assert.match(stdout, /Got predict response/);
3950
});
4051
});

0 commit comments

Comments
 (0)