Skip to content

Commit 57195fa

Browse files
committed
feat(state_based): aggregators support optional msg type
1 parent 60977eb commit 57195fa

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

examples/agents/state_based.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def main():
3434

3535
config = StateBasedConfig(
3636
aggregators={
37-
"/camera/camera/color/image_raw": [ROS2ImgVLMDiffAggregator()],
37+
("/camera/camera/color/image_raw", "sensor_msgs/msg/Image"): [
38+
ROS2ImgVLMDiffAggregator()
39+
],
3840
"/rosout": [ROS2LogsAggregator()],
3941
}
4042
)

src/rai_core/rai/agents/langchain/state_based_agent.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import time
1717
from abc import ABC, abstractmethod
1818
from concurrent.futures import ThreadPoolExecutor, as_completed
19-
from typing import Dict, List, Optional, Tuple
19+
from typing import Dict, List, Optional, Tuple, Union
2020

2121
from langchain_core.language_models import BaseChatModel
2222
from langchain_core.messages import BaseMessage, HumanMessage
@@ -33,7 +33,9 @@
3333

3434

3535
class StateBasedConfig(BaseModel):
36-
aggregators: Dict[str, List[BaseAggregator]]
36+
aggregators: Dict[Union[str, Tuple[str, str]], List[BaseAggregator]] = Field(
37+
description="Dict of topic : aggregator or (topic, msg_type) : aggragator"
38+
)
3739
time_interval: float = Field(default=5.0)
3840
max_workers: int = 8
3941

@@ -96,9 +98,13 @@ def setup_connector(self) -> BaseConnector:
9698

9799
def _configure_state_sources(self):
98100
for source, aggregators in self.config.aggregators.items():
101+
if isinstance(source, tuple):
102+
source, msg_type = source
103+
else:
104+
msg_type = None
99105
for aggregator in aggregators:
100106
callback_id = self._connector.register_callback(
101-
source, aggregator, raw=True
107+
source, aggregator, raw=True, msg_type=msg_type
102108
)
103109
self._registered_callbacks.add(callback_id)
104110

0 commit comments

Comments
 (0)