Skip to content

Commit ab85485

Browse files
committed
wip
1 parent bb98682 commit ab85485

File tree

10 files changed

+562
-12
lines changed

10 files changed

+562
-12
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import threading
16+
import time
17+
from abc import ABC, abstractmethod
18+
from concurrent.futures import ThreadPoolExecutor, as_completed
19+
from typing import Dict, List, Optional, Tuple
20+
21+
from langchain_core.language_models import BaseChatModel
22+
from langchain_core.messages import BaseMessage, HumanMessage
23+
from langchain_core.tools import BaseTool
24+
from pydantic import BaseModel, ConfigDict, Field
25+
from rclpy.callback_groups import ReentrantCallbackGroup
26+
from rclpy.subscription import Subscription
27+
28+
from rai.agents.langchain import create_state_based_runnable
29+
from rai.aggregators import BaseAggregator
30+
from rai.communication.base_connector import BaseConnector
31+
from rai.communication.hri_connector import HRIConnector, HRIMessage
32+
from rai.messages.multimodal import HumanMultimodalMessage
33+
34+
from .langchain import ReActAgent, ReActAgentState, create_state_based_runnable
35+
36+
37+
class StateBasedConfig(BaseModel):
38+
aggregators: Dict[str, List[BaseAggregator]]
39+
time_interval: float = Field(default=5.0)
40+
max_workers: int = 8
41+
42+
model_config = ConfigDict(
43+
arbitrary_types_allowed=True,
44+
)
45+
46+
47+
class BaseStateBasedAgent(ReActAgent, ABC):
48+
"""
49+
Agent that runs aggregators (config.aggregators) every config.time_interval seconds.
50+
Aggregators are registered to their sources using
51+
:py:class:`~rai.communication.ros2.connectors.ROS2Connector`
52+
53+
Output from aggragators is called `state`. Such state is saved and can be
54+
retrieved by `get_state` method.
55+
56+
In `StateBaseAgent`, state is added to LLM history. For more details about the LLM
57+
agent see :py:func:`~rai.agents.langchain.runnables.create_state_based_runnable`
58+
"""
59+
60+
def __init__(
61+
self,
62+
connectors: dict[str, HRIConnector[HRIMessage]],
63+
config: StateBasedConfig,
64+
llm: Optional[BaseChatModel] = None,
65+
tools: Optional[List[BaseTool]] = None,
66+
state: Optional[ReActAgentState] = None,
67+
system_prompt: Optional[str] = None,
68+
):
69+
runnable = create_state_based_runnable(
70+
llm=llm,
71+
tools=tools,
72+
system_prompt=system_prompt,
73+
state_retriever=self.get_state,
74+
)
75+
super().__init__(
76+
connectors, llm, tools, state, system_prompt, runnable=runnable
77+
)
78+
self.config = config
79+
80+
self._callback_group = ReentrantCallbackGroup()
81+
self._subscriptions: Dict[str, Subscription] = dict()
82+
83+
self._aggregation_results: Dict[str, HumanMessage | HumanMultimodalMessage] = (
84+
dict()
85+
)
86+
self._aggregation_thread: threading.Thread | None = None
87+
88+
self._registered_callbacks = set()
89+
self._connector = self.setup_connector()
90+
self._configure_state_sources()
91+
92+
@abstractmethod
93+
def setup_connector(self) -> BaseConnector:
94+
pass
95+
96+
def _configure_state_sources(self):
97+
for source, aggregators in self.config.aggregators.items():
98+
for aggregator in aggregators:
99+
callback_id = self._connector.register_callback(
100+
source, aggregator, raw=True
101+
)
102+
self._registered_callbacks.add(callback_id)
103+
104+
def run(self):
105+
super().run()
106+
self._aggregation_thread = threading.Thread(target=self._run_state_loop)
107+
self._aggregation_thread.start()
108+
109+
def get_state(self) -> Dict[str, HumanMessage | HumanMultimodalMessage]:
110+
"""Returns output for all aggregators"""
111+
return self._aggregation_results
112+
113+
def _run_state_loop(self):
114+
"""Runs aggregation on collected data"""
115+
while not self._stop_event.is_set():
116+
ts = time.perf_counter()
117+
self.logger.debug("Starting aggregation interval")
118+
self._on_aggregation_interval()
119+
elapsed_time = time.perf_counter() - ts
120+
self.logger.debug(f"Aggregation done in: {elapsed_time:.2f}s")
121+
if elapsed_time > self.config.time_interval:
122+
self.logger.warning(
123+
"State aggregation time interval exceeded. Expected "
124+
f"{self.config.time_interval:.2f}s, got {elapsed_time:.2f}s. Consider "
125+
f"increasing {self.__class__.__name__}.config.time_interval."
126+
)
127+
time.sleep(max(0, self.config.time_interval - (elapsed_time)))
128+
129+
def _on_aggregation_interval(self):
130+
"""Runs aggregation on collected data"""
131+
132+
def process_aggregator(
133+
source: str, aggregator: BaseAggregator
134+
) -> Tuple[str, BaseMessage | None]:
135+
self.logger.info(
136+
f"Running aggregator: {aggregator}(source={source}) on {len(aggregator.get_buffer())} messages"
137+
)
138+
ts = time.perf_counter()
139+
140+
output = aggregator.get()
141+
142+
self.logger.debug(
143+
f'Aggregator "{aggregator}(source={source})" done in {time.perf_counter() - ts:.2f}s'
144+
)
145+
return source, output
146+
147+
with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
148+
futures = list()
149+
for source, aggregators in self.config.aggregators.items():
150+
for aggregator in aggregators:
151+
future = executor.submit(process_aggregator, source, aggregator)
152+
futures.append(future)
153+
154+
for future in as_completed(futures):
155+
try:
156+
source, output = future.result()
157+
except Exception as e:
158+
self.logger.error(f"Aggregator crashed: {e}")
159+
continue
160+
161+
if output is None:
162+
continue
163+
self._aggregation_results[source] = output
164+
165+
def stop(self):
166+
"""Stop the agent's execution loop."""
167+
self.logger.info("Stopping the agent. Please wait...")
168+
self._stop_event.set()
169+
if self.thread is not None:
170+
self.thread.join()
171+
self.thread = None
172+
if self._aggregation_thread is not None:
173+
self._aggregation_thread.join()
174+
self._aggregation_thread = None
175+
self._stop_event.clear()
176+
for callback_id in self._registered_callbacks:
177+
self._connector.unregister_callback(callback_id)
178+
self._connector.shutdown()
179+
self.logger.info("Agent stopped")

