Skip to content

Commit bdb7502

Browse files
authored
feat(genai): add Gemini batch prediction samples (#3905)
* feat(genai): add Gemini batch prediction samples * feat(genai): add Gemini batch prediction samples * update console log * update BigQuery output URI
1 parent 7a89075 commit bdb7502

File tree

3 files changed

+282
-0
lines changed

3 files changed

+282
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(projectId, outputUri) {
20+
// [START generativeaionvertexai_batch_predict_gemini_createjob_bigquery]
21+
// Import the aiplatform library
22+
const aiplatformLib = require('@google-cloud/aiplatform');
23+
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;
24+
25+
/**
26+
* TODO(developer): Uncomment/update these variables before running the sample.
27+
*/
28+
// projectId = 'YOUR_PROJECT_ID';
29+
// URI of the output BigQuery table.
30+
// E.g. "bq://[PROJECT].[DATASET].[TABLE]"
31+
// outputUri = 'bq://projectid.dataset.table';
32+
33+
// URI of the multimodal input BigQuery table.
34+
// E.g. "bq://[PROJECT].[DATASET].[TABLE]"
35+
const inputUri =
36+
'bq://storage-samples.generative_ai.batch_requests_for_multimodal_input';
37+
const location = 'us-central1';
38+
const parent = `projects/${projectId}/locations/${location}`;
39+
const modelName = `${parent}/publishers/google/models/gemini-1.5-flash-002`;
40+
41+
// Specify the location of the api endpoint.
42+
const clientOptions = {
43+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
44+
};
45+
46+
// Instantiate the client.
47+
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);
48+
49+
// Create a Gemini batch prediction job using BigQuery input and output datasets.
50+
async function create_batch_prediction_gemini_bq() {
51+
const bqSource = new aiplatform.BigQuerySource({
52+
inputUri: inputUri,
53+
});
54+
55+
const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
56+
bigquerySource: bqSource,
57+
instancesFormat: 'bigquery',
58+
});
59+
60+
const bqDestination = new aiplatform.BigQueryDestination({
61+
outputUri: outputUri,
62+
});
63+
64+
const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
65+
bigqueryDestination: bqDestination,
66+
predictionsFormat: 'bigquery',
67+
});
68+
69+
const batchPredictionJob = new aiplatform.BatchPredictionJob({
70+
displayName: 'Batch predict with Gemini - BigQuery',
71+
model: modelName, // Add model parameters per request in the input BigQuery table.
72+
inputConfig: inputConfig,
73+
outputConfig: outputConfig,
74+
});
75+
76+
const request = {
77+
parent: parent,
78+
batchPredictionJob,
79+
};
80+
81+
// Create batch prediction job request
82+
const [response] = await jobServiceClient.createBatchPredictionJob(request);
83+
console.log('Response name: ', response.name);
84+
// Example response:
85+
// Response name: projects/<project>/locations/us-central1/batchPredictionJobs/<job-id>
86+
}
87+
88+
await create_batch_prediction_gemini_bq();
89+
// [END generativeaionvertexai_batch_predict_gemini_createjob_bigquery]
90+
}
91+
92+
main(...process.argv.slice(2)).catch(err => {
93+
console.error(err.message);
94+
process.exitCode = 1;
95+
});
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(projectId, outputUri) {
20+
// [START generativeaionvertexai_batch_predict_gemini_createjob_gcs]
21+
// Import the aiplatform library
22+
const aiplatformLib = require('@google-cloud/aiplatform');
23+
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;
24+
25+
/**
26+
* TODO(developer): Uncomment/update these variables before running the sample.
27+
*/
28+
// projectId = 'YOUR_PROJECT_ID';
29+
// URI of the output folder in Google Cloud Storage.
30+
// E.g. "gs://[BUCKET]/[OUTPUT]"
31+
// outputUri = 'gs://my-bucket';
32+
33+
// URI of the input file in Google Cloud Storage.
34+
// E.g. "gs://[BUCKET]/[DATASET].jsonl"
35+
// Or try:
36+
// "gs://cloud-samples-data/generative-ai/batch/gemini_multimodal_batch_predict.jsonl"
37+
// for a batch prediction that uses audio, video, and an image.
38+
const inputUri =
39+
'gs://cloud-samples-data/generative-ai/batch/batch_requests_for_multimodal_input.jsonl';
40+
const location = 'us-central1';
41+
const parent = `projects/${projectId}/locations/${location}`;
42+
const modelName = `${parent}/publishers/google/models/gemini-1.5-flash-002`;
43+
44+
// Specify the location of the api endpoint.
45+
const clientOptions = {
46+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
47+
};
48+
49+
// Instantiate the client.
50+
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);
51+
52+
// Create a Gemini batch prediction job using Google Cloud Storage input and output buckets.
53+
async function create_batch_prediction_gemini_gcs() {
54+
const gcsSource = new aiplatform.GcsSource({
55+
uris: [inputUri],
56+
});
57+
58+
const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
59+
gcsSource: gcsSource,
60+
instancesFormat: 'jsonl',
61+
});
62+
63+
const gcsDestination = new aiplatform.GcsDestination({
64+
outputUriPrefix: outputUri,
65+
});
66+
67+
const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
68+
gcsDestination: gcsDestination,
69+
predictionsFormat: 'jsonl',
70+
});
71+
72+
const batchPredictionJob = new aiplatform.BatchPredictionJob({
73+
displayName: 'Batch predict with Gemini - GCS',
74+
model: modelName,
75+
inputConfig: inputConfig,
76+
outputConfig: outputConfig,
77+
});
78+
79+
const request = {
80+
parent: parent,
81+
batchPredictionJob,
82+
};
83+
84+
// Create batch prediction job request
85+
const [response] = await jobServiceClient.createBatchPredictionJob(request);
86+
console.log('Response name: ', response.name);
87+
// Example response:
88+
// Response name: projects/<project>/locations/us-central1/batchPredictionJobs/<job-id>
89+
}
90+
91+
await create_batch_prediction_gemini_gcs();
92+
// [END generativeaionvertexai_batch_predict_gemini_createjob_gcs]
93+
}
94+
95+
main(...process.argv.slice(2)).catch(err => {
96+
console.error(err.message);
97+
process.exitCode = 1;
98+
});
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
const {assert} = require('chai');
20+
const {after, describe, it} = require('mocha');
21+
const cp = require('child_process');
22+
const {JobServiceClient} = require('@google-cloud/aiplatform');
23+
24+
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
25+
26+
describe('Batch predict with Gemini', async () => {
27+
const projectId = process.env.CAIP_PROJECT_ID;
28+
const outputGCSUri = 'gs://ucaip-samples-test-output/';
29+
const outputBqUri = `bq://${process.env.CAIP_PROJECT_ID}.gen_ai_batch_prediction.predictions_${Date.now()}`;
30+
const location = 'us-central1';
31+
32+
const jobServiceClient = new JobServiceClient({
33+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
34+
});
35+
let batchPredictionGcsJobId;
36+
let batchPredictionBqJobId;
37+
38+
after(async () => {
39+
let name = jobServiceClient.batchPredictionJobPath(
40+
projectId,
41+
location,
42+
batchPredictionGcsJobId
43+
);
44+
cancelAndDeleteJob(name);
45+
46+
name = jobServiceClient.batchPredictionJobPath(
47+
projectId,
48+
location,
49+
batchPredictionBqJobId
50+
);
51+
cancelAndDeleteJob(name);
52+
53+
function cancelAndDeleteJob(name) {
54+
const cancelRequest = {
55+
name,
56+
};
57+
58+
jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => {
59+
const deleteRequest = {
60+
name,
61+
};
62+
63+
return jobServiceClient.deleteBatchPredictionJob(deleteRequest);
64+
});
65+
}
66+
});
67+
68+
it('should create Batch prediction Gemini job with GCS ', async () => {
69+
const response = execSync(
70+
`node ./batch-prediction/batch-predict-gcs.js ${projectId} ${outputGCSUri}`
71+
);
72+
73+
assert.match(response, new RegExp('/batchPredictionJobs/'));
74+
batchPredictionGcsJobId = response
75+
.split('/locations/us-central1/batchPredictionJobs/')[1]
76+
.split('\n')[0];
77+
}).timeout(10000);
78+
79+
it('should create Batch prediction Gemini job with BigQuery', async () => {
80+
const response = execSync(
81+
`node ./batch-prediction/batch-predict-bq.js ${projectId} ${outputBqUri}`
82+
);
83+
84+
assert.match(response, new RegExp('/batchPredictionJobs/'));
85+
batchPredictionBqJobId = response
86+
.split('/locations/us-central1/batchPredictionJobs/')[1]
87+
.split('\n')[0];
88+
}).timeout(10000);
89+
});

0 commit comments

Comments
 (0)