Skip to content

Commit 75506c4

Browse files
authored
feat: add StateBaseAgent (#529)
1 parent 20b14b4 commit 75506c4

File tree

16 files changed

+654
-203
lines changed

16 files changed

+654
-203
lines changed

docs/developer_guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ Choose the conversational_agent when beginning with RAI, and consider upgrading
132132
from myrobot import robot
133133

134134
from rai.agents.conversational_agent import create_conversational_agent
135-
from rai.agents.state_based import create_state_based_agent
135+
from rai.agents.langchain import create_state_based_runnable
136136
from rai import get_llm_model
137137

138138
SYSTEM_PROMPT = "You are a robot with interfaces..."
@@ -144,7 +144,7 @@ tools = [pick_up_object, scan_object, SayTool(robot=robot)]
144144
conversational_agent = create_conversational_agent(
145145
llm=llm, tools=tools, system_prompt=SYSTEM_PROMPT
146146
)
147-
state_based_agent = create_state_based_agent(
147+
state_based_agent = create_state_based_runnable(
148148
llm=llm, state_retriever=state_retriever, tools=tools
149149
)
150150

examples/agents/state_based.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language goveself.rning permissions and
13+
# limitations under the License.
14+
15+
from rai.agents import wait_for_shutdown
16+
from rai.agents.langchain import StateBasedConfig
17+
from rai.agents.ros2 import ROS2StateBasedAgent
18+
from rai.aggregators.ros2 import (
19+
ROS2ImgVLMDiffAggregator,
20+
ROS2LogsAggregator,
21+
)
22+
from rai.communication.ros2 import (
23+
ROS2Connector,
24+
ROS2Context,
25+
ROS2HRIConnector,
26+
)
27+
from rai.tools.ros2.generic.toolkit import ROS2Toolkit
28+
29+
30+
@ROS2Context()
31+
def main():
32+
hri_connector = ROS2HRIConnector()
33+
ros2_connector = ROS2Connector()
34+
35+
config = StateBasedConfig(
36+
aggregators={
37+
("/camera/camera/color/image_raw", "sensor_msgs/msg/Image"): [
38+
ROS2ImgVLMDiffAggregator()
39+
],
40+
"/rosout": [
41+
ROS2LogsAggregator()
42+
], # if msg_type is not provided, topic has to exist
43+
}
44+
)
45+
46+
agent = ROS2StateBasedAgent(
47+
config=config,
48+
target_connectors={"to_human": hri_connector},
49+
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
50+
)
51+
agent.subscribe_source("/from_human", hri_connector)
52+
agent.run()
53+
wait_for_shutdown([agent])
54+
55+
56+
if __name__ == "__main__":
57+
main()

src/rai_core/rai/agents/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from rai.agents.conversational_agent import create_conversational_agent
1717
from rai.agents.langchain.react_agent import ReActAgent
1818
from rai.agents.runner import AgentRunner, wait_for_shutdown
19-
from rai.agents.state_based import create_state_based_agent
2019
from rai.agents.tool_runner import ToolRunner
2120

2221
__all__ = [
@@ -25,6 +24,5 @@
2524
"ReActAgent",
2625
"ToolRunner",
2726
"create_conversational_agent",
28-
"create_state_based_agent",
2927
"wait_for_shutdown",
3028
]

src/rai_core/rai/agents/langchain/__init__.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .agent import BaseState, LangChainAgent, newMessageBehaviorType
1516
from .callback import HRICallbackHandler
16-
from .runnables import create_react_runnable
17+
from .react_agent import ReActAgent
18+
from .runnables import (
19+
ReActAgentState,
20+
create_react_runnable,
21+
create_state_based_runnable,
22+
)
23+
from .state_based_agent import BaseStateBasedAgent, StateBasedConfig
1724

18-
__all__ = ["HRICallbackHandler", "create_react_runnable"]
25+
__all__ = [
26+
"BaseState",
27+
"BaseStateBasedAgent",
28+
"HRICallbackHandler",
29+
"LangChainAgent",
30+
"ReActAgent",
31+
"ReActAgentState",
32+
"StateBasedConfig",
33+
"create_react_runnable",
34+
"create_state_based_runnable",
35+
"newMessageBehaviorType",
36+
]

src/rai_core/rai/agents/langchain/agent.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
from langchain_core.messages import BaseMessage
2323
from langchain_core.runnables import Runnable
2424

25-
from rai.agents.base import BaseAgent
26-
from rai.agents.langchain import HRICallbackHandler
27-
from rai.agents.langchain.runnables import ReActAgentState
2825
from rai.communication.hri_connector import HRIConnector, HRIMessage
2926
from rai.initialization import get_tracing_callbacks
3027

28+
from ..base import BaseAgent
29+
from .callback import HRICallbackHandler
30+
from .runnables import ReActAgentState
31+
3132

3233
class BaseState(TypedDict):
3334
messages: List[BaseMessage]
@@ -208,6 +209,7 @@ def _run_loop(self):
208209
self._run_agent()
209210

210211
def stop(self):
212+
"""Stop the agent's execution loop."""
211213
self._stop_event.set()
212214
self._interrupt_event.set()
213215
self._agent_ready_event.wait()
@@ -216,6 +218,7 @@ def stop(self):
216218
self._thread.join()
217219
self._thread = None
218220
self.logger.info("Agent stopped")
221+
self._stop_event.clear()
219222

220223
@staticmethod
221224
def _apply_reduction_behavior(

src/rai_core/rai/agents/langchain/react_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from langchain_core.language_models import BaseChatModel
1818
from langchain_core.tools import BaseTool
1919

20-
from rai.agents.langchain import create_react_runnable
21-
from rai.agents.langchain.agent import LangChainAgent
22-
from rai.agents.langchain.runnables import ReActAgentState
2320
from rai.communication.hri_connector import HRIConnector, HRIMessage
2421

22+
from .agent import LangChainAgent
23+
from .runnables import ReActAgentState, create_react_runnable
24+
2525

2626
class ReActAgent(LangChainAgent):
2727
def __init__(

src/rai_core/rai/agents/langchain/runnables.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
from functools import partial
16-
from typing import List, Optional, TypedDict, cast
17+
from typing import (
18+
Any,
19+
Callable,
20+
Dict,
21+
List,
22+
Optional,
23+
TypedDict,
24+
cast,
25+
)
1726

1827
from langchain_core.language_models import BaseChatModel
19-
from langchain_core.messages import BaseMessage, SystemMessage
28+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
2029
from langchain_core.runnables import Runnable
2130
from langchain_core.tools import BaseTool
2231
from langgraph.graph import START, StateGraph
2332
from langgraph.prebuilt.tool_node import tools_condition
2433

25-
from rai.agents.tool_runner import ToolRunner
2634
from rai.initialization import get_llm_model
35+
from rai.messages import HumanMultimodalMessage
36+
37+
from ..tool_runner import ToolRunner
2738

2839

2940
class ReActAgentState(TypedDict):
@@ -112,3 +123,50 @@ def create_react_runnable(
112123

113124
# Compile the graph
114125
return graph.compile()
126+
127+
128+
def retriever_wrapper(
129+
state_retriever: Callable[[], Dict[str, HumanMessage | HumanMultimodalMessage]],
130+
state: ReActAgentState,
131+
):
132+
"""This wrapper is used to put state messages into LLM context"""
133+
for source, message in state_retriever().items():
134+
message.content = f"{source}: {message.content}"
135+
logging.getLogger("state_retriever").debug(
136+
f"Adding state message:\n{message.pretty_repr()}"
137+
)
138+
state["messages"].append(message)
139+
return state
140+
141+
142+
def create_state_based_runnable(
143+
llm: Optional[BaseChatModel] = None,
144+
tools: Optional[List[BaseTool]] = None,
145+
system_prompt: Optional[str] = None,
146+
state_retriever: Optional[Callable[[], Dict[str, Any]]] = None,
147+
) -> Runnable[ReActAgentState, ReActAgentState]:
148+
if llm is None:
149+
llm = get_llm_model("complex_model", streaming=True)
150+
graph = StateGraph(ReActAgentState)
151+
graph.add_edge(START, "state_retriever")
152+
graph.add_edge("state_retriever", "llm")
153+
graph.add_conditional_edges(
154+
"llm",
155+
tools_condition,
156+
)
157+
graph.add_edge("tools", "state_retriever")
158+
159+
if state_retriever is None:
160+
state_retriever = lambda: {}
161+
162+
graph.add_node("state_retriever", partial(retriever_wrapper, state_retriever))
163+
164+
if tools is None:
165+
tools = []
166+
bound_llm = cast(BaseChatModel, llm.bind_tools(tools))
167+
graph.add_node("llm", partial(llm_node, bound_llm, system_prompt))
168+
169+
tool_runner = ToolRunner(tools)
170+
graph.add_node("tools", tool_runner)
171+
172+
return graph.compile()

0 commit comments

Comments
 (0)