Skip to content

Commit 12d7644

Browse files
authored
Automatically create Thread in Griptape Cloud Assistant Driver (#1582)
1 parent 7c1381e commit 12d7644

File tree

5 files changed

+237
-90
lines changed

5 files changed

+237
-90
lines changed

Diff for: CHANGELOG.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717

1818
### Changed
1919
- Task bitshift operators can now take a list of Tasks.
20+
- `GriptapeCloudAssistantDriver` and `OpenAiAssistantDriver` now automatically create a new Thread if one is not provided. Can be disabled with `auto_create_thread=False`.
21+
- `GriptapeCloudAssistantDriver` and `OpenAiAssistantDriver` now return metadata (`thread_id`) on the response Artifact.
22+
- `GriptapeCloudAssistantDriver` now accepts a `thread_alias` parameter for fetching a Thread by alias, creating one if it doesn't exist.
23+
- `EvalEngine` to use structured output when generating evaluation steps.
24+
- `GriptapeCloudVectorStoreDriver.query()` updated to non-deprecated Griptape Cloud API shape.
2025

2126
### Fixed
2227

@@ -34,10 +39,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3439
### Deprecated
3540

3641
- `FuturesExecutorMixin.futures_executor`. Use `FuturesExecutorMixin.create_futures_executor` instead.
37-
### Changed
38-
39-
- `EvalEngine` to use structured output when generating evaluation steps.
40-
- `GriptapeCloudVectorStoreDriver.query()` updated to non-deprecated Griptape Cloud API shape.
4142

4243
## [1.1.1] - 2025-01-03
4344

Diff for: griptape/drivers/assistant/griptape_cloud_assistant_driver.py

+45-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import time
6+
import uuid
67
from typing import Optional
78
from urllib.parse import urljoin
89

@@ -30,6 +31,7 @@ class GriptapeCloudAssistantDriver(BaseAssistantDriver):
3031
input: Optional[str] = field(default=None, kw_only=True)
3132
assistant_id: str = field(kw_only=True)
3233
thread_id: Optional[str] = field(default=None, kw_only=True)
34+
thread_alias: Optional[str] = field(default=None, kw_only=True)
3335
ruleset_ids: Optional[list[str]] = field(default=None, kw_only=True)
3436
additional_ruleset_ids: list[str] = field(factory=list, kw_only=True)
3537
knowledge_base_ids: Optional[list[str]] = field(default=None, kw_only=True)
@@ -41,8 +43,41 @@ class GriptapeCloudAssistantDriver(BaseAssistantDriver):
4143
stream: bool = field(default=False, kw_only=True)
4244
poll_interval: int = field(default=1, kw_only=True)
4345
max_attempts: int = field(default=20, kw_only=True)
46+
auto_create_thread: bool = field(default=True, kw_only=True)
4447

4548
def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
49+
if self.thread_id is None and self.auto_create_thread:
50+
self._create_or_find_thread(self.thread_alias)
51+
assistant_run_id = self._create_run(*args)
52+
run_result = self._get_run_result(assistant_run_id)
53+
54+
run_result.meta.update(
55+
{"assistant_id": self.assistant_id, "assistant_run_id": assistant_run_id, "thread_id": self.thread_id}
56+
)
57+
58+
return run_result
59+
60+
def _create_or_find_thread(self, thread_alias: Optional[str] = None) -> None:
61+
if thread_alias is None:
62+
self.thread_id = self._create_thread()
63+
else:
64+
thread = self._find_thread_by_alias(thread_alias)
65+
66+
if thread is None:
67+
self.thread_id = self._create_thread(thread_alias)
68+
69+
def _create_thread(self, thread_alias: Optional[str] = None) -> str:
70+
url = urljoin(self.base_url.strip("/"), "/api/threads")
71+
72+
body = {"name": uuid.uuid4().hex}
73+
if thread_alias is not None:
74+
body["alias"] = thread_alias
75+
76+
response = requests.post(url, json=body, headers=self.headers)
77+
response.raise_for_status()
78+
return response.json()["thread_id"]
79+
80+
def _create_run(self, *args: BaseArtifact) -> str:
4681
url = urljoin(self.base_url.strip("/"), f"/api/assistants/{self.assistant_id}/runs")
4782

4883
response = requests.post(
@@ -64,9 +99,7 @@ def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
6499
headers=self.headers,
65100
)
66101
response.raise_for_status()
67-
response_json = response.json()
68-
69-
return self._get_run_result(response_json["assistant_run_id"])
102+
return response.json()["assistant_run_id"]
70103

71104
def _get_run_result(self, assistant_run_id: str) -> BaseArtifact | InfoArtifact:
72105
events, next_offset = self._get_run_events(assistant_run_id)
@@ -105,3 +138,12 @@ def _get_run_events(self, assistant_run_id: str, offset: int = 0) -> tuple[list[
105138
next_offset = response_json.get("next_offset", 0)
106139

107140
return events, next_offset
141+
142+
def _find_thread_by_alias(self, thread_alias: str) -> Optional[dict]:
143+
url = urljoin(self.base_url.strip("/"), "/api/threads")
144+
response = requests.get(url, params={"alias": thread_alias}, headers=self.headers)
145+
response.raise_for_status()
146+
147+
threads = response.json()["threads"]
148+
149+
return next((thread for thread in threads if thread["alias"] == thread_alias), None)

Diff for: griptape/drivers/assistant/openai_assistant_driver.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
3939
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
4040
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
4141
organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
42-
thread_id: Optional[str] = field(kw_only=True)
42+
thread_id: Optional[str] = field(default=None, kw_only=True)
4343
assistant_id: str = field(kw_only=True)
4444
event_handler: AssistantEventHandler = field(
4545
default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={"serializable": False}
4646
)
47+
auto_create_thread: bool = field(default=True, kw_only=True)
4748

4849
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
4950

@@ -56,8 +57,17 @@ def client(self) -> openai.OpenAI:
5657
)
5758

5859
def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
60+
if self.thread_id is None and self.auto_create_thread:
61+
self.thread_id = self.client.beta.threads.create().id
62+
response = self._create_run(*args)
63+
64+
response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id})
65+
66+
return response
67+
68+
def _create_run(self, *args: BaseArtifact) -> TextArtifact:
5969
content = "\n".join(arg.value for arg in args)
60-
self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=content)
70+
message_id = self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=content)
6171
with self.client.beta.threads.runs.stream(
6272
thread_id=self.thread_id,
6373
assistant_id=self.assistant_id,
@@ -71,4 +81,9 @@ def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
7181
message_contents.append("".join(content.text.value for content in message.content))
7282
message_text = "\n".join(message_contents)
7383

74-
return TextArtifact(message_text)
84+
response = TextArtifact(message_text)
85+
86+
response.meta.update(
87+
{"assistant_id": self.assistant_id, "thread_id": self.thread_id, "message_id": message_id}
88+
)
89+
return response

0 commit comments

Comments
 (0)