|
23 | 23 | from langchain_core.messages import HumanMessage |
24 | 24 | from langchain_core.tools import BaseTool |
25 | 25 | from pydantic import BaseModel, ConfigDict, Field |
| 26 | +from rai.aggregators import BaseAggregator |
26 | 27 | from rclpy.callback_groups import ReentrantCallbackGroup |
27 | 28 | from rclpy.subscription import Subscription |
28 | 29 |
|
29 | 30 | from rai.agents.langchain.runnables import create_state_based_runnable |
30 | | -from rai.agents.postprocessors import BaseStatePostprocessor |
31 | 31 | from rai.communication.hri_connector import HRIConnector, HRIMessage |
32 | 32 | from rai.communication.ros2.api.conversion import import_message_from_str |
33 | 33 | from rai.communication.ros2.api.topic import TopicConfig |
|
38 | 38 |
|
39 | 39 |
|
40 | 40 | class StateBasedConfig(BaseModel): |
41 | | - postprocessors: Dict[str, List[BaseStatePostprocessor]] |
| 41 | + aggregators: Dict[str, List[BaseAggregator]] |
42 | 42 | sources: Dict[str, TopicConfig] |
43 | 43 | time_interval: float = Field(default=5.0) |
44 | 44 | max_workers: int = 8 |
@@ -96,26 +96,13 @@ def __init__( |
96 | 96 | self._aggregation_thread: threading.Thread | None = None |
97 | 97 |
|
98 | 98 | def _configure_state_sources(self): |
99 | | - for topic, config in self.config.sources.items(): |
100 | | - if topic in self._subscriptions: |
101 | | - continue |
102 | | - # NOTE(boczekbartek): refactor to use confugired_callbacks once implemented |
103 | | - # in the connector |
104 | | - qos_profile = self._ros2_connector._topic_api._resolve_qos_profile( |
105 | | - topic, config.auto_qos_matching, config.qos_profile, for_publisher=False |
106 | | - ) |
107 | | - msg_type = import_message_from_str(config.msg_type) |
108 | | - self._subscriptions[topic] = self._ros2_connector.node.create_subscription( |
109 | | - msg_type=msg_type, |
110 | | - topic=topic, |
111 | | - callback=partial(self._state_topic_callback, topic), |
112 | | - qos_profile=qos_profile, |
113 | | - callback_group=self._callback_group, |
114 | | - ) |
115 | | - |
116 | | - def _state_topic_callback(self, topic_name: str, msg: Any): |
117 | | - with self._db_lock: |
118 | | - self._db[topic_name].append(msg) |
| 99 | + for source, aggregators in self.config.aggregators.items(): |
| 100 | + for aggregator in aggregators: |
| 101 | + self._ros2_connector.register_callback( |
| 102 | + source, |
| 103 | + aggregator, |
| 104 | + raw=True |
| 105 | + ) |
119 | 106 |
|
120 | 107 | def run(self): |
121 | 108 | super().run() |
@@ -146,30 +133,25 @@ def _on_aggregation_interval(self): |
146 | 133 | """Runs aggregation on collected data""" |
147 | 134 |
|
148 | 135 | def process_aggregator( |
149 | | - source: str, postprocessor: BaseStatePostprocessor |
| 136 | + source: str, aggregator: BaseAggregator |
150 | 137 | ) -> Tuple[str, HumanMessage | HumanMultimodalMessage | None]: |
151 | | - with self._db_lock: |
152 | | - msgs = tuple(self._db[source]) # shallow, immutable copy |
153 | 138 | self.logger.info( |
154 | | - f"Running postprocessor: {postprocessor}(source={source}) on {len(msgs)} messages" |
| 139 | + f"Running postprocessor: {aggregator}(source={source}) on {len(aggregator.get_buffer())} messages" |
155 | 140 | ) |
156 | | - return source, postprocessor(msgs) |
| 141 | + return source, aggregator.get() |
157 | 142 |
|
158 | 143 | with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor: |
159 | 144 | futures = list() |
160 | | - for source, postprocessors in self.config.postprocessors.items(): |
161 | | - for postprocessor in postprocessors: |
162 | | - future = executor.submit(process_aggregator, source, postprocessor) |
| 145 | + for source, aggregators in self.config.aggregators.items(): |
| 146 | + for aggregator in aggregators: |
| 147 | + future = executor.submit(process_aggregator, source, aggregator) |
163 | 148 | futures.append(future) |
164 | 149 |
|
165 | 150 | for future in as_completed(futures): |
166 | 151 | source, output = future.result() |
167 | 152 | if output is not None: |
168 | 153 | self._aggregation_results[source] = output |
169 | 154 |
|
170 | | - with self._db_lock: |
171 | | - self._db.clear() |
172 | | - |
173 | 155 | def stop(self): |
174 | 156 | """Stop the agent's execution loop.""" |
175 | 157 | self.logger.info("Stopping the agent. Please wait...") |
|
0 commit comments