src/rai_core/rai/agents/langchain/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .agent import BaseState, LangChainAgent
1516
from .callback import HRICallbackHandler
16-
from .runnables import create_react_runnable
17+
from .react_agent import ReActAgent
18+
from .runnables import ReActAgentState, create_react_runnable
1719

18-
__all__ = ["HRICallbackHandler", "create_react_runnable"]
20+
__all__ = [
21+
"BaseState",
22+
"HRICallbackHandler",
23+
"LangChainAgent",
24+
"ReActAgent",
25+
"ReActAgentState",
26+
"create_react_runnable",
27+
]

src/rai_core/rai/agents/langchain/callback.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
import threading
17-
from typing import Dict, List, Optional
17+
from typing import List, Optional
1818
from uuid import UUID
1919

2020
from langchain_core.callbacks import BaseCallbackHandler
@@ -27,7 +27,7 @@
2727
class HRICallbackHandler(BaseCallbackHandler):
2828
def __init__(
2929
self,
30-
connectors: Dict[str, HRIConnector[HRIMessage]],
30+
connectors: dict[str, HRIConnector[HRIMessage]],
3131
aggregate_chunks: bool = False,
3232
splitting_chars: Optional[List[str]] = None,
3333
max_buffer_size: int = 200,
@@ -47,20 +47,21 @@ def _should_split(self, token: str) -> bool:
4747
return token in self.splitting_chars
4848

4949
def _send_all_targets(self, tokens: str, done: bool = False):
50-
for target, connector in self.connectors.items():
51-
self.logger.info(f"Sending {len(tokens)} tokens to target: {target}")
50+
self.logger.info(
51+
f"Sending {len(tokens)} tokens to {len(self.connectors)} connectors"
52+
)
53+
for connector_name, connector in self.connectors.items():
5254
try:
53-
to_send: HRIMessage = connector.build_message(
55+
connector.send_all_targets(
5456
AIMessage(content=tokens),
5557
self.current_conversation_id,
5658
self.current_chunk_id,
5759
done,
5860
)
59-
connector.send_message(to_send, target)
60-
self.logger.debug(f"Sent {len(tokens)} tokens to hri_connector.")
61+
self.logger.debug(f"Sent {len(tokens)} tokens to {connector_name}")
6162
except Exception as e:
6263
self.logger.error(
63-
f"Failed to send {len(tokens)} tokens to hri_connector: {e}"
64+
f"Failed to send {len(tokens)} tokens to {connector_name}: {e}"
6465
)
6566

6667
def on_llm_new_token(self, token: str, *, run_id: UUID, **kwargs):

src/rai_core/rai/agents/langchain/runnables.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
from functools import partial
16-
from typing import List, Optional, TypedDict, cast
17+
from typing import (
18+
Any,
19+
Callable,
20+
Dict,
21+
List,
22+
Optional,
23+
TypedDict,
24+
cast,
25+
)
1726

1827
from langchain_core.language_models import BaseChatModel
19-
from langchain_core.messages import BaseMessage, SystemMessage
28+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
2029
from langchain_core.runnables import Runnable
2130
from langchain_core.tools import BaseTool
2231
from langgraph.graph import START, StateGraph
2332
from langgraph.prebuilt.tool_node import tools_condition
2433

2534
from rai.agents.tool_runner import ToolRunner
2635
from rai.initialization import get_llm_model
36+
from rai.messages import HumanMultimodalMessage
2737

2838

2939
class ReActAgentState(TypedDict):
@@ -112,3 +122,50 @@ def create_react_runnable(
112122

113123
# Compile the graph
114124
return graph.compile()
125+
126+
127+
def retriever_wrapper(
128+
state_retriever: Callable[[], Dict[str, HumanMessage | HumanMultimodalMessage]],
129+
state: ReActAgentState,
130+
):
131+
"""This wrapper is used to put state messages into LLM context"""
132+
for source, message in state_retriever().items():
133+
message.content = f"{source}: {message.content}"
134+
logging.getLogger("state_retriever").debug(
135+
f"Adding state message:\n{message.pretty_repr()}"
136+
)
137+
state["messages"].append(message)
138+
return state
139+
140+
141+
def create_state_based_runnable(
142+
llm: Optional[BaseChatModel] = None,
143+
tools: Optional[List[BaseTool]] = None,
144+
system_prompt: Optional[str] = None,
145+
state_retriever: Optional[Callable[[], Dict[str, Any]]] = None,
146+
) -> Runnable[ReActAgentState, ReActAgentState]:
147+
if llm is None:
148+
llm = get_llm_model("complex_model", streaming=True)
149+
graph = StateGraph(ReActAgentState)
150+
graph.add_edge(START, "state_retriever")
151+
graph.add_edge("state_retriever", "llm")
152+
graph.add_conditional_edges(
153+
"llm",
154+
tools_condition,
155+
)
156+
graph.add_edge("tools", "state_retriever")
157+
158+
if state_retriever is None:
159+
state_retriever = lambda: {}
160+
161+
graph.add_node("state_retriever", partial(retriever_wrapper, state_retriever))
162+
163+
if tools is None:
164+
tools = []
165+
bound_llm = cast(BaseChatModel, llm.bind_tools(tools))
166+
graph.add_node("llm", partial(llm_node, bound_llm, system_prompt))
167+
168+
tool_runner = ToolRunner(tools)
169+
graph.add_node("tools", tool_runner)
170+
171+
return graph.compile()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from .state_based_agent import ROS2StateBasedAgent
15+
16+
__all__ = ["ROS2StateBasedAgent"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from rai.agents import BaseStateBasedAgent
16+
from rai.communication.ros2 import ROS2Connector
17+
18+
19+
class ROS2StateBasedAgent(BaseStateBasedAgent):
20+
def setup_connector(self):
21+
return ROS2Connector()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .base import BaseAggregator
16+
17+
__all__ = ["BaseAggregator"]

0 commit comments

Comments
 (0)