Skip to content

Commit 49a56bb

Browse files
committed
feat: add StateBasedAgent and rai.aggregators
1 parent ab85485 commit 49a56bb

File tree

11 files changed

+92
-212
lines changed

11 files changed

+92
-212
lines changed

docs/developer_guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ Choose the conversational_agent when beginning with RAI, and consider upgrading
131131
from myrobot import robot
132132

133133
from rai.agents.conversational_agent import create_conversational_agent
134-
from rai.agents.state_based import create_state_based_agent
134+
from rai.agents.langchain import create_state_based_runnable
135135
from rai import get_llm_model
136136

137137
SYSTEM_PROMPT = "You are a robot with interfaces..."
@@ -143,7 +143,7 @@ tools = [pick_up_object, scan_object, SayTool(robot=robot)]
143143
conversational_agent = create_conversational_agent(
144144
llm=llm, tools=tools, system_prompt=SYSTEM_PROMPT
145145
)
146-
state_based_agent = create_state_based_agent(
146+
state_based_agent = create_state_based_runnable(
147147
llm=llm, state_retriever=state_retriever, tools=tools
148148
)
149149

examples/agents/state_based.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language goveself.rning permissions and
13+
# limitations under the License.
14+
15+
from rai.agents import wait_for_shutdown
16+
from rai.agents.langchain import StateBasedConfig
17+
from rai.agents.ros2 import ROS2StateBasedAgent
18+
from rai.aggregators.ros2 import (
19+
ROS2ImgVLMDiffAggregator,
20+
ROS2LogsAggregator,
21+
)
22+
from rai.communication.ros2 import (
23+
ROS2Connector,
24+
ROS2Context,
25+
ROS2HRIConnector,
26+
)
27+
from rai.tools.ros2.generic.toolkit import ROS2Toolkit
28+
29+
30+
@ROS2Context()
31+
def main():
32+
hri_connector = ROS2HRIConnector()
33+
ros2_connector = ROS2Connector()
34+
35+
config = StateBasedConfig(
36+
aggregators={
37+
"/camera/camera/color/image_raw": [ROS2ImgVLMDiffAggregator()],
38+
"/rosout": [ROS2LogsAggregator()],
39+
}
40+
)
41+
42+
agent = ROS2StateBasedAgent(
43+
config=config,
44+
target_connectors={"to_human": hri_connector},
45+
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
46+
)
47+
agent.subscribe_source("/from_human", hri_connector)
48+
agent.run()
49+
wait_for_shutdown([agent])
50+
51+
52+
if __name__ == "__main__":
53+
main()

src/rai_core/rai/agents/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from rai.agents.conversational_agent import create_conversational_agent
1717
from rai.agents.langchain.react_agent import ReActAgent
1818
from rai.agents.runner import AgentRunner, wait_for_shutdown
19-
from rai.agents.state_based import create_state_based_agent
2019
from rai.agents.tool_runner import ToolRunner
2120

2221
__all__ = [
@@ -25,6 +24,5 @@
2524
"ReActAgent",
2625
"ToolRunner",
2726
"create_conversational_agent",
28-
"create_state_based_agent",
2927
"wait_for_shutdown",
3028
]

src/rai_core/rai/agents/langchain/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .agent import BaseState, LangChainAgent
15+
from .agent import BaseState, LangChainAgent, newMessageBehaviorType
1616
from .callback import HRICallbackHandler
1717
from .react_agent import ReActAgent
18-
from .runnables import ReActAgentState, create_react_runnable
18+
from .runnables import (
19+
ReActAgentState,
20+
create_react_runnable,
21+
create_state_based_runnable,
22+
)
23+
from .state_based_agent import BaseStateBasedAgent, StateBasedConfig
1924

2025
__all__ = [
2126
"BaseState",
27+
"BaseStateBasedAgent",
2228
"HRICallbackHandler",
2329
"LangChainAgent",
2430
"ReActAgent",
2531
"ReActAgentState",
32+
"StateBasedConfig",
2633
"create_react_runnable",
34+
"create_state_based_runnable",
35+
"newMessageBehaviorType",
2736
]

src/rai_core/rai/agents/langchain/agent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
from langchain_core.messages import BaseMessage
2323
from langchain_core.runnables import Runnable
2424

25-
from rai.agents.base import BaseAgent
26-
from rai.agents.langchain import HRICallbackHandler
27-
from rai.agents.langchain.runnables import ReActAgentState
2825
from rai.communication.hri_connector import HRIConnector, HRIMessage
2926
from rai.initialization import get_tracing_callbacks
3027

28+
from ..base import BaseAgent
29+
from .callback import HRICallbackHandler
30+
from .runnables import ReActAgentState
31+
3132

3233
class BaseState(TypedDict):
3334
messages: List[BaseMessage]

src/rai_core/rai/agents/langchain/react_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from langchain_core.language_models import BaseChatModel
1818
from langchain_core.tools import BaseTool
1919

20-
from rai.agents.langchain import create_react_runnable
21-
from rai.agents.langchain.agent import LangChainAgent
22-
from rai.agents.langchain.runnables import ReActAgentState
2320
from rai.communication.hri_connector import HRIConnector, HRIMessage
2421

22+
from .agent import LangChainAgent
23+
from .runnables import ReActAgentState, create_react_runnable
24+
2525

2626
class ReActAgent(LangChainAgent):
2727
def __init__(

src/rai_core/rai/agents/langchain/runnables.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
from langgraph.graph import START, StateGraph
3232
from langgraph.prebuilt.tool_node import tools_condition
3333

34-
from rai.agents.tool_runner import ToolRunner
3534
from rai.initialization import get_llm_model
3635
from rai.messages import HumanMultimodalMessage
3736

37+
from ..tool_runner import ToolRunner
38+
3839

3940
class ReActAgentState(TypedDict):
4041
"""State type for the react agent.

src/rai_core/rai/agents/base_state_based_agent.py renamed to src/rai_core/rai/agents/langchain/state_based_agent.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@
2222
from langchain_core.messages import BaseMessage, HumanMessage
2323
from langchain_core.tools import BaseTool
2424
from pydantic import BaseModel, ConfigDict, Field
25-
from rclpy.callback_groups import ReentrantCallbackGroup
26-
from rclpy.subscription import Subscription
2725

28-
from rai.agents.langchain import create_state_based_runnable
2926
from rai.aggregators import BaseAggregator
3027
from rai.communication.base_connector import BaseConnector
3128
from rai.communication.hri_connector import HRIConnector, HRIMessage
3229
from rai.messages.multimodal import HumanMultimodalMessage
3330

34-
from .langchain import ReActAgent, ReActAgentState, create_state_based_runnable
31+
from .agent import LangChainAgent, newMessageBehaviorType
32+
from .runnables import ReActAgentState, create_state_based_runnable
3533

3634

3735
class StateBasedConfig(BaseModel):
@@ -44,7 +42,7 @@ class StateBasedConfig(BaseModel):
4442
)
4543

4644

47-
class BaseStateBasedAgent(ReActAgent, ABC):
45+
class BaseStateBasedAgent(LangChainAgent, ABC):
4846
"""
4947
Agent that runs aggregators (config.aggregators) every config.time_interval seconds.
5048
Aggregators are registered to their sources using
@@ -59,12 +57,14 @@ class BaseStateBasedAgent(ReActAgent, ABC):
5957

6058
def __init__(
6159
self,
62-
connectors: dict[str, HRIConnector[HRIMessage]],
6360
config: StateBasedConfig,
61+
target_connectors: Dict[str, HRIConnector[HRIMessage]],
6462
llm: Optional[BaseChatModel] = None,
6563
tools: Optional[List[BaseTool]] = None,
6664
state: Optional[ReActAgentState] = None,
6765
system_prompt: Optional[str] = None,
66+
new_message_behavior: newMessageBehaviorType = "interrupt_keep_last",
67+
max_size: int = 100,
6868
):
6969
runnable = create_state_based_runnable(
7070
llm=llm,
@@ -73,13 +73,14 @@ def __init__(
7373
state_retriever=self.get_state,
7474
)
7575
super().__init__(
76-
connectors, llm, tools, state, system_prompt, runnable=runnable
76+
target_connectors=target_connectors,
77+
runnable=runnable,
78+
state=state,
79+
new_message_behavior=new_message_behavior,
80+
max_size=max_size,
7781
)
7882
self.config = config
7983

80-
self._callback_group = ReentrantCallbackGroup()
81-
self._subscriptions: Dict[str, Subscription] = dict()
82-
8384
self._aggregation_results: Dict[str, HumanMessage | HumanMultimodalMessage] = (
8485
dict()
8586
)

src/rai_core/rai/agents/ros2/state_based_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from rai.agents import BaseStateBasedAgent
1615
from rai.communication.ros2 import ROS2Connector
1716

17+
from ..langchain import BaseStateBasedAgent
18+
1819

1920
class ROS2StateBasedAgent(BaseStateBasedAgent):
2021
def setup_connector(self):

0 commit comments

Comments
 (0)