Skip to content

Commit 1e39d26

Browse files
authored
feat(utils): add utility function for configuring Event Bus with GTC (#1651)
1 parent a3723a1 commit 1e39d26

File tree

7 files changed

+184
-5
lines changed

7 files changed

+184
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3232
- Support for OpenAi reasoning models, `o1` and `o3`.
3333
- Support for enums in `GriptapeCloudToolTool`.
3434
- `LocalRerankDriver` for reranking locally.
35+
- `griptape.utils.griptape_cloud.GriptapeCloudStructure` for automatically configuring Cloud-specific Drivers when in the Griptape Cloud Structures Runtime.
3536

3637
### Changed
3738

griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class GriptapeCloudEventListenerDriver(BaseEventListenerDriver):
2626
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
2727
kw_only=True,
2828
)
29-
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]), kw_only=True)
29+
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True)
3030
headers: dict = field(
3131
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
3232
kw_only=True,
@@ -36,12 +36,22 @@ class GriptapeCloudEventListenerDriver(BaseEventListenerDriver):
3636
)
3737

3838
@structure_run_id.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
39-
def validate_run_id(self, _: Attribute, structure_run_id: str) -> None:
39+
def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None:
4040
if structure_run_id is None:
4141
raise ValueError(
4242
"structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).",
4343
)
4444

45+
@api_key.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
46+
def validate_api_key(self, _: Attribute, api_key: Optional[str]) -> None:
47+
if api_key is None:
48+
raise ValueError(
49+
"No value was found for the 'GT_CLOUD_API_KEY' environment variable. "
50+
"This environment variable is required when running in Griptape Cloud for authorization. "
51+
"You can generate a Griptape Cloud API Key by visiting https://cloud.griptape.ai/keys . "
52+
"Specify it as an environment variable when creating a Managed Structure in Griptape Cloud."
53+
)
54+
4555
def publish_event(self, event: BaseEvent | dict) -> None:
4656
from griptape.observability.observability import Observability
4757

griptape/events/finish_structure_run_event.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
@define
1414
class FinishStructureRunEvent(BaseEvent):
1515
structure_id: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
16-
output_task_input: BaseArtifact = field(kw_only=True, metadata={"serializable": True})
17-
output_task_output: Optional[BaseArtifact] = field(kw_only=True, metadata={"serializable": True})
16+
output_task_input: Optional[BaseArtifact] = field(kw_only=True, default=None, metadata={"serializable": True})
17+
output_task_output: Optional[BaseArtifact] = field(kw_only=True, default=None, metadata={"serializable": True})

griptape/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .file_utils import get_mime_type
2525
from .contextvars_utils import with_contextvars
2626
from .json_schema_utils import build_strict_schema, resolve_refs
27+
from .griptape_cloud import GriptapeCloudStructure
2728

2829

2930
def minify_json(value: str) -> str:
@@ -58,4 +59,5 @@ def minify_json(value: str) -> str:
5859
"with_contextvars",
5960
"build_strict_schema",
6061
"resolve_refs",
62+
"GriptapeCloudStructure",
6163
]

griptape/utils/griptape_cloud.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import TYPE_CHECKING, Any, Optional, TypeVar
5+
6+
from attrs import define, field
7+
from typing_extensions import ParamSpec
8+
9+
from griptape.artifacts import BaseArtifact, GenericArtifact
10+
from griptape.events import EventBus, EventListener, FinishStructureRunEvent
11+
from griptape.utils.decorators import lazy_property
12+
13+
P = ParamSpec("P")
14+
T = TypeVar("T")
15+
16+
if TYPE_CHECKING:
17+
from types import TracebackType
18+
19+
from griptape.observability.observability import Observability
20+
21+
22+
@define()
23+
class GriptapeCloudStructure:
24+
"""Utility for working with Griptape Cloud Structures.
25+
26+
Attributes:
27+
_event_listener: Event Listener to use. Defaults to an EventListener with a GriptapeCloudEventListenerDriver.
28+
_observability: Observability to use. Defaults to an Observability with a GriptapeCloudObservabilityDriver.
29+
observe: Whether to enable observability. Enabling requires the `drivers-observability-griptape-cloud` extra.
30+
"""
31+
32+
_event_listener: EventListener = field(default=None, kw_only=True, alias="event_listener")
33+
_observability: Observability = field(default=None, kw_only=True, alias="observability")
34+
observe: bool = field(default=False, kw_only=True)
35+
_output: BaseArtifact = field(default=None, init=False)
36+
37+
@lazy_property()
38+
def event_listener(self) -> EventListener:
39+
from griptape.drivers.event_listener.griptape_cloud import GriptapeCloudEventListenerDriver
40+
41+
return EventListener(event_listener_driver=GriptapeCloudEventListenerDriver())
42+
43+
@lazy_property()
44+
def observability(self) -> Observability:
45+
from griptape.drivers.observability.griptape_cloud import GriptapeCloudObservabilityDriver
46+
from griptape.observability.observability import Observability
47+
48+
return Observability(observability_driver=GriptapeCloudObservabilityDriver())
49+
50+
@property
51+
def output(self) -> BaseArtifact:
52+
return self._output
53+
54+
@output.setter
55+
def output(self, value: BaseArtifact | Any) -> None:
56+
if not isinstance(value, BaseArtifact):
57+
self._output = GenericArtifact(value)
58+
else:
59+
self._output = value
60+
61+
@property
62+
def structure_run_id(self) -> str:
63+
return os.environ["GT_CLOUD_STRUCTURE_RUN_ID"]
64+
65+
@property
66+
def in_managed_environment(self) -> bool:
67+
return "GT_CLOUD_STRUCTURE_RUN_ID" in os.environ
68+
69+
def __enter__(self) -> GriptapeCloudStructure:
70+
from griptape.observability.observability import Observability
71+
72+
if self.in_managed_environment:
73+
EventBus.add_event_listener(self.event_listener)
74+
75+
if self.observe:
76+
Observability.set_global_driver(self.observability.observability_driver)
77+
self.observability.observability_driver.__enter__()
78+
79+
return self
80+
81+
def __exit__(
82+
self,
83+
exc_type: Optional[type[BaseException]],
84+
exc_value: Optional[BaseException],
85+
exc_traceback: Optional[TracebackType],
86+
) -> None:
87+
from griptape.observability.observability import Observability
88+
89+
if self.in_managed_environment:
90+
if self.output is not None:
91+
EventBus.publish_event(FinishStructureRunEvent(output_task_output=self.output), flush=True)
92+
EventBus.remove_event_listener(self.event_listener)
93+
94+
if self.observe:
95+
Observability.set_global_driver(None)
96+
self.observability.observability_driver.__exit__(exc_type, exc_value, exc_traceback)

tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,17 @@ def test_try_publish_event_payload(self, mock_post, driver):
8383
headers={"Authorization": "Bearer foo bar"},
8484
)
8585

86-
def try_publish_event_payload_batch(self, mock_post, driver):
86+
def test_validate_api_key(self):
87+
with pytest.raises(ValueError, match="No value was found"):
88+
GriptapeCloudEventListenerDriver()
89+
90+
def test_validate_run_id(self):
91+
os.environ["GT_CLOUD_API_KEY"] = "foo bar"
92+
93+
with pytest.raises(ValueError, match="structure_run_id must be set"):
94+
GriptapeCloudEventListenerDriver()
95+
96+
def test_try_publish_event_payload_batch(self, mock_post, driver):
8797
for _ in range(3):
8898
event = MockEvent()
8999
driver.try_publish_event_payload(event.to_dict())
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
3+
import pytest
4+
5+
from griptape.artifacts.text_artifact import TextArtifact
6+
from griptape.drivers.event_listener.griptape_cloud import GriptapeCloudEventListenerDriver
7+
from griptape.events import EventListener
8+
from griptape.events.event_bus import EventBus
9+
from griptape.observability.observability import Observability
10+
from griptape.utils import GriptapeCloudStructure
11+
from tests.mocks.mock_event_listener_driver import MockEventListenerDriver
12+
13+
14+
class TestGriptapeCloudUtils:
15+
@pytest.fixture(autouse=True)
16+
def set_up_environment(self):
17+
os.environ["GT_CLOUD_API_KEY"] = "foo"
18+
os.environ["GT_CLOUD_STRUCTURE_RUN_ID"] = "bar"
19+
20+
yield
21+
22+
del os.environ["GT_CLOUD_API_KEY"]
23+
del os.environ["GT_CLOUD_STRUCTURE_RUN_ID"]
24+
25+
@pytest.mark.parametrize("observe", [True, False])
26+
def test_context_manager(self, observe):
27+
with GriptapeCloudStructure(observe=observe) as context:
28+
assert context.in_managed_environment
29+
30+
assert len(EventBus.event_listeners) == 1
31+
assert isinstance(EventBus.event_listeners[0].event_listener_driver, GriptapeCloudEventListenerDriver)
32+
33+
if observe:
34+
assert Observability.get_global_driver() == context.observability.observability_driver
35+
else:
36+
assert Observability.get_global_driver() is None
37+
38+
assert len(EventBus.event_listeners) == 0
39+
assert Observability.get_global_driver() is None
40+
41+
def test_in_managed_environment(self):
42+
context = GriptapeCloudStructure()
43+
44+
assert context.in_managed_environment
45+
46+
def test_structure_run_id(self):
47+
context = GriptapeCloudStructure()
48+
49+
assert context.structure_run_id == "bar"
50+
51+
def test_output(self):
52+
with GriptapeCloudStructure(
53+
event_listener=EventListener(event_listener_driver=MockEventListenerDriver())
54+
) as context:
55+
context.output = "foo"
56+
assert context.output.value == "foo"
57+
58+
output = TextArtifact("bar")
59+
context.output = output
60+
assert context.output is output

0 commit comments

Comments
 (0)