Skip to content

Commit 6ab59e4

Browse files
committed
feat(utils): add griptape cloud context util
WIP WIP
1 parent 4073605 commit 6ab59e4

File tree

7 files changed

+159
-15
lines changed

7 files changed

+159
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3131
- Support for new serialization metadata, `serialization_key` and `deserialization_key` for more granular control over serialization.
3232
- Support for OpenAi reasoning models, `o1` and `o3`.
3333
- Support for enums in `GriptapeCloudToolTool`.
34-
- `griptape.utils.griptape_cloud_utils.configure_events` for automatically configuring the `EventBus` for use with Griptape Cloud Managed Structures.
34+
- `griptape.utils.griptape_cloud_utils.GriptapeCloud` for automatically configuring the `EventBus` for use with Griptape Cloud Managed Structures.
3535

3636
### Changed
3737

griptape/events/finish_structure_run_event.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
@define
1414
class FinishStructureRunEvent(BaseEvent):
1515
structure_id: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
16-
output_task_input: BaseArtifact = field(kw_only=True, metadata={"serializable": True})
17-
output_task_output: Optional[BaseArtifact] = field(kw_only=True, metadata={"serializable": True})
16+
output_task_input: Optional[BaseArtifact] = field(kw_only=True, default=None, metadata={"serializable": True})
17+
output_task_output: Optional[BaseArtifact] = field(kw_only=True, default=None, metadata={"serializable": True})

griptape/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .file_utils import get_mime_type
2525
from .contextvars_utils import with_contextvars
2626
from .json_schema_utils import build_strict_schema, resolve_refs
27+
from .griptape_cloud import GriptapeCloud
2728

2829

2930
def minify_json(value: str) -> str:
@@ -58,4 +59,5 @@ def minify_json(value: str) -> str:
5859
"with_contextvars",
5960
"build_strict_schema",
6061
"resolve_refs",
62+
"GriptapeCloud",
6163
]

griptape/utils/griptape_cloud.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import TYPE_CHECKING, Any, Optional, TypeVar
5+
6+
from attrs import define, field
7+
from typing_extensions import ParamSpec
8+
9+
from griptape.artifacts import BaseArtifact, GenericArtifact
10+
from griptape.events import EventBus, EventListener, FinishStructureRunEvent
11+
from griptape.utils.decorators import lazy_property
12+
13+
P = ParamSpec("P")
14+
T = TypeVar("T")
15+
16+
if TYPE_CHECKING:
17+
from types import TracebackType
18+
19+
from griptape.observability.observability import Observability
20+
21+
22+
@define()
23+
class GriptapeCloud:
24+
"""Utility for working with Griptape Cloud.
25+
26+
Args:
27+
_event_listener (EventListener): Event Listener to use. Defaults to an EventListener with a GriptapeCloudEventListenerDriver.
28+
_observability (Observability): Observability to use. Defaults to an Observability with a GriptapeCloudObservabilityDriver.
29+
observe (bool): Whether to enable observability. Enabling requires the `drivers-observability-griptape-cloud` extra.
30+
"""
31+
32+
_event_listener: EventListener = field(default=None, kw_only=True, alias="event_listener")
33+
_observability: Observability = field(default=None, kw_only=True, alias="observability")
34+
observe: bool = field(default=False, kw_only=True)
35+
_output: BaseArtifact = field(default=None, init=False)
36+
37+
@lazy_property()
38+
def event_listener(self) -> EventListener:
39+
from griptape.drivers.event_listener.griptape_cloud import GriptapeCloudEventListenerDriver
40+
41+
return EventListener(event_listener_driver=GriptapeCloudEventListenerDriver())
42+
43+
@lazy_property()
44+
def observability(self) -> Observability:
45+
from griptape.drivers.observability.griptape_cloud import GriptapeCloudObservabilityDriver
46+
from griptape.observability.observability import Observability
47+
48+
return Observability(observability_driver=GriptapeCloudObservabilityDriver())
49+
50+
@property
51+
def output(self) -> BaseArtifact:
52+
return self._output
53+
54+
@output.setter
55+
def output(self, value: BaseArtifact | Any) -> None:
56+
if not isinstance(value, BaseArtifact):
57+
self._output = GenericArtifact(value)
58+
else:
59+
self._output = value
60+
61+
@property
62+
def structure_run_id(self) -> str:
63+
return os.environ["GT_CLOUD_STRUCTURE_RUN_ID"]
64+
65+
@property
66+
def in_managed_environment(self) -> bool:
67+
return "GT_CLOUD_STRUCTURE_RUN_ID" in os.environ
68+
69+
def __enter__(self) -> GriptapeCloud:
70+
from griptape.observability.observability import Observability
71+
72+
if self.in_managed_environment:
73+
EventBus.add_event_listener(self.event_listener)
74+
75+
if self.observe:
76+
Observability.set_global_driver(self.observability.observability_driver)
77+
self.observability.observability_driver.__enter__()
78+
79+
return self
80+
81+
def __exit__(
82+
self,
83+
exc_type: Optional[type[BaseException]],
84+
exc_value: Optional[BaseException],
85+
exc_traceback: Optional[TracebackType],
86+
) -> None:
87+
from griptape.observability.observability import Observability
88+
89+
if self.in_managed_environment:
90+
if self.output is not None:
91+
EventBus.publish_event(FinishStructureRunEvent(output_task_output=self.output), flush=True)
92+
EventBus.remove_event_listener(self.event_listener)
93+
94+
if self.observe:
95+
Observability.set_global_driver(None)
96+
self.observability.observability_driver.__exit__(exc_type, exc_value, exc_traceback)

