|
16 | 16 |
|
17 | 17 | 'use strict';
|
18 | 18 |
|
19 |
| -async function main(project, location = 'us-central1') { |
20 |
| - // [START aiplatform_sdk_embedding] |
21 |
| - /** |
22 |
| - * TODO(developer): Uncomment these variables before running the sample.\ |
23 |
| - * (Not necessary if passing values as arguments) |
24 |
| - */ |
25 |
| - // const project = 'YOUR_PROJECT_ID'; |
26 |
| - // const location = 'YOUR_PROJECT_LOCATION'; |
| 19 | +// [START aiplatform_sdk_embedding] |
| 20 | +async function main( |
| 21 | + project, |
| 22 | + model = 'textembedding-gecko@003', |
| 23 | + texts = 'banana bread?;banana muffins?', |
| 24 | + task = 'RETRIEVAL_DOCUMENT', |
| 25 | + apiEndpoint = 'us-central1-aiplatform.googleapis.com' |
| 26 | +) { |
27 | 27 | const aiplatform = require('@google-cloud/aiplatform');
|
28 |
| - |
29 |
| - // Imports the Google Cloud Prediction service client |
30 | 28 | const {PredictionServiceClient} = aiplatform.v1;
|
31 |
| - |
32 |
| - // Import the helper module for converting arbitrary protobuf.Value objects. |
33 |
| - const {helpers} = aiplatform; |
34 |
| - |
35 |
| - // Specifies the location of the api endpoint |
36 |
| - const clientOptions = { |
37 |
| - apiEndpoint: 'us-central1-aiplatform.googleapis.com', |
38 |
| - }; |
39 |
| - |
40 |
| - const publisher = 'google'; |
41 |
| - const model = 'textembedding-gecko@001'; |
42 |
| - |
43 |
| - // Instantiates a client |
44 |
| - const predictionServiceClient = new PredictionServiceClient(clientOptions); |
| 29 | + const {helpers} = aiplatform; // helps construct protobuf.Value objects. |
| 30 | + const clientOptions = {apiEndpoint: apiEndpoint}; |
| 31 | + const match = apiEndpoint.match(/(?<Location>\w+-\w+)/); |
| 32 | + const location = match ? match.groups.Location : 'us-centra11'; |
| 33 | + const endpoint = `projects/${project}/locations/${location}/publishers/google/models/${model}`; |
45 | 34 |
|
46 | 35 | async function callPredict() {
|
47 |
| - // Configure the parent resource |
48 |
| - const endpoint = `projects/${project}/locations/${location}/publishers/${publisher}/models/${model}`; |
49 |
| - |
50 |
| - const instance = { |
51 |
| - content: 'What is life?', |
52 |
| - }; |
53 |
| - const instanceValue = helpers.toValue(instance); |
54 |
| - const instances = [instanceValue]; |
55 |
| - |
56 |
| - const parameter = { |
57 |
| - temperature: 0, |
58 |
| - maxOutputTokens: 256, |
59 |
| - topP: 0, |
60 |
| - topK: 1, |
61 |
| - }; |
62 |
| - const parameters = helpers.toValue(parameter); |
63 |
| - |
64 |
| - const request = { |
65 |
| - endpoint, |
66 |
| - instances, |
67 |
| - parameters, |
68 |
| - }; |
69 |
| - |
70 |
| - // Predict request |
71 |
| - const [response] = await predictionServiceClient.predict(request); |
72 |
| - console.log('Get text embeddings response'); |
| 36 | + const instances = texts |
| 37 | + .split(';') |
| 38 | + .map(e => helpers.toValue({content: e, taskType: task})); |
| 39 | + const request = {endpoint, instances}; |
| 40 | + const client = new PredictionServiceClient(clientOptions); |
| 41 | + const [response] = await client.predict(request); |
| 42 | + console.log('Got predict response'); |
73 | 43 | const predictions = response.predictions;
|
74 |
| - console.log('\tPredictions :'); |
75 | 44 | for (const prediction of predictions) {
|
76 |
| - console.log(`\t\tPrediction : ${JSON.stringify(prediction)}`); |
| 45 | + const embeddings = prediction.structValue.fields.embeddings; |
| 46 | + const values = embeddings.structValue.fields.values.listValue.values; |
| 47 | + console.log('Got prediction: ' + JSON.stringify(values)); |
77 | 48 | }
|
78 | 49 | }
|
79 | 50 |
|
80 | 51 | callPredict();
|
81 |
| - // [END aiplatform_sdk_embedding] |
82 | 52 | }
|
| 53 | +// [END aiplatform_sdk_embedding] |
83 | 54 |
|
84 | 55 | process.on('unhandledRejection', err => {
|
85 | 56 | console.error(err.message);
|
|
0 commit comments