Skip to content

Commit 335f6ec

Browse files
committed
Support Structure kwargs, Task args/kwargs
1 parent 1dbe51a commit 335f6ec

File tree

9 files changed

+79
-13
lines changed

9 files changed

+79
-13
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222

2323
## Unreleased
2424

25+
### Changed
26+
- `BaseTask.run` now accepts args and kwargs which are added to the Task's context under `args` and `kwargs` keys.
27+
- `Structure.run` now accepts kwargs which is added to the Task's context under `kwargs` key.
28+
29+
2530
## [1.2.0] - 2025-01-21
2631

2732
### Added

docs/griptape-framework/structures/tasks.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Tasks that take input have a field [input](../../reference/griptape/tasks/base_t
1313
Within the [input](../../reference/griptape/tasks/base_text_input_task.md#griptape.tasks.base_text_input_task.BaseTextInputTask.input), you can access the following [context](../../reference/griptape/structures/structure.md#griptape.structures.structure.Structure.context) variables:
1414

1515
- `args`: an array of arguments passed to the `.run()` method.
16+
- `kwargs`: an array of keyword arguments passed to the `.run()` method.
1617
- `structure`: the structure that the task belongs to.
1718
- user defined context variables
1819

griptape/structures/structure.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC):
4848
)
4949
meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True)
5050
fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True})
51-
_execution_args: tuple = ()
51+
_execution_args: tuple = field(factory=tuple, init=False)
52+
_execution_kwargs: dict[str, Any] = field(factory=dict, init=False)
5253
_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()), init=False)
5354

5455
def __attrs_post_init__(self) -> None:
@@ -74,6 +75,10 @@ def tasks(self) -> list[BaseTask]:
7475
def execution_args(self) -> tuple:
7576
return self._execution_args
7677

78+
@property
79+
def execution_kwargs(self) -> dict:
80+
return self._execution_kwargs
81+
7782
@property
7883
def input_task(self) -> Optional[BaseTask]:
7984
return self.tasks[0] if self.tasks else None
@@ -125,7 +130,7 @@ def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]:
125130
return added_tasks
126131

127132
def context(self, task: BaseTask) -> dict[str, Any]:
128-
return {"args": self.execution_args, "structure": self}
133+
return {"args": self.execution_args, "kwargs": self.execution_kwargs, "structure": self}
129134

130135
def resolve_relationships(self) -> None:
131136
task_by_id = {}
@@ -154,7 +159,6 @@ def resolve_relationships(self) -> None:
154159
@observable
155160
def before_run(self, args: Any) -> None:
156161
super().before_run(args)
157-
self._execution_args = args
158162

159163
[task.reset() for task in self.tasks]
160164

@@ -197,7 +201,9 @@ def after_run(self) -> None:
197201
def add_task(self, task: BaseTask) -> BaseTask: ...
198202

199203
@observable
200-
def run(self, *args) -> Structure:
204+
def run(self, *args, **kwargs) -> Structure:
205+
self._execution_args = args
206+
self._execution_kwargs = kwargs
201207
self.before_run(args)
202208

203209
result = self.try_run(*args)

griptape/tasks/base_task.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ class State(Enum):
4040

4141
output: Optional[BaseArtifact] = field(default=None, init=False)
4242
context: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
43+
_execution_args: tuple = field(factory=tuple, init=False)
44+
_execution_kwargs: dict[str, Any] = field(factory=dict, init=False)
45+
46+
@property
47+
def execution_args(self) -> tuple:
48+
return self._execution_args
49+
50+
@property
51+
def execution_kwargs(self) -> dict:
52+
return self._execution_kwargs
4353

4454
def __rshift__(self, other: BaseTask | list[BaseTask]) -> BaseTask | list[BaseTask]:
4555
if isinstance(other, list):
@@ -160,8 +170,11 @@ def before_run(self) -> None:
160170
),
161171
)
162172

163-
def run(self) -> BaseArtifact:
173+
def run(self, *args, **kwargs) -> BaseArtifact:
164174
try:
175+
self._execution_args = args
176+
self._execution_kwargs = kwargs
177+
165178
self.state = BaseTask.State.RUNNING
166179

167180
self.before_run()
@@ -209,6 +222,8 @@ def can_run(self) -> bool:
209222
def reset(self) -> BaseTask:
210223
self.state = BaseTask.State.PENDING
211224
self.output = None
225+
self._execution_args = ()
226+
self._execution_kwargs = {}
212227

213228
return self
214229