griptape/utils/griptape_cloud_utils.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,17 @@ def test_try_publish_event_payload(self, mock_post, driver):
8383
headers={"Authorization": "Bearer foo bar"},
8484
)
8585

86-
def try_publish_event_payload_batch(self, mock_post, driver):
86+
def test_validate_api_key(self):
87+
with pytest.raises(ValueError, match="No value was found"):
88+
GriptapeCloudEventListenerDriver()
89+
90+
def test_validate_run_id(self):
91+
os.environ["GT_CLOUD_API_KEY"] = "foo bar"
92+
93+
with pytest.raises(ValueError, match="structure_run_id must be set"):
94+
GriptapeCloudEventListenerDriver()
95+
96+
def test_try_publish_event_payload_batch(self, mock_post, driver):
8797
for _ in range(3):
8898
event = MockEvent()
8999
driver.try_publish_event_payload(event.to_dict())
Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,58 @@
11
import os
22

3+
import pytest
4+
5+
from griptape.artifacts.text_artifact import TextArtifact
36
from griptape.drivers.event_listener.griptape_cloud import GriptapeCloudEventListenerDriver
7+
from griptape.events import EventListener
48
from griptape.events.event_bus import EventBus
5-
from griptape.utils.griptape_cloud_utils import configure_events
9+
from griptape.observability.observability import Observability
10+
from griptape.utils import GriptapeCloud
11+
from tests.mocks.mock_event_listener_driver import MockEventListenerDriver
612

713

814
class TestGriptapeCloudUtils:
9-
def test_configure_events(self):
15+
@pytest.fixture(autouse=True)
16+
def set_up_environment(self):
1017
os.environ["GT_CLOUD_API_KEY"] = "foo"
1118
os.environ["GT_CLOUD_STRUCTURE_RUN_ID"] = "bar"
1219

13-
configure_events()
20+
yield
21+
22+
del os.environ["GT_CLOUD_API_KEY"]
23+
del os.environ["GT_CLOUD_STRUCTURE_RUN_ID"]
24+
25+
@pytest.mark.parametrize("observe", [True, False])
26+
def test_context_manager(self, observe):
27+
with GriptapeCloud(observe=observe) as context:
28+
assert context.in_managed_environment
29+
30+
assert len(EventBus.event_listeners) == 1
31+
assert isinstance(EventBus.event_listeners[0].event_listener_driver, GriptapeCloudEventListenerDriver)
32+
33+
if observe:
34+
assert Observability.get_global_driver() == context.observability.observability_driver
35+
else:
36+
assert Observability.get_global_driver() is None
37+
38+
assert len(EventBus.event_listeners) == 0
39+
assert Observability.get_global_driver() is None
40+
41+
def test_in_managed_environment(self):
42+
context = GriptapeCloud()
43+
44+
assert context.in_managed_environment
45+
46+
def test_structure_run_id(self):
47+
context = GriptapeCloud()
48+
49+
assert context.structure_run_id == "bar"
50+
51+
def test_output(self):
52+
with GriptapeCloud(event_listener=EventListener(event_listener_driver=MockEventListenerDriver())) as context:
53+
context.output = "foo"
54+
assert context.output.value == "foo"
1455

15-
assert len(EventBus.event_listeners) == 1
16-
assert isinstance(EventBus.event_listeners[0].event_listener_driver, GriptapeCloudEventListenerDriver)
56+
output = TextArtifact("bar")
57+
context.output = output
58+
assert context.output is output

0 commit comments

Comments
 (0)