14
14
15
15
import threading
16
16
import time
17
- from collections import defaultdict
18
17
from concurrent .futures import ThreadPoolExecutor , as_completed
19
- from functools import partial
20
- from typing import Any , Dict , List , Optional , Tuple
18
+ from typing import Dict , List , Optional , Tuple
21
19
22
20
from langchain_core .language_models import BaseChatModel
23
21
from langchain_core .messages import HumanMessage
24
22
from langchain_core .tools import BaseTool
25
23
from pydantic import BaseModel , ConfigDict , Field
26
- from rai .aggregators import BaseAggregator
27
24
from rclpy .callback_groups import ReentrantCallbackGroup
28
25
from rclpy .subscription import Subscription
29
26
30
27
from rai .agents .langchain .runnables import create_state_based_runnable
28
+ from rai .aggregators import BaseAggregator
31
29
from rai .communication .hri_connector import HRIConnector , HRIMessage
32
- from rai .communication .ros2 .connectors import ROS2Connector
33
30
from rai .messages .multimodal import HumanMultimodalMessage
34
31
35
32
from .react_agent import ReActAgent , ReActAgentState
@@ -48,7 +45,7 @@ class StateBasedConfig(BaseModel):
48
45
class StateBasedAgent (ReActAgent ):
49
46
"""
50
47
Agent that runs aggregators (config.aggregators) every config.time_interval seconds.
51
- Aggregators are registered to their sources using
48
+ Aggregators are registered to their sources using
52
49
:py:class:`~rai.communication.ros2.connectors.ROS2Connector`
53
50
54
51
Output from aggragators is called `state`. Such state is saved and can be
@@ -77,27 +74,19 @@ def __init__(
77
74
connectors , llm , tools , state , system_prompt , runnable = runnable
78
75
)
79
76
self .config = config
80
- self .ros2_connector = ros2_connector
81
77
82
78
self ._callback_group = ReentrantCallbackGroup ()
83
79
self ._subscriptions : Dict [str , Subscription ] = dict ()
84
80
85
81
self ._aggregation_results : Dict [str , HumanMessage | HumanMultimodalMessage ] = (
86
82
dict ()
87
83
)
88
- self ._db : Dict [str , List [Any ]] = defaultdict (list )
89
- self ._db_lock = threading .Lock ()
90
-
91
84
self ._aggregation_thread : threading .Thread | None = None
92
85
93
86
def _configure_state_sources (self ):
94
87
for source , aggregators in self .config .aggregators .items ():
95
88
for aggregator in aggregators :
96
- self .ros2_connector .register_callback (
97
- source ,
98
- aggregator ,
99
- raw = True
100
- )
89
+ self .ros2_connector .register_callback (source , aggregator , raw = True )
101
90
102
91
def run (self ):
103
92
super ().run ()
@@ -157,6 +146,6 @@ def stop(self):
157
146
if self ._aggregation_thread is not None :
158
147
self ._aggregation_thread .join ()
159
148
self ._aggregation_thread = None
160
- #TODO(boczekbartek): deregister aggregators
149
+ # TODO(boczekbartek): deregister aggregators
161
150
self ._stop_event .clear ()
162
151
self .logger .info ("Agent stopped" )
0 commit comments