diff --git a/CHANGELOG.md b/CHANGELOG.md index f3fe9185e..98bcb8423 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Changed +- `BaseTask.run` now accepts args and kwargs which are added to the Task's context under `args` and `kwargs` keys. +- `Structure.run` now accepts kwargs which is added to the Task's context under `kwargs` key. + + ## [1.2.0] - 2025-01-21 ### Added diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 5cdced709..5210ba552 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -13,6 +13,7 @@ Tasks that take input have a field [input](../../reference/griptape/tasks/base_t Within the [input](../../reference/griptape/tasks/base_text_input_task.md#griptape.tasks.base_text_input_task.BaseTextInputTask.input), you can access the following [context](../../reference/griptape/structures/structure.md#griptape.structures.structure.Structure.context) variables: - `args`: an array of arguments passed to the `.run()` method. +- `kwargs`: an array of keyword arguments passed to the `.run()` method. - `structure`: the structure that the task belongs to. - user defined context variables diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 24d9d5233..4804ab2c9 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -48,7 +48,8 @@ class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC): ) meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True) fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - _execution_args: tuple = () + _execution_args: tuple = field(factory=tuple, init=False) + _execution_kwargs: dict[str, Any] = field(factory=dict, init=False) _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()), init=False) def __attrs_post_init__(self) -> None: @@ -74,6 +75,10 @@ def tasks(self) -> list[BaseTask]: def execution_args(self) -> tuple: return self._execution_args + @property + def execution_kwargs(self) -> dict: + return self._execution_kwargs + @property def input_task(self) -> Optional[BaseTask]: return self.tasks[0] if self.tasks else None @@ -125,7 +130,7 @@ def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]: return added_tasks def context(self, task: BaseTask) -> dict[str, Any]: - return {"args": self.execution_args, "structure": self} + return {"args": self.execution_args, "kwargs": self.execution_kwargs, "structure": self} def resolve_relationships(self) -> None: task_by_id = {} @@ -154,7 +159,6 @@ def resolve_relationships(self) -> None: @observable def before_run(self, args: Any) -> None: super().before_run(args) - self._execution_args = args [task.reset() for task in self.tasks] @@ -197,7 +201,9 @@ def after_run(self) -> None: def add_task(self, task: BaseTask) -> BaseTask: ... @observable - def run(self, *args) -> Structure: + def run(self, *args, **kwargs) -> Structure: + self._execution_args = args + self._execution_kwargs = kwargs self.before_run(args) result = self.try_run(*args) diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index e0036b1a5..2ba998e8f 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -40,6 +40,16 @@ class State(Enum): output: Optional[BaseArtifact] = field(default=None, init=False) context: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) + _execution_args: tuple = field(factory=tuple, init=False) + _execution_kwargs: dict[str, Any] = field(factory=dict, init=False) + + @property + def execution_args(self) -> tuple: + return self._execution_args + + @property + def execution_kwargs(self) -> dict: + return self._execution_kwargs def __rshift__(self, other: BaseTask | list[BaseTask]) -> BaseTask | list[BaseTask]: if isinstance(other, list): @@ -160,8 +170,11 @@ def before_run(self) -> None: ), ) - def run(self) -> BaseArtifact: + def run(self, *args, **kwargs) -> BaseArtifact: try: + self._execution_args = args + self._execution_kwargs = kwargs + self.state = BaseTask.State.RUNNING self.before_run() @@ -209,6 +222,8 @@ def can_run(self) -> bool: def reset(self) -> BaseTask: self.state = BaseTask.State.PENDING self.output = None + self._execution_args = () + self._execution_kwargs = {} return self @@ -219,7 +234,9 @@ def try_run(self) -> BaseArtifact: ... def full_context(self) -> dict[str, Any]: # Need to deep copy so that the serialized context doesn't contain non-serializable data context = deepcopy(self.context) - if self.structure is not None: + if self.structure is None: + context.update({"args": self._execution_args, "kwargs": self._execution_kwargs}) + else: context.update(self.structure.context(self)) return context diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 442f654d5..ee90a6435 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -215,11 +215,13 @@ def test_context(self): agent.add_task(task) - agent.run("hello") + agent.run("hello", foo="bar") context = agent.context(task) assert context["structure"] == agent + assert context["args"] == ("hello",) + assert context["kwargs"] == {"foo": "bar"} def test_task_memory_defaults(self, mock_config): agent = Agent() diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index d461b5bbf..04b548e2e 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -356,7 +356,7 @@ def test_context(self): assert context["parent_output"] is None - pipeline.run() + pipeline.run("hello", foo="bar") context = pipeline.context(task) @@ -365,6 +365,8 @@ def test_context(self): assert context["structure"] == pipeline assert context["parent"] == parent assert context["child"] == child + assert context["args"] == ("hello",) + assert context["kwargs"] == {"foo": "bar"} def test_run_with_error_artifact(self, error_artifact_task, waiting_task): end_task = PromptTask("end") diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index a40b20b93..b71c0e2b5 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -739,7 +739,7 @@ def test_context(self): assert context["parent_outputs"] == {} - workflow.run() + workflow.run("hello", foo="bar") context = workflow.context(task) @@ -749,6 +749,8 @@ def test_context(self): assert context["structure"] == workflow assert context["parents"] == {parent.id: parent} assert context["children"] == {child.id: child} + assert context["args"] == ("hello",) + assert context["kwargs"] == {"foo": "bar"} def test_run_with_error_artifact(self, error_artifact_task, waiting_task): end_task = PromptTask("end") diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 8ceea4c77..517bfd7c5 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -229,8 +229,19 @@ def test_runnable_mixin(self): def test_full_context(self, task): task.structure = Agent() task.structure._execution_args = ("foo", "bar") + task.structure._execution_kwargs = {"baz": "qux"} - assert task.full_context == {"args": ("foo", "bar"), "structure": task.structure} + assert task.full_context == {"args": ("foo", "bar"), "kwargs": {"baz": "qux"}, "structure": task.structure} + assert task.structure.execution_args == ("foo", "bar") + assert task.structure.execution_kwargs == {"baz": "qux"} + + task.structure = None + task._execution_args = ("foo", "bar") + task._execution_kwargs = {"baz": "qux"} + + assert task.full_context == {"args": ("foo", "bar"), "kwargs": {"baz": "qux"}} + assert task.execution_args == ("foo", "bar") + assert task.execution_kwargs == {"baz": "qux"} def test_is_pending(self, task): task.state = task.State.PENDING @@ -249,3 +260,21 @@ def test___str__(self, task): assert str(task) == "foobar" task.output = None assert str(task) == "" + + def test_run_args(self, task): + task.run("foo", "bar") + + assert task._execution_args == ("foo", "bar") + + def test_run_kwargs(self, task): + task.run(foo="bar") + + assert task._execution_kwargs == {"foo": "bar"} + + def test_args_full_context(self): + task = MockTask() + task.context = {"foo": "buzz"} + task.run("foo", "bar", baz="qux") + + assert task.full_context["args"] == ("foo", "bar") + assert task.full_context["kwargs"] == {"baz": "qux"} diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index 1be3904c5..8c79d0892 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -31,9 +31,15 @@ def test_full_context(self): subtask = MockTextInputTask("test", context={"foo": "bar"}) child = MockTextInputTask("child") - assert parent.full_context == {} - assert subtask.full_context == {"foo": "bar"} - assert child.full_context == {} + assert parent.full_context == { + "args": (), + "kwargs": {}, + } + assert subtask.full_context == {"args": (), "kwargs": {}, "foo": "bar"} + assert child.full_context == { + "args": (), + "kwargs": {}, + } pipeline = Pipeline()