@@ -219,7 +234,9 @@ def try_run(self) -> BaseArtifact: ...
219234
def full_context(self) -> dict[str, Any]:
220235
# Need to deep copy so that the serialized context doesn't contain non-serializable data
221236
context = deepcopy(self.context)
222-
if self.structure is not None:
237+
if self.structure is None:
238+
context.update({"args": self._execution_args, "kwargs": self._execution_kwargs})
239+
else:
223240
context.update(self.structure.context(self))
224241

225242
return context

tests/unit/structures/test_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,13 @@ def test_context(self):
215215

216216
agent.add_task(task)
217217

218-
agent.run("hello")
218+
agent.run("hello", foo="bar")
219219

220220
context = agent.context(task)
221221

222222
assert context["structure"] == agent
223+
assert context["args"] == ("hello",)
224+
assert context["kwargs"] == {"foo": "bar"}
223225

224226
def test_task_memory_defaults(self, mock_config):
225227
agent = Agent()

tests/unit/structures/test_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def test_context(self):
356356

357357
assert context["parent_output"] is None
358358

359-
pipeline.run()
359+
pipeline.run("hello", foo="bar")
360360

361361
context = pipeline.context(task)
362362

@@ -365,6 +365,8 @@ def test_context(self):
365365
assert context["structure"] == pipeline
366366
assert context["parent"] == parent
367367
assert context["child"] == child
368+
assert context["args"] == ("hello",)
369+
assert context["kwargs"] == {"foo": "bar"}
368370

369371
def test_run_with_error_artifact(self, error_artifact_task, waiting_task):
370372
end_task = PromptTask("end")

tests/unit/structures/test_workflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def test_context(self):
739739

740740
assert context["parent_outputs"] == {}
741741

742-
workflow.run()
742+
workflow.run("hello", foo="bar")
743743

744744
context = workflow.context(task)
745745

@@ -749,6 +749,8 @@ def test_context(self):
749749
assert context["structure"] == workflow
750750
assert context["parents"] == {parent.id: parent}
751751
assert context["children"] == {child.id: child}
752+
assert context["args"] == ("hello",)
753+
assert context["kwargs"] == {"foo": "bar"}
752754

753755
def test_run_with_error_artifact(self, error_artifact_task, waiting_task):
754756
end_task = PromptTask("end")

tests/unit/tasks/test_base_task.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,15 @@ def test_runnable_mixin(self):
229229
def test_full_context(self, task):
230230
task.structure = Agent()
231231
task.structure._execution_args = ("foo", "bar")
232+
task.structure._execution_kwargs = {"baz": "qux"}
232233

233-
assert task.full_context == {"args": ("foo", "bar"), "structure": task.structure}
234+
assert task.full_context == {"args": ("foo", "bar"), "kwargs": {"baz": "qux"}, "structure": task.structure}
235+
236+
task.structure = None
237+
task._execution_args = ("foo", "bar")
238+
task._execution_kwargs = {"baz": "qux"}
239+
240+
assert task.full_context == {"args": ("foo", "bar"), "kwargs": {"baz": "qux"}}
234241

235242
def test_is_pending(self, task):
236243
task.state = task.State.PENDING
@@ -249,3 +256,21 @@ def test___str__(self, task):
249256
assert str(task) == "foobar"
250257
task.output = None
251258
assert str(task) == ""
259+
260+
def test_run_args(self, task):
261+
task.run("foo", "bar")
262+
263+
assert task._execution_args == ("foo", "bar")
264+
265+
def test_run_kwargs(self, task):
266+
task.run(foo="bar")
267+
268+
assert task._execution_kwargs == {"foo": "bar"}
269+
270+
def test_args_full_context(self):
271+
task = MockTask()
272+
task.context = {"foo": "buzz"}
273+
task.run("foo", "bar", baz="qux")
274+
275+
assert task.full_context["args"] == ("foo", "bar")
276+
assert task.full_context["kwargs"] == {"baz": "qux"}

tests/unit/tasks/test_base_text_input_task.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,15 @@ def test_full_context(self):
3131
subtask = MockTextInputTask("test", context={"foo": "bar"})
3232
child = MockTextInputTask("child")
3333

34-
assert parent.full_context == {}
35-
assert subtask.full_context == {"foo": "bar"}
36-
assert child.full_context == {}
34+
assert parent.full_context == {
35+
"args": (),
36+
"kwargs": {},
37+
}
38+
assert subtask.full_context == {"args": (), "kwargs": {}, "foo": "bar"}
39+
assert child.full_context == {
40+
"args": (),
41+
"kwargs": {},
42+
}
3743

3844
pipeline = Pipeline()
3945

0 commit comments

Comments
 (0)