Skip to content

Commit

Permalink
Support Structure kwargs, Task args/kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 22, 2025
1 parent 1dbe51a commit f2711cd
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 13 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed
- `BaseTask.run` now accepts args and kwargs which are added to the Task's context under `args` and `kwargs` keys.
- `Structure.run` now accepts kwargs which is added to the Task's context under `kwargs` key.


## [1.2.0] - 2025-01-21

### Added
Expand Down
1 change: 1 addition & 0 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Tasks that take input have a field [input](../../reference/griptape/tasks/base_t
Within the [input](../../reference/griptape/tasks/base_text_input_task.md#griptape.tasks.base_text_input_task.BaseTextInputTask.input), you can access the following [context](../../reference/griptape/structures/structure.md#griptape.structures.structure.Structure.context) variables:

- `args`: an array of arguments passed to the `.run()` method.
- `kwargs`: an array of keyword arguments passed to the `.run()` method.
- `structure`: the structure that the task belongs to.
- user defined context variables

Expand Down
14 changes: 10 additions & 4 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC):
)
meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True)
fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_execution_args: tuple = ()
_execution_args: tuple = field(factory=tuple, init=False)
_execution_kwargs: dict[str, Any] = field(factory=dict, init=False)
_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()), init=False)

def __attrs_post_init__(self) -> None:
Expand All @@ -74,6 +75,10 @@ def tasks(self) -> list[BaseTask]:
def execution_args(self) -> tuple:
return self._execution_args

@property
def execution_kwargs(self) -> dict:
return self._execution_kwargs

@property
def input_task(self) -> Optional[BaseTask]:
return self.tasks[0] if self.tasks else None
Expand Down Expand Up @@ -125,7 +130,7 @@ def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]:
return added_tasks

def context(self, task: BaseTask) -> dict[str, Any]:
return {"args": self.execution_args, "structure": self}
return {"args": self.execution_args, "kwargs": self.execution_kwargs, "structure": self}

def resolve_relationships(self) -> None:
task_by_id = {}
Expand Down Expand Up @@ -154,7 +159,6 @@ def resolve_relationships(self) -> None:
@observable
def before_run(self, args: Any) -> None:
super().before_run(args)
self._execution_args = args

[task.reset() for task in self.tasks]

Expand Down Expand Up @@ -197,7 +201,9 @@ def after_run(self) -> None:
def add_task(self, task: BaseTask) -> BaseTask: ...

@observable
def run(self, *args) -> Structure:
def run(self, *args, **kwargs) -> Structure:
self._execution_args = args
self._execution_kwargs = kwargs
self.before_run(args)

result = self.try_run(*args)
Expand Down
21 changes: 19 additions & 2 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ class State(Enum):

output: Optional[BaseArtifact] = field(default=None, init=False)
context: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
_execution_args: tuple = field(factory=tuple, init=False)
_execution_kwargs: dict[str, Any] = field(factory=dict, init=False)

@property
def execution_args(self) -> tuple:
return self._execution_args

@property
def execution_kwargs(self) -> dict:
return self._execution_kwargs

def __rshift__(self, other: BaseTask | list[BaseTask]) -> BaseTask | list[BaseTask]:
if isinstance(other, list):
Expand Down Expand Up @@ -160,8 +170,11 @@ def before_run(self) -> None:
),
)

def run(self) -> BaseArtifact:
def run(self, *args, **kwargs) -> BaseArtifact:
try:
self._execution_args = args
self._execution_kwargs = kwargs

self.state = BaseTask.State.RUNNING

self.before_run()
Expand Down Expand Up @@ -209,6 +222,8 @@ def can_run(self) -> bool:
def reset(self) -> BaseTask:
self.state = BaseTask.State.PENDING
self.output = None
self._execution_args = ()
self._execution_kwargs = {}

return self

