Skip to content

Commit ff99594

Browse files
committed
refactor: aggregators with state
1 parent f6fa467 commit ff99594

File tree

6 files changed

+87
-60
lines changed

6 files changed

+87
-60
lines changed

src/rai_core/rai/agents/postprocessors/base.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

src/rai_core/rai/agents/state_based_agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from rai.communication.hri_connector import HRIConnector, HRIMessage
3232
from rai.communication.ros2.api.conversion import import_message_from_str
3333
from rai.communication.ros2.api.topic import TopicConfig
34-
from rai.communication.ros2.connectors import ROS2ARIConnector
34+
from rai.communication.ros2.connectors import ROS2Connector
3535
from rai.messages.multimodal import HumanMultimodalMessage
3636

3737
from .react_agent import ReActAgent, ReActAgentState
@@ -81,7 +81,7 @@ def __init__(
8181
)
8282
self.config = config
8383

84-
self._ari_connector = ROS2ARIConnector()
84+
self._ros2_connector = ROS2Connector()
8585
self._callback_group = ReentrantCallbackGroup()
8686
self._subscriptions: Dict[str, Subscription] = dict()
8787

@@ -101,11 +101,11 @@ def _configure_state_sources(self):
101101
continue
102102
# NOTE(boczekbartek): refactor to use confugired_callbacks once implemented
103103
# in the connector
104-
qos_profile = self._ari_connector._topic_api._resolve_qos_profile(
104+
qos_profile = self._ros2_connector._topic_api._resolve_qos_profile(
105105
topic, config.auto_qos_matching, config.qos_profile, for_publisher=False
106106
)
107107
msg_type = import_message_from_str(config.msg_type)
108-
self._subscriptions[topic] = self._ari_connector.node.create_subscription(
108+
self._subscriptions[topic] = self._ros2_connector.node.create_subscription(
109109
msg_type=msg_type,
110110
topic=topic,
111111
callback=partial(self._state_topic_callback, topic),
@@ -182,5 +182,5 @@ def stop(self):
182182
self._aggregation_thread = None
183183
self._stop_event.clear()
184184
for subscription in self._subscriptions.values():
185-
self._ari_connector.node.destroy_subscription(subscription)
185+
self._ros2_connector.node.destroy_subscription(subscription)
186186
self.logger.info("Agent stopped")

src/rai_core/rai/agents/postprocessors/__init__.py renamed to src/rai_core/rai/aggregators/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .base import BaseStatePostprocessor
15+
from .base import BaseAggregator
1616

17-
__all__ = ["BaseStatePostprocessor"]
17+
__all__ = ["BaseAggregator"]

src/rai_core/rai/aggregators/base.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from collections import deque
17+
from typing import Deque, Generic, TypeVar, List
18+
19+
from langchain_core.messages import BaseMessage
20+
21+
22+
T = TypeVar("T")
23+
24+
25+
class BaseAggregator(ABC, Generic[T]):
26+
"""
27+
Interface for aggregators.
28+
29+
`__call__` method receives a message and appends it to the buffer.
30+
`get` method returns the aggregated message.
31+
"""
32+
def __init__(self, max_size: int | None=None) -> None:
33+
super().__init__()
34+
self._buffer: Deque[T] = deque()
35+
self.max_size = max_size
36+
37+
def __call__(
38+
self, msg: T
39+
) -> None:
40+
if self.max_size is not None and len(self._buffer) >= self.max_size:
41+
self._buffer.popleft()
42+
self._buffer.append(msg)
43+
44+
@abstractmethod
45+
def get(self) -> BaseMessage | None:
46+
""" Returns the aggregated message """
47+
pass
48+
49+
def clear(self) -> None:
50+
self._buffer.clear()
51+
52+
def get_buffer(self) -> List[T]:
53+
return list(self._buffer)
54+
55+
def __str__(self) -> str:
56+
return f"{self.__class__.__name__}"
57+

src/rai_core/rai/agents/postprocessors/ros2/postprocessors.py renamed to src/rai_core/rai/aggregators/ros2/aggregators.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,19 @@
1919
from rcl_interfaces.msg import Log
2020
from sensor_msgs.msg import CompressedImage, Image
2121

22-
from rai.agents.postprocessors import BaseStatePostprocessor
22+
from rai.aggregators import BaseAggregator
2323
from rai.communication.ros2.api.conversion import encode_ros2_img_to_base64
2424
from rai.initialization.model_initialization import get_llm_model
2525
from rai.messages import HumanMultimodalMessage
2626

2727

28-
class ROS2LogsPostprocessor(BaseStatePostprocessor[Log]):
28+
class ROS2LogsPostprocessor(BaseAggregator[Log]):
2929
"""Returns only unique messages while keeping their order"""
3030

3131
levels = {10: "DEBUG", 20: "INFO", 30: "WARNING", 40: "ERROR", 50: "FATAL"}
3232

33-
def __call__(self, msgs: Sequence[Log]) -> HumanMessage:
33+
def get(self) -> HumanMessage:
34+
msgs = self.get_buffer()
3435
buffer = []
3536
prev_parsed = None
3637
counter = 0
@@ -50,12 +51,11 @@ def __call__(self, msgs: Sequence[Log]) -> HumanMessage:
5051
return HumanMessage(content=result)
5152

5253

53-
class ROS2GetLastImagePostprocessor(BaseStatePostprocessor[Image | CompressedImage]):
54+
class ROS2GetLastImagePostprocessor(BaseAggregator[Image | CompressedImage]):
5455
"""Returns the last image from the buffer as base64 encoded string"""
5556

56-
def __call__(
57-
self, msgs: Sequence[Image | CompressedImage]
58-
) -> HumanMultimodalMessage | None:
57+
def get(self) -> HumanMultimodalMessage | None:
58+
msgs = self.get_buffer()
5959
if len(msgs) == 0:
6060
return None
6161
ros2_img = msgs[-1]
@@ -64,21 +64,23 @@ def __call__(
6464

6565

6666
class ROS2ImgVLMDescriptionPostprocessor(
67-
BaseStatePostprocessor[Image | CompressedImage]
67+
BaseAggregator[Image | CompressedImage]
6868
):
6969
"""
7070
Returns the VLM analysis of the last image in the aggregation buffer
7171
"""
7272

73-
def __init__(self) -> None:
74-
super().__init__()
73+
def __init__(self, max_size: int | None=None) -> None:
74+
super().__init__(max_size)
7575
self.llm = get_llm_model(model_type="simple_model", streaming=True)
7676

77-
def __call__(self, msgs: Sequence[Image | CompressedImage]) -> HumanMessage | None:
77+
def get(self) -> HumanMessage | None:
78+
msgs: List[Image | CompressedImage] = self.get_buffer()
7879
if len(msgs) == 0:
7980
return None
8081

8182
b64_images: List[str] = [encode_ros2_img_to_base64(msg) for msg in msgs]
83+
self.clear()
8284

8385
system_prompt = "You are an expert in image analysis and your speciality is the"
8486
"description of images"
@@ -102,18 +104,18 @@ class ROS2ImgDescription(BaseModel):
102104
)
103105

104106

105-
class ROS2ImgVLMDiffPostprocessor(BaseStatePostprocessor[Image | CompressedImage]):
107+
class ROS2ImgVLMDiffPostprocessor(BaseAggregator[Image | CompressedImage]):
106108
"""
107109
Returns the LLM analysis of the differences between 3 images in the
108110
aggregation buffer: 1st, midden, last
109111
"""
110112

111-
def __init__(self) -> None:
112-
super().__init__()
113+
def __init__(self, max_size: int | None=None) -> None:
114+
super().__init__(max_size)
113115
self.llm = get_llm_model(model_type="simple_model", streaming=True)
114116

115117
@staticmethod
116-
def get_key_elements(elements: Sequence[Any]) -> List[Any]:
118+
def get_key_elements(elements: List[Any]) -> List[Any]:
117119
"""
118120
Returns 1st, last and middle elements of the list
119121
"""
@@ -122,11 +124,14 @@ def get_key_elements(elements: Sequence[Any]) -> List[Any]:
122124
middle_index = len(elements) // 2
123125
return [elements[0], elements[middle_index], elements[-1]]
124126

125-
def __call__(self, msgs: Sequence[Any]) -> HumanMessage | None:
126-
if len(msgs) == 0:
127+
def get(self) -> HumanMessage | None:
128+
if len(self.get_buffer()) == 0:
127129
return None
128130

129-
b64_images = [encode_ros2_img_to_base64(msg) for msg in msgs]
131+
b64_images = [encode_ros2_img_to_base64(msg) for msg in self._buffer]
132+
133+
self.clear()
134+
130135
b64_images = self.get_key_elements(b64_images)
131136

132137
system_prompt = "You are an expert in image analysis and your speciality is the comparison of 2 images"

0 commit comments

Comments
 (0)