Skip to content

Commit f31ffff

Browse files
committed
refactor(state_based): adapt to register_callback api
1 parent ff99594 commit f31ffff

File tree

5 files changed

+44
-60
lines changed

5 files changed

+44
-60
lines changed

examples/agents/state_based.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# limitations under the License.
1414

1515
from rai.agents import StateBasedAgent, StateBasedConfig, wait_for_shutdown
16-
from rai.agents.postprocessors.ros2 import (
17-
ROS2ImgVLMDiffPostprocessor,
18-
ROS2LogsPostprocessor,
16+
from rai.aggregators.ros2 import (
17+
ROS2ImgVLMDiffAggregator,
18+
ROS2LogsAggregator,
1919
)
2020
from rai.communication.ros2 import (
21-
ROS2ARIConnector,
21+
ROS2Connector,
2222
ROS2Context,
2323
ROS2HRIConnector,
2424
)
@@ -29,12 +29,12 @@
2929
@ROS2Context()
3030
def main():
3131
connector = ROS2HRIConnector(sources=["/from_human"], targets=["/to_human"])
32-
ari_connector = ROS2ARIConnector()
32+
ros2_connector = ROS2Connector()
3333

3434
config = StateBasedConfig(
35-
postprocessors={
36-
"/camera/camera/color/image_raw": [ROS2ImgVLMDiffPostprocessor()],
37-
"/rosout": [ROS2LogsPostprocessor()],
35+
aggregators={
36+
"/camera/camera/color/image_raw": [ROS2ImgVLMDiffAggregator()],
37+
"/rosout": [ROS2LogsAggregator()],
3838
},
3939
sources={
4040
"/camera/camera/color/image_raw": TopicConfig(
@@ -47,7 +47,7 @@ def main():
4747
agent = StateBasedAgent(
4848
config=config,
4949
connectors={"hri": connector},
50-
tools=ROS2Toolkit(connector=ari_connector).get_tools(),
50+
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
5151
) # type: ignore
5252
agent.run()
5353
wait_for_shutdown([agent])

src/rai_core/rai/agents/state_based_agent.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
from langchain_core.messages import HumanMessage
2424
from langchain_core.tools import BaseTool
2525
from pydantic import BaseModel, ConfigDict, Field
26+
from rai.aggregators import BaseAggregator
2627
from rclpy.callback_groups import ReentrantCallbackGroup
2728
from rclpy.subscription import Subscription
2829

2930
from rai.agents.langchain.runnables import create_state_based_runnable
30-
from rai.agents.postprocessors import BaseStatePostprocessor
3131
from rai.communication.hri_connector import HRIConnector, HRIMessage
3232
from rai.communication.ros2.api.conversion import import_message_from_str
3333
from rai.communication.ros2.api.topic import TopicConfig
@@ -38,7 +38,7 @@
3838

3939

4040
class StateBasedConfig(BaseModel):
41-
postprocessors: Dict[str, List[BaseStatePostprocessor]]
41+
aggregators: Dict[str, List[BaseAggregator]]
4242
sources: Dict[str, TopicConfig]
4343
time_interval: float = Field(default=5.0)
4444
max_workers: int = 8
@@ -96,26 +96,13 @@ def __init__(
9696
self._aggregation_thread: threading.Thread | None = None
9797

9898
def _configure_state_sources(self):
99-
for topic, config in self.config.sources.items():
100-
if topic in self._subscriptions:
101-
continue
102-
# NOTE(boczekbartek): refactor to use confugired_callbacks once implemented
103-
# in the connector
104-
qos_profile = self._ros2_connector._topic_api._resolve_qos_profile(
105-
topic, config.auto_qos_matching, config.qos_profile, for_publisher=False
106-
)
107-
msg_type = import_message_from_str(config.msg_type)
108-
self._subscriptions[topic] = self._ros2_connector.node.create_subscription(
109-
msg_type=msg_type,
110-
topic=topic,
111-
callback=partial(self._state_topic_callback, topic),
112-
qos_profile=qos_profile,
113-
callback_group=self._callback_group,
114-
)
115-
116-
def _state_topic_callback(self, topic_name: str, msg: Any):
117-
with self._db_lock:
118-
self._db[topic_name].append(msg)
99+
for source, aggregators in self.config.aggregators.items():
100+
for aggregator in aggregators:
101+
self._ros2_connector.register_callback(
102+
source,
103+
aggregator,
104+
raw=True
105+
)
119106

120107
def run(self):
121108
super().run()
@@ -146,30 +133,25 @@ def _on_aggregation_interval(self):
146133
"""Runs aggregation on collected data"""
147134

148135
def process_aggregator(
149-
source: str, postprocessor: BaseStatePostprocessor
136+
source: str, aggregator: BaseAggregator
150137
) -> Tuple[str, HumanMessage | HumanMultimodalMessage | None]:
151-
with self._db_lock:
152-
msgs = tuple(self._db[source]) # shallow, immutable copy
153138
self.logger.info(
154-
f"Running postprocessor: {postprocessor}(source={source}) on {len(msgs)} messages"
139+
f"Running postprocessor: {aggregator}(source={source}) on {len(aggregator.get_buffer())} messages"
155140
)
156-
return source, postprocessor(msgs)
141+
return source, aggregator.get()
157142

158143
with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
159144
futures = list()
160-
for source, postprocessors in self.config.postprocessors.items():
161-
for postprocessor in postprocessors:
162-
future = executor.submit(process_aggregator, source, postprocessor)
145+
for source, aggregators in self.config.aggregators.items():
146+
for aggregator in aggregators:
147+
future = executor.submit(process_aggregator, source, aggregator)
163148
futures.append(future)
164149

165150
for future in as_completed(futures):
166151
source, output = future.result()
167152
if output is not None:
168153
self._aggregation_results[source] = output
169154

170-
with self._db_lock:
171-
self._db.clear()
172-
173155
def stop(self):
174156
"""Stop the agent's execution loop."""
175157
self.logger.info("Stopping the agent. Please wait...")

src/rai_core/rai/aggregators/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from collections import deque
1717
from typing import Deque, Generic, TypeVar, List
1818

19-
from langchain_core.messages import BaseMessage
19+
from langchain_core.messages import HumanMessage
20+
from rai.messages.multimodal import HumanMultimodalMessage
2021

2122

2223
T = TypeVar("T")
@@ -42,7 +43,7 @@ def __call__(
4243
self._buffer.append(msg)
4344

4445
@abstractmethod
45-
def get(self) -> BaseMessage | None:
46+
def get(self) -> HumanMessage | HumanMultimodalMessage | None:
4647
""" Returns the aggregated message """
4748
pass
4849

src/rai_core/rai/aggregators/ros2/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .postprocessors import (
16-
ROS2GetLastImagePostprocessor,
17-
ROS2ImgVLMDescriptionPostprocessor,
18-
ROS2ImgVLMDiffPostprocessor,
19-
ROS2LogsPostprocessor,
15+
from .aggregators import (
16+
ROS2GetLastImageAggregator,
17+
ROS2ImgVLMDescriptionAggregator,
18+
ROS2ImgVLMDiffAggregator,
19+
ROS2LogsAggregator,
2020
)
2121

2222
__all__ = [
23-
"ROS2GetLastImagePostprocessor",
24-
"ROS2ImgVLMDescriptionPostprocessor",
25-
"ROS2ImgVLMDiffPostprocessor",
26-
"ROS2LogsPostprocessor",
23+
"ROS2GetLastImageAggregator",
24+
"ROS2ImgVLMDescriptionAggregator",
25+
"ROS2ImgVLMDiffAggregator",
26+
"ROS2LogsAggregator",
2727
]

src/rai_core/rai/aggregators/ros2/aggregators.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from rai.messages import HumanMultimodalMessage
2626

2727

28-
class ROS2LogsPostprocessor(BaseAggregator[Log]):
28+
class ROS2LogsAggregator(BaseAggregator[Log]):
2929
"""Returns only unique messages while keeping their order"""
3030

3131
levels = {10: "DEBUG", 20: "INFO", 30: "WARNING", 40: "ERROR", 50: "FATAL"}
@@ -51,7 +51,7 @@ def get(self) -> HumanMessage:
5151
return HumanMessage(content=result)
5252

5353

54-
class ROS2GetLastImagePostprocessor(BaseAggregator[Image | CompressedImage]):
54+
class ROS2GetLastImageAggregator(BaseAggregator[Image | CompressedImage]):
5555
"""Returns the last image from the buffer as base64 encoded string"""
5656

5757
def get(self) -> HumanMultimodalMessage | None:
@@ -63,7 +63,7 @@ def get(self) -> HumanMultimodalMessage | None:
6363
return HumanMultimodalMessage(content="", images=[b64_image])
6464

6565

66-
class ROS2ImgVLMDescriptionPostprocessor(
66+
class ROS2ImgVLMDescriptionAggregator(
6767
BaseAggregator[Image | CompressedImage]
6868
):
6969
"""
@@ -104,7 +104,7 @@ class ROS2ImgDescription(BaseModel):
104104
)
105105

106106

107-
class ROS2ImgVLMDiffPostprocessor(BaseAggregator[Image | CompressedImage]):
107+
class ROS2ImgVLMDiffAggregator(BaseAggregator[Image | CompressedImage]):
108108
"""
109109
Returns the LLM analysis of the differences between 3 images in the
110110
aggregation buffer: 1st, midden, last
@@ -125,10 +125,11 @@ def get_key_elements(elements: List[Any]) -> List[Any]:
125125
return [elements[0], elements[middle_index], elements[-1]]
126126

127127
def get(self) -> HumanMessage | None:
128-
if len(self.get_buffer()) == 0:
128+
msgs = self.get_buffer()
129+
if len(msgs) == 0:
129130
return None
130131

131-
b64_images = [encode_ros2_img_to_base64(msg) for msg in self._buffer]
132+
b64_images = [encode_ros2_img_to_base64(msg) for msg in msgs]
132133

133134
self.clear()
134135

@@ -154,5 +155,5 @@ class ROS2ImgDiffOutput(BaseModel):
154155
llm = self.llm.with_structured_output(ROS2ImgDiffOutput)
155156
response = cast(ROS2ImgDiffOutput, llm.invoke(task))
156157
return HumanMessage(
157-
content=f"Result of the analysis of the {len(b64_images)} keyframes selected from {len(msgs)} last images:\n{response}"
158+
content=f"Result of the analysis of the {len(b64_images)} keyframes selected from {len(b64_images)} last images:\n{response}"
158159
)

0 commit comments

Comments
 (0)