Expand All @@ -219,7 +234,9 @@ def try_run(self) -> BaseArtifact: ...
def full_context(self) -> dict[str, Any]:
# Need to deep copy so that the serialized context doesn't contain non-serializable data
context = deepcopy(self.context)
if self.structure is not None:
if self.structure is None:
context.update({"args": self._execution_args, "kwargs": self._execution_kwargs})
else:
context.update(self.structure.context(self))

return context
4 changes: 3 additions & 1 deletion tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,13 @@ def test_context(self):

agent.add_task(task)

agent.run("hello")
agent.run("hello", foo="bar")

context = agent.context(task)

assert context["structure"] == agent
assert context["args"] == ("hello",)
assert context["kwargs"] == {"foo": "bar"}

def test_task_memory_defaults(self, mock_config):
agent = Agent()
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def test_context(self):

assert context["parent_output"] is None

pipeline.run()
pipeline.run("hello", foo="bar")

context = pipeline.context(task)

Expand All @@ -365,6 +365,8 @@ def test_context(self):
assert context["structure"] == pipeline
assert context["parent"] == parent
assert context["child"] == child
assert context["args"] == ("hello",)
assert context["kwargs"] == {"foo": "bar"}

def test_run_with_error_artifact(self, error_artifact_task, waiting_task):
end_task = PromptTask("end")
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def test_context(self):

assert context["parent_outputs"] == {}

workflow.run()
workflow.run("hello", foo="bar")

context = workflow.context(task)

Expand All @@ -749,6 +749,8 @@ def test_context(self):
assert context["structure"] == workflow
assert context["parents"] == {parent.id: parent}
assert context["children"] == {child.id: child}
assert context["args"] == ("hello",)
assert context["kwargs"] == {"foo": "bar"}

def test_run_with_error_artifact(self, error_artifact_task, waiting_task):
end_task = PromptTask("end")
Expand Down
31 changes: 30 additions & 1 deletion tests/unit/tasks/test_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,19 @@ def test_runnable_mixin(self):
def test_full_context(self, task):
task.structure = Agent()
task.structure._execution_args = ("foo", "bar")
task.structure._execution_kwargs = {"baz": "qux"}

assert task.full_context == {"args": ("foo", "bar"), "structure": task.structure}
assert task.full_context == {"args": ("foo", "bar"), "kwargs": {"baz": "qux"}, "structure": task.structure}
assert task.structure.execution_args == ("foo", "bar")
assert task.structure.execution_kwargs == {"baz": "qux"}

task.structure = None
task._execution_args = ("foo", "bar")
task._execution_kwargs = {"baz": "qux"}

assert task.full_context == {"args": ("foo", "bar"), "kwargs": {"baz": "qux"}}
assert task.execution_args == ("foo", "bar")
assert task.execution_kwargs == {"baz": "qux"}

def test_is_pending(self, task):
task.state = task.State.PENDING
Expand All @@ -249,3 +260,21 @@ def test___str__(self, task):
assert str(task) == "foobar"
task.output = None
assert str(task) == ""

def test_run_args(self, task):
task.run("foo", "bar")

assert task._execution_args == ("foo", "bar")

def test_run_kwargs(self, task):
task.run(foo="bar")

assert task._execution_kwargs == {"foo": "bar"}

def test_args_full_context(self):
task = MockTask()
task.context = {"foo": "buzz"}
task.run("foo", "bar", baz="qux")

assert task.full_context["args"] == ("foo", "bar")
assert task.full_context["kwargs"] == {"baz": "qux"}
12 changes: 9 additions & 3 deletions tests/unit/tasks/test_base_text_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,15 @@ def test_full_context(self):
subtask = MockTextInputTask("test", context={"foo": "bar"})
child = MockTextInputTask("child")

assert parent.full_context == {}
assert subtask.full_context == {"foo": "bar"}
assert child.full_context == {}
assert parent.full_context == {
"args": (),
"kwargs": {},
}
assert subtask.full_context == {"args": (), "kwargs": {}, "foo": "bar"}
assert child.full_context == {
"args": (),
"kwargs": {},
}

pipeline = Pipeline()

Expand Down

0 comments on commit f2711cd

Please sign in to comment.