Skip to content

feat: GenAI SDK client - add zero-shot prompt optimizer: an option to quickly improve or generate system instructions or a prompt. #5589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import logging

from tests.unit.vertexai.genai.replays import pytest_helper

# from vertexai._genai import types

logger = logging.getLogger("vertexai_genai.promptoptimizer")
logging.basicConfig(encoding="utf-8", level=logging.INFO, force=True)


def test_optimize_prompt(client):
"""Tests the optimize request parameters method."""

# client._api_client._http_options.base_url = (
# "https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
# )
client._api_client._http_options.base_url = (
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com"
)
test_prompt = "Generate system instructions for analyzing medical articles"
response = client.prompt_optimizer.optimize_prompt(prompt=test_prompt)
logger.info("response: %s", response)
# assert isinstance(response, types.OptimizeResponse)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
test_method="prompt_optimizer.optimize_prompt",
)
152 changes: 116 additions & 36 deletions vertexai/_genai/prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import logging
import time
from typing import Any, Optional, Union
from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union
from urllib.parse import urlencode

from google.genai import _api_module
Expand All @@ -34,11 +34,14 @@
logger = logging.getLogger("vertexai_genai.promptoptimizer")


def _OptimizeRequestParameters_to_vertex(
def _OptimizePromptRequestParams_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ["content"]) is not None:
setv(to_object, ["content"], getv(from_object, ["content"]))

if getv(from_object, ["config"]) is not None:
setv(to_object, ["config"], getv(from_object, ["config"]))

Expand Down Expand Up @@ -229,6 +232,8 @@ def _OptimizeResponse_from_vertex(
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ["content"]) is not None:
setv(to_object, ["content"], getv(from_object, ["content"]))

return to_object

Expand Down Expand Up @@ -383,25 +388,29 @@ def _CustomJob_from_vertex(
class PromptOptimizer(_api_module.BaseModule):
"""Prompt Optimizer"""

def _optimize_dummy(
self, *, config: Optional[types.OptimizeConfigOrDict] = None
) -> types.OptimizeResponse:
"""Optimize multiple prompts."""
def _optimize_prompt(
self,
*,
content: Optional[types.ContentOrDict] = None,
config: Optional[types.OptimizePromptConfigOrDict] = None,
) -> Iterator[types.OptimizeResponse]:
"""Optimize a single prompt."""

parameter_model = types._OptimizeRequestParameters(
parameter_model = types._OptimizePromptRequestParams(
content=content,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
else:
request_dict = _OptimizeRequestParameters_to_vertex(parameter_model)
request_dict = _OptimizePromptRequestParams_to_vertex(parameter_model)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = ":optimize".format_map(request_url_dict)
path = "tuningJobs:optimizePrompt".format_map(request_url_dict)
else:
path = ":optimize"
path = "tuningJobs:optimizePrompt"

query_params = request_dict.get("_query")
if query_params:
Expand All @@ -419,19 +428,27 @@ def _optimize_dummy(
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = self._api_client.request("post", path, request_dict, http_options)
if config is not None and getattr(config, "should_return_http_response", None):
raise ValueError(
"Accessing the raw HTTP response is not supported in streaming"
" methods."
)

response_dict = "" if not response.body else json.loads(response.body)
for response in self._api_client.request_streamed(
"post", path, request_dict, http_options
):

if self._api_client.vertexai:
response_dict = _OptimizeResponse_from_vertex(response_dict)
response_dict = "" if not response.body else json.loads(response.body)

return_value = types.OptimizeResponse._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
if self._api_client.vertexai:
response_dict = _OptimizeResponse_from_vertex(response_dict)

self._api_client._verify_response(return_value)
return return_value
return_value = types.OptimizeResponse._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)

self._api_client._verify_response(return_value)
yield return_value

def _create_custom_job_resource(
self,
Expand Down Expand Up @@ -660,29 +677,81 @@ def optimize(
job = self._wait_for_completion(job_id)
return job

def optimize_prompt(
self,
*,
prompt: str,
config: Optional[types.OptimizePromptConfig] = None,
) -> types.OptimizeResponse:
"""Call PO-zero prompt optimizer for a single prompt.

Args:
prompt: The prompt to optimize.
config: None for now. Support will be added in the future.

Returns:
The optimized prompt.
"""
if config is not None:
raise ValueError(
"Currently, config is not supported for a single prompt"
" optimization."
)

prompt = types.Content(parts=[types.Part(text=prompt)], role="user")
remaining_remote_calls = 100
logger.info(
"Prompt Optimizer is enabled with max remote calls:"
f" {remaining_remote_calls}."
)
full_response = ""
while remaining_remote_calls > 0:
response = self._optimize_prompt(content=prompt)
logger.info("streaming response: %s", response)
for chunk in response:
yield chunk
logger.info("streaming chunk: %s", chunk)
if (
chunk is not None
and isinstance(chunk.content, types.Content)
and chunk.content.parts is not None
):
full_response += chunk.content.parts[0].text

remaining_remote_calls -= 1
if remaining_remote_calls == 0:
logger.info("Reached max remote calls for prompt optimizer.")
logger.info("response: %s", full_response)

return full_response


class AsyncPromptOptimizer(_api_module.BaseModule):
"""Prompt Optimizer"""

async def _optimize_dummy(
self, *, config: Optional[types.OptimizeConfigOrDict] = None
) -> types.OptimizeResponse:
"""Optimize multiple prompts."""
async def _optimize_prompt(
self,
*,
content: Optional[types.ContentOrDict] = None,
config: Optional[types.OptimizePromptConfigOrDict] = None,
) -> Awaitable[AsyncIterator[types.OptimizeResponse]]:
"""Optimize a single prompt."""

parameter_model = types._OptimizeRequestParameters(
parameter_model = types._OptimizePromptRequestParams(
content=content,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
else:
request_dict = _OptimizeRequestParameters_to_vertex(parameter_model)
request_dict = _OptimizePromptRequestParams_to_vertex(parameter_model)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = ":optimize".format_map(request_url_dict)
path = "tuningJobs:optimizePrompt".format_map(request_url_dict)
else:
path = ":optimize"
path = "tuningJobs:optimizePrompt"

query_params = request_dict.get("_query")
if query_params:
Expand All @@ -700,21 +769,32 @@ async def _optimize_dummy(
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = await self._api_client.async_request(
if config is not None and getattr(config, "should_return_http_response", None):
raise ValueError(
"Accessing the raw HTTP response is not supported in streaming"
" methods."
)

response_stream = await self._api_client.async_request_streamed(
"post", path, request_dict, http_options
)

response_dict = "" if not response.body else json.loads(response.body)
async def async_generator(): # type: ignore[no-untyped-def]
async for response in response_stream:

if self._api_client.vertexai:
response_dict = _OptimizeResponse_from_vertex(response_dict)
response_dict = "" if not response.body else json.loads(response.body)

return_value = types.OptimizeResponse._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
if self._api_client.vertexai:
response_dict = _OptimizeResponse_from_vertex(response_dict)

self._api_client._verify_response(return_value)
return return_value
return_value = types.OptimizeResponse._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)

self._api_client._verify_response(return_value)
yield return_value

return async_generator() # type: ignore[no-untyped-call, no-any-return]

async def _create_custom_job_resource(
self,
Expand Down
Loading
Loading