1414
1515import threading
1616import time
17- from collections import defaultdict
1817from concurrent .futures import ThreadPoolExecutor , as_completed
19- from functools import partial
20- from typing import Any , Dict , List , Optional , Tuple
18+ from typing import Dict , List , Optional , Tuple
2119
2220from langchain_core .language_models import BaseChatModel
2321from langchain_core .messages import HumanMessage
2422from langchain_core .tools import BaseTool
2523from pydantic import BaseModel , ConfigDict , Field
26- from rai .aggregators import BaseAggregator
2724from rclpy .callback_groups import ReentrantCallbackGroup
2825from rclpy .subscription import Subscription
2926
3027from rai .agents .langchain .runnables import create_state_based_runnable
28+ from rai .aggregators import BaseAggregator
3129from rai .communication .hri_connector import HRIConnector , HRIMessage
32- from rai .communication .ros2 .connectors import ROS2Connector
3330from rai .messages .multimodal import HumanMultimodalMessage
3431
3532from .react_agent import ReActAgent , ReActAgentState
@@ -48,7 +45,7 @@ class StateBasedConfig(BaseModel):
4845class StateBasedAgent (ReActAgent ):
4946 """
5047 Agent that runs aggregators (config.aggregators) every config.time_interval seconds.
51- Aggregators are registered to their sources using
48+ Aggregators are registered to their sources using
5249 :py:class:`~rai.communication.ros2.connectors.ROS2Connector`
5350
5451 Output from aggragators is called `state`. Such state is saved and can be
@@ -77,27 +74,19 @@ def __init__(
7774 connectors , llm , tools , state , system_prompt , runnable = runnable
7875 )
7976 self .config = config
80- self .ros2_connector = ros2_connector
8177
8278 self ._callback_group = ReentrantCallbackGroup ()
8379 self ._subscriptions : Dict [str , Subscription ] = dict ()
8480
8581 self ._aggregation_results : Dict [str , HumanMessage | HumanMultimodalMessage ] = (
8682 dict ()
8783 )
88- self ._db : Dict [str , List [Any ]] = defaultdict (list )
89- self ._db_lock = threading .Lock ()
90-
9184 self ._aggregation_thread : threading .Thread | None = None
9285
9386 def _configure_state_sources (self ):
9487 for source , aggregators in self .config .aggregators .items ():
9588 for aggregator in aggregators :
96- self .ros2_connector .register_callback (
97- source ,
98- aggregator ,
99- raw = True
100- )
89+ self .ros2_connector .register_callback (source , aggregator , raw = True )
10190
10291 def run (self ):
10392 super ().run ()
@@ -157,6 +146,6 @@ def stop(self):
157146 if self ._aggregation_thread is not None :
158147 self ._aggregation_thread .join ()
159148 self ._aggregation_thread = None
160- #TODO(boczekbartek): deregister aggregators
149+ # TODO(boczekbartek): deregister aggregators
161150 self ._stop_event .clear ()
162151 self .logger .info ("Agent stopped" )
0 commit comments