Skip to content

feat: GenAI SDK client - Add support for context specs when creating agent engine instances #5551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/unit/vertexai/genai/replays/test_create_agent_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@ def test_create_config_lightweight(client):
}


def test_create_with_context_spec(client):
project = "test-project"
location = "us-central1"
parent = f"projects/{project}/locations/{location}"
generation_model = f"{parent}/publishers/google/models/gemini-2.0-flash-001"
embedding_model = f"{parent}/publishers/google/models/text-embedding-005"

agent_engine = client.agent_engines.create(
config={
"context_spec": {
"memory_bank_config": {
"generation_config": {"model": generation_model},
"similarity_search_config": {
"embedding_model": embedding_model,
},
},
},
"http_options": {"api_version": "v1beta1"},
},
)
agent_engine = client.agent_engines.get(name=agent_engine.api_resource.name)
memory_bank_config = agent_engine.api_resource.context_spec.memory_bank_config
assert memory_bank_config.generation_config.model == generation_model
assert (
memory_bank_config.similarity_search_config.embedding_model == embedding_model
)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down
1 change: 1 addition & 0 deletions tests/unit/vertexai/genai/test_agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,7 @@ def test_create_agent_engine_with_env_vars_dict(
gcs_dir_name=None,
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT,
context_spec=None,
)
request_mock.assert_called_with(
"post",
Expand Down
46 changes: 46 additions & 0 deletions vertexai/_genai/agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ def _ReasoningEngineSpec_to_vertex(
return to_object


def _ReasoningEngineContextSpec_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ["memory_bank_config"]) is not None:
setv(
to_object,
["memoryBankConfig"],
getv(from_object, ["memory_bank_config"]),
)

return to_object


def _CreateAgentEngineConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand All @@ -82,6 +97,15 @@ def _CreateAgentEngineConfig_to_vertex(
_ReasoningEngineSpec_to_vertex(getv(from_object, ["spec"]), to_object),
)

if getv(from_object, ["context_spec"]) is not None:
setv(
parent_object,
["contextSpec"],
_ReasoningEngineContextSpec_to_vertex(
getv(from_object, ["context_spec"]), to_object
),
)

return to_object


Expand Down Expand Up @@ -550,6 +574,15 @@ def _UpdateAgentEngineConfig_to_vertex(
_ReasoningEngineSpec_to_vertex(getv(from_object, ["spec"]), to_object),
)

if getv(from_object, ["context_spec"]) is not None:
setv(
parent_object,
["contextSpec"],
_ReasoningEngineContextSpec_to_vertex(
getv(from_object, ["context_spec"]), to_object
),
)

if getv(from_object, ["update_mask"]) is not None:
setv(
parent_object,
Expand Down Expand Up @@ -1976,6 +2009,10 @@ def create(
"config must be a dict or AgentEngineConfig, but got"
f" {type(config)}."
)
context_spec = config.context_spec
if context_spec is not None:
# Conversion to a dict for _create_config
context_spec = context_spec.model_dump()
api_config = self._create_config(
mode="create",
agent_engine=agent_engine,
Expand All @@ -1986,6 +2023,7 @@ def create(
gcs_dir_name=config.gcs_dir_name,
extra_packages=config.extra_packages,
env_vars=config.env_vars,
context_spec=context_spec,
)
operation = self._create(config=api_config)
# TODO: Use a more specific link.
Expand Down Expand Up @@ -2029,6 +2067,7 @@ def _create_config(
gcs_dir_name: Optional[str] = None,
extra_packages: Optional[Sequence[str]] = None,
env_vars: Optional[dict[str, Union[str, Any]]] = None,
context_spec: Optional[dict[str, Any]] = None,
):
import sys
from vertexai.agent_engines import _agent_engines
Expand All @@ -2049,6 +2088,8 @@ def _create_config(
if description is not None:
update_masks.append("description")
config["description"] = description
if context_spec is not None:
config["context_spec"] = context_spec
if agent_engine is not None:
sys_version = f"{sys.version_info.major}.{sys.version_info.minor}"
gcs_dir_name = gcs_dir_name or _agent_engines._DEFAULT_GCS_DIR_NAME
Expand Down Expand Up @@ -2307,6 +2348,10 @@ def update(
"config must be a dict or AgentEngineConfig, but got"
f" {type(config)}."
)
context_spec = config.context_spec
if context_spec is not None:
# Conversion to a dict for _create_config
context_spec = context_spec.model_dump()
api_config = self._create_config(
mode="update",
agent_engine=agent_engine,
Expand All @@ -2317,6 +2362,7 @@ def update(
gcs_dir_name=config.gcs_dir_name,
extra_packages=config.extra_packages,
env_vars=config.env_vars,
context_spec=context_spec,
)
operation = self._update(name=name, config=api_config)
logger.info(
Expand Down
Loading
Loading