Skip to content

Commit 3998a7b

Browse files
committed
cleanup
1 parent beddec0 commit 3998a7b

File tree

3 files changed

+10
-25
lines changed

3 files changed

+10
-25
lines changed

examples/agents/state_based.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,14 @@ def main():
3030
connector = ROS2HRIConnector(sources=["/from_human"], targets=["/to_human"])
3131
ros2_connector = ROS2Connector()
3232

33-
aggregators={
33+
aggregators = {
3434
"/camera/camera/color/image_raw": [ROS2ImgVLMDiffAggregator()],
3535
"/rosout": [ROS2LogsAggregator()],
3636
}
3737

38-
for source, aggregators in self.config.aggregators.items():
38+
for source, aggregators in aggregators.items():
3939
for aggregator in aggregators:
40-
ros2_connector.register_callback(
41-
source,
42-
aggregator,
43-
raw=True
44-
)
40+
ros2_connector.register_callback(source, aggregator, raw=True)
4541

4642
agent = StateBasedAgent(
4743
config=StateBasedConfig(aggregators=aggregators),

src/rai_core/rai/agents/state_based_agent.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,19 @@
1414

1515
import threading
1616
import time
17-
from collections import defaultdict
1817
from concurrent.futures import ThreadPoolExecutor, as_completed
19-
from functools import partial
20-
from typing import Any, Dict, List, Optional, Tuple
18+
from typing import Dict, List, Optional, Tuple
2119

2220
from langchain_core.language_models import BaseChatModel
2321
from langchain_core.messages import HumanMessage
2422
from langchain_core.tools import BaseTool
2523
from pydantic import BaseModel, ConfigDict, Field
26-
from rai.aggregators import BaseAggregator
2724
from rclpy.callback_groups import ReentrantCallbackGroup
2825
from rclpy.subscription import Subscription
2926

3027
from rai.agents.langchain.runnables import create_state_based_runnable
28+
from rai.aggregators import BaseAggregator
3129
from rai.communication.hri_connector import HRIConnector, HRIMessage
32-
from rai.communication.ros2.connectors import ROS2Connector
3330
from rai.messages.multimodal import HumanMultimodalMessage
3431

3532
from .react_agent import ReActAgent, ReActAgentState
@@ -48,7 +45,7 @@ class StateBasedConfig(BaseModel):
4845
class StateBasedAgent(ReActAgent):
4946
"""
5047
Agent that runs aggregators (config.aggregators) every config.time_interval seconds.
51-
Aggregators are registered to their sources using
48+
Aggregators are registered to their sources using
5249
:py:class:`~rai.communication.ros2.connectors.ROS2Connector`
5350
5451
Output from aggragators is called `state`. Such state is saved and can be
@@ -77,27 +74,19 @@ def __init__(
7774
connectors, llm, tools, state, system_prompt, runnable=runnable
7875
)
7976
self.config = config
80-
self.ros2_connector = ros2_connector
8177

8278
self._callback_group = ReentrantCallbackGroup()
8379
self._subscriptions: Dict[str, Subscription] = dict()
8480

8581
self._aggregation_results: Dict[str, HumanMessage | HumanMultimodalMessage] = (
8682
dict()
8783
)
88-
self._db: Dict[str, List[Any]] = defaultdict(list)
89-
self._db_lock = threading.Lock()
90-
9184
self._aggregation_thread: threading.Thread | None = None
9285

9386
def _configure_state_sources(self):
9487
for source, aggregators in self.config.aggregators.items():
9588
for aggregator in aggregators:
96-
self.ros2_connector.register_callback(
97-
source,
98-
aggregator,
99-
raw=True
100-
)
89+
self.ros2_connector.register_callback(source, aggregator, raw=True)
10190

10291
def run(self):
10392
super().run()
@@ -157,6 +146,6 @@ def stop(self):
157146
if self._aggregation_thread is not None:
158147
self._aggregation_thread.join()
159148
self._aggregation_thread = None
160-
#TODO(boczekbartek): deregister aggregators
149+
# TODO(boczekbartek): deregister aggregators
161150
self._stop_event.clear()
162151
self.logger.info("Agent stopped")

src/rai_core/rai/tools/ros2/generic/topics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from pydantic import BaseModel, Field
2323

2424
from rai.communication.ros2 import ROS2Connector, ROS2Message
25-
from rai.communication.ros2.api.conversion import ros2_message_to_dict
26-
from rai.messages import MultimodalArtifact, preprocess_image
25+
from rai.communication.ros2.api.conversion import encode_ros2_img_to_base64, ros2_message_to_dict
26+
from rai.messages import MultimodalArtifact
2727
from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit
2828

2929

0 commit comments

Comments
 (0)