|
23 | 23 | from langchain_core.messages import HumanMessage
|
24 | 24 | from langchain_core.tools import BaseTool
|
25 | 25 | from pydantic import BaseModel, ConfigDict, Field
|
| 26 | +from rai.aggregators import BaseAggregator |
26 | 27 | from rclpy.callback_groups import ReentrantCallbackGroup
|
27 | 28 | from rclpy.subscription import Subscription
|
28 | 29 |
|
29 | 30 | from rai.agents.langchain.runnables import create_state_based_runnable
|
30 |
| -from rai.agents.postprocessors import BaseStatePostprocessor |
31 | 31 | from rai.communication.hri_connector import HRIConnector, HRIMessage
|
32 | 32 | from rai.communication.ros2.api.conversion import import_message_from_str
|
33 | 33 | from rai.communication.ros2.api.topic import TopicConfig
|
|
38 | 38 |
|
39 | 39 |
|
40 | 40 | class StateBasedConfig(BaseModel):
|
41 |
| - postprocessors: Dict[str, List[BaseStatePostprocessor]] |
| 41 | + aggregators: Dict[str, List[BaseAggregator]] |
42 | 42 | sources: Dict[str, TopicConfig]
|
43 | 43 | time_interval: float = Field(default=5.0)
|
44 | 44 | max_workers: int = 8
|
@@ -96,26 +96,13 @@ def __init__(
|
96 | 96 | self._aggregation_thread: threading.Thread | None = None
|
97 | 97 |
|
98 | 98 | def _configure_state_sources(self):
|
99 |
| - for topic, config in self.config.sources.items(): |
100 |
| - if topic in self._subscriptions: |
101 |
| - continue |
102 |
| - # NOTE(boczekbartek): refactor to use confugired_callbacks once implemented |
103 |
| - # in the connector |
104 |
| - qos_profile = self._ros2_connector._topic_api._resolve_qos_profile( |
105 |
| - topic, config.auto_qos_matching, config.qos_profile, for_publisher=False |
106 |
| - ) |
107 |
| - msg_type = import_message_from_str(config.msg_type) |
108 |
| - self._subscriptions[topic] = self._ros2_connector.node.create_subscription( |
109 |
| - msg_type=msg_type, |
110 |
| - topic=topic, |
111 |
| - callback=partial(self._state_topic_callback, topic), |
112 |
| - qos_profile=qos_profile, |
113 |
| - callback_group=self._callback_group, |
114 |
| - ) |
115 |
| - |
116 |
| - def _state_topic_callback(self, topic_name: str, msg: Any): |
117 |
| - with self._db_lock: |
118 |
| - self._db[topic_name].append(msg) |
| 99 | + for source, aggregators in self.config.aggregators.items(): |
| 100 | + for aggregator in aggregators: |
| 101 | + self._ros2_connector.register_callback( |
| 102 | + source, |
| 103 | + aggregator, |
| 104 | + raw=True |
| 105 | + ) |
119 | 106 |
|
120 | 107 | def run(self):
|
121 | 108 | super().run()
|
@@ -146,30 +133,25 @@ def _on_aggregation_interval(self):
|
146 | 133 | """Runs aggregation on collected data"""
|
147 | 134 |
|
148 | 135 | def process_aggregator(
|
149 |
| - source: str, postprocessor: BaseStatePostprocessor |
| 136 | + source: str, aggregator: BaseAggregator |
150 | 137 | ) -> Tuple[str, HumanMessage | HumanMultimodalMessage | None]:
|
151 |
| - with self._db_lock: |
152 |
| - msgs = tuple(self._db[source]) # shallow, immutable copy |
153 | 138 | self.logger.info(
|
154 |
| - f"Running postprocessor: {postprocessor}(source={source}) on {len(msgs)} messages" |
| 139 | + f"Running postprocessor: {aggregator}(source={source}) on {len(aggregator.get_buffer())} messages" |
155 | 140 | )
|
156 |
| - return source, postprocessor(msgs) |
| 141 | + return source, aggregator.get() |
157 | 142 |
|
158 | 143 | with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
|
159 | 144 | futures = list()
|
160 |
| - for source, postprocessors in self.config.postprocessors.items(): |
161 |
| - for postprocessor in postprocessors: |
162 |
| - future = executor.submit(process_aggregator, source, postprocessor) |
| 145 | + for source, aggregators in self.config.aggregators.items(): |
| 146 | + for aggregator in aggregators: |
| 147 | + future = executor.submit(process_aggregator, source, aggregator) |
163 | 148 | futures.append(future)
|
164 | 149 |
|
165 | 150 | for future in as_completed(futures):
|
166 | 151 | source, output = future.result()
|
167 | 152 | if output is not None:
|
168 | 153 | self._aggregation_results[source] = output
|
169 | 154 |
|
170 |
| - with self._db_lock: |
171 |
| - self._db.clear() |
172 |
| - |
173 | 155 | def stop(self):
|
174 | 156 | """Stop the agent's execution loop."""
|
175 | 157 | self.logger.info("Stopping the agent. Please wait...")
|
|
0 commit comments