Skip to content

Commit

Permalink
Auto-fill APIKeys from GCP Secrets
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii committed Feb 15, 2025
1 parent 589a0ef commit ef2db48
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 115 deletions.
182 changes: 88 additions & 94 deletions poetry.lock

Large diffs are not rendered by default.

33 changes: 32 additions & 1 deletion prediction_market_agent_tooling/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import json
import typing as t
from copy import deepcopy

from eth_account.signers.local import LocalAccount
from pydantic import Field
from pydantic import Field, model_validator
from pydantic.types import SecretStr
from pydantic.v1.types import SecretStr as SecretStrV1
from pydantic_settings import BaseSettings, SettingsConfigDict
from safe_eth.eth import EthereumClient
from safe_eth.safe.safe import SafeV141
from web3 import Account

from prediction_market_agent_tooling.deploy.gcp.utils import gcp_get_secret_value
from prediction_market_agent_tooling.gtypes import (
ChainID,
ChecksumAddress,
Expand Down Expand Up @@ -60,6 +63,34 @@ class APIKeys(BaseSettings):
ENABLE_CACHE: bool = False
CACHE_DIR: str = "./.cache"

@model_validator(mode="before")
@classmethod
def _model_validator(cls, data: t.Any) -> t.Any:
data = deepcopy(data)
data = cls._replace_gcp_secrets(data)
return data

@staticmethod
def _replace_gcp_secrets(data: t.Any) -> t.Any:
if isinstance(data, dict):
for k, v in data.items():
# Check if the value is meant to be fetched from GCP Secret Manager, if so, replace it with it.
if isinstance(v, (str, SecretStr)):
secret_value = (
v.get_secret_value() if isinstance(v, SecretStr) else v
)
if secret_value.startswith("gcps:"):
# We assume that secrets are dictionaries and the value is a key in the dictionary,
# example usage: `BET_FROM_PRIVATE_KEY=gcps:my-agent:private_key`
_, secret_name, key_name = secret_value.split(":")
secret_data = json.loads(gcp_get_secret_value(secret_name))[
key_name
]
data[k] = secret_data
else:
raise ValueError("Data must be a dictionary.")
return data

@property
def manifold_user_id(self) -> str:
return get_authenticated_user(
Expand Down
7 changes: 2 additions & 5 deletions prediction_market_agent_tooling/deploy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
run_deployed_gcp_function,
schedule_deployed_gcp_function,
)
from prediction_market_agent_tooling.deploy.gcp.utils import (
gcp_function_is_active,
gcp_resolve_api_keys_secrets,
)
from prediction_market_agent_tooling.deploy.gcp.utils import gcp_function_is_active
from prediction_market_agent_tooling.deploy.trade_interval import (
FixedInterval,
TradeInterval,
Expand Down Expand Up @@ -228,7 +225,7 @@ def {entrypoint_function_name}(request) -> str:
monitor_agent = MARKET_TYPE_TO_DEPLOYED_AGENT[market_type].from_api_keys(
name=gcp_fname,
start_time=start_time or utcnow(),
api_keys=gcp_resolve_api_keys_secrets(api_keys),
api_keys=api_keys,
)
env_vars |= monitor_agent.model_dump_prefixed()

Expand Down
14 changes: 0 additions & 14 deletions prediction_market_agent_tooling/deploy/gcp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from google.cloud.functions_v2.types.functions import Function
from google.cloud.secretmanager import SecretManagerServiceClient

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.deploy.gcp.kubernetes_models import (
KubernetesCronJobsModel,
)
Expand Down Expand Up @@ -203,16 +202,3 @@ def gcp_get_secret_value(name: str, version: str = "latest") -> str:
return client.access_secret_version(
name=f"projects/{get_gcloud_project_id()}/secrets/{name}/versions/{version}"
).payload.data.decode("utf-8")


def gcp_resolve_api_keys_secrets(api_keys: APIKeys) -> APIKeys:
return APIKeys.model_validate(
api_keys.model_dump_public()
| {
k: gcp_get_secret_value(
name=v.rsplit(":", 1)[0],
version=v.rsplit(":", 1)[1],
)
for k, v in api_keys.model_dump_secrets().items()
}
)
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "prediction-market-agent-tooling"
version = "0.57.19"
version = "0.58.0"
description = "Tools to benchmark, deploy and monitor prediction market agents."
authors = ["Gnosis"]
readme = "README.md"
Expand Down Expand Up @@ -57,6 +57,8 @@ optuna = { version = "^4.1.0", optional = true}
httpx = ">=0.25.2,<1.0.0"
cowdao-cowpy = "^1.0.0rc1"
eth-keys = "^0.6.1"
proto-plus = "^1.0.0"
protobuf = "^4.0.0"

[tool.poetry.extras]
openai = ["openai"]
Expand Down
57 changes: 57 additions & 0 deletions tests/tools/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest.mock import patch

from pydantic import SecretStr

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.gtypes import PrivateKey


def test_gcp_secrets_empty() -> None:
with patch.dict("os.environ", {}, clear=True):
api_keys = APIKeys()
assert api_keys


def test_gcp_secrets_from_env_plain() -> None:
with patch.dict("os.environ", {"BET_FROM_PRIVATE_KEY": "secret"}, clear=True):
api_keys = APIKeys()
assert api_keys.bet_from_private_key.get_secret_value() == "secret"


def test_gcp_secrets_from_env_gcp() -> None:
with patch.dict(
"os.environ", {"BET_FROM_PRIVATE_KEY": "gcps:test:key"}, clear=True
), patch(
"prediction_market_agent_tooling.config.gcp_get_secret_value",
return_value='{"key": "test_secret"}',
):
api_keys = APIKeys()
assert api_keys.bet_from_private_key.get_secret_value() == "test_secret"


def test_gcp_secrets_from_kwargs_plain() -> None:
api_keys = APIKeys(BET_FROM_PRIVATE_KEY=PrivateKey(SecretStr("test_secret")))
assert api_keys.bet_from_private_key.get_secret_value() == "test_secret"


def test_gcp_secrets_from_kwargs_gcp() -> None:
with patch(
"prediction_market_agent_tooling.config.gcp_get_secret_value",
return_value='{"key": "test_secret"}',
):
api_keys = APIKeys(BET_FROM_PRIVATE_KEY=PrivateKey(SecretStr("gcps:test:key")))
assert api_keys.bet_from_private_key.get_secret_value() == "test_secret"


def test_gcp_secrets_from_dict_plain() -> None:
api_keys = APIKeys.model_validate({"BET_FROM_PRIVATE_KEY": "test_secret"})
assert api_keys.bet_from_private_key.get_secret_value() == "test_secret"


def test_gcp_secrets_from_dict_gcp() -> None:
with patch(
"prediction_market_agent_tooling.config.gcp_get_secret_value",
return_value='{"key": "test_secret"}',
):
api_keys = APIKeys.model_validate({"BET_FROM_PRIVATE_KEY": "gcps:test:key"})
assert api_keys.bet_from_private_key.get_secret_value() == "test_secret"

0 comments on commit ef2db48

Please sign in to comment.