Skip to content

Commit c6c49b9

Browse files
authored
feat: ros2 connector (#379)
1 parent 1595c52 commit c6c49b9

File tree

8 files changed

+654
-157
lines changed

8 files changed

+654
-157
lines changed

src/rai/rai/communication/ari_connector.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Generic, Optional, TypeVar
16-
17-
from pydantic import Field
15+
from typing import Any, Dict, Generic, Optional, TypeVar
1816

1917
from .base_connector import BaseConnector, BaseMessage
2018

@@ -26,15 +24,14 @@ class ARIMessage(BaseMessage):
2624
Inherit from this class to create specific ARI message types.
2725
"""
2826

29-
30-
# TODO: Move this to ros2 module
31-
class ROS2RRIMessage(ARIMessage):
32-
ros_message_type: str = Field(
33-
description="The string representation of the ROS message type (e.g. 'std_msgs/msg/String')"
34-
)
35-
python_message_class: Optional[type] = Field(
36-
description="The Python class of the ROS message type", default=None
37-
)
27+
def __init__(
28+
self,
29+
payload: Any,
30+
metadata: Optional[Dict[str, Any]] = None,
31+
*args: Any,
32+
**kwargs: Any,
33+
):
34+
super().__init__(payload, metadata, *args, **kwargs)
3835

3936

4037
T = TypeVar("T", bound=ARIMessage)

src/rai/rai/communication/base_connector.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,26 @@
1313
# limitations under the License.
1414

1515
from abc import abstractmethod
16-
from typing import Any, Callable, Generic, Optional, Protocol, TypeVar
16+
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
1717
from uuid import uuid4
1818

1919

20-
class BaseMessage(Protocol):
20+
class BaseMessage:
2121
payload: Any
22+
metadata: Dict[str, Any]
2223

23-
def __init__(self, payload: Any, *args, **kwargs):
24+
def __init__(
25+
self,
26+
payload: Any,
27+
metadata: Optional[Dict[str, Any]] = None,
28+
*args: Any,
29+
**kwargs: Any,
30+
):
2431
self.payload = payload
32+
if metadata is None:
33+
self.metadata = {}
34+
else:
35+
self.metadata = metadata
2536

2637

2738
T = TypeVar("T", bound=BaseMessage)

src/rai/rai/communication/ros2/api.py

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import logging
1617
import time
1718
import uuid
19+
from concurrent.futures import ThreadPoolExecutor
1820
from functools import partial
19-
from typing import Annotated, Any, Dict, List, Optional, Tuple, Type, TypedDict, cast
21+
from typing import (
22+
Annotated,
23+
Any,
24+
Callable,
25+
Dict,
26+
List,
27+
Optional,
28+
Tuple,
29+
Type,
30+
TypedDict,
31+
cast,
32+
)
2033

2134
import rclpy
2235
import rclpy.callback_groups
@@ -146,7 +159,7 @@ def publish(
146159
topic: str,
147160
msg_content: Dict[str, Any],
148161
msg_type: str,
149-
*, # Force keyword arguments
162+
*,
150163
auto_qos_matching: bool = True,
151164
qos_profile: Optional[QoSProfile] = None,
152165
) -> None:
@@ -170,11 +183,20 @@ def publish(
170183
publisher = self._get_or_create_publisher(topic, type(msg), qos_profile)
171184
publisher.publish(msg)
172185

186+
def _verify_receive_args(
187+
self, topic: str, auto_topic_type: bool, msg_type: Optional[str]
188+
) -> None:
189+
if auto_topic_type and msg_type is not None:
190+
raise ValueError("Cannot provide both auto_topic_type and msg_type")
191+
if not auto_topic_type and msg_type is None:
192+
raise ValueError("msg_type must be provided if auto_topic_type is False")
193+
173194
def receive(
174195
self,
175196
topic: str,
176-
msg_type: str,
177-
*, # Force keyword arguments
197+
*,
198+
auto_topic_type: bool = True,
199+
msg_type: Optional[str] = None,
178200
timeout_sec: float = 1.0,
179201
auto_qos_matching: bool = True,
180202
qos_profile: Optional[QoSProfile] = None,
@@ -193,8 +215,20 @@ def receive(
193215
194216
Raises:
195217
ValueError: If no publisher exists or no message is received within timeout
218+
ValueError: If auto_topic_type is False and msg_type is not provided
219+
ValueError: If auto_topic_type is True and msg_type is provided
196220
"""
197-
self._verify_publisher_exists(topic)
221+
self._verify_receive_args(topic, auto_topic_type, msg_type)
222+
topic_endpoints = self._verify_publisher_exists(topic)
223+
224+
# TODO: Verify publishers topic type consistency
225+
if auto_topic_type:
226+
msg_type = topic_endpoints[0].topic_type
227+
else:
228+
if msg_type is None:
229+
raise ValueError(
230+
"msg_type must be provided if auto_topic_type is False"
231+
)
198232

199233
qos_profile = self._resolve_qos_profile(
200234
topic, auto_qos_matching, qos_profile, for_publisher=False
@@ -260,16 +294,18 @@ def _get_message_class(msg_type: str) -> Type[Any]:
260294
"""Convert message type string to actual message class."""
261295
return import_message_from_str(msg_type)
262296

263-
def _verify_publisher_exists(self, topic: str) -> None:
297+
def _verify_publisher_exists(self, topic: str) -> List[TopicEndpointInfo]:
264298
"""Verify that at least one publisher exists for the given topic.
265299
266300
Raises:
267301
ValueError: If no publisher exists for the topic
268302
"""
269-
if not self._node.get_publishers_info_by_topic(topic):
303+
topic_endpoints = self._node.get_publishers_info_by_topic(topic)
304+
if not topic_endpoints:
270305
raise ValueError(f"No publisher found for topic: {topic}")
306+
return topic_endpoints
271307

272-
def __del__(self) -> None:
308+
def shutdown(self) -> None:
273309
"""Cleanup publishers when object is destroyed."""
274310
for publisher in self._publishers.values():
275311
publisher.destroy()
@@ -324,18 +360,52 @@ def __init__(self, node: rclpy.node.Node) -> None:
324360
self.node = node
325361
self._logger = node.get_logger()
326362
self.actions: Dict[str, ROS2ActionData] = {}
363+
self._callback_executor = ThreadPoolExecutor(max_workers=10)
327364

328365
def _generate_handle(self):
329366
return str(uuid.uuid4())
330367

331368
def _generic_callback(self, handle: str, feedback_msg: Any) -> None:
332369
self.actions[handle]["feedbacks"].append(feedback_msg.feedback)
333370

371+
def _fan_out_feedback(
372+
self, callbacks: List[Callable[[Any], None]], feedback_msg: Any
373+
) -> None:
374+
"""Fan out feedback message to multiple callbacks concurrently.
375+
376+
Args:
377+
callbacks: List of callback functions to execute
378+
feedback_msg: The feedback message to pass to each callback
379+
"""
380+
for callback in callbacks:
381+
self._callback_executor.submit(
382+
self._safe_callback_wrapper, callback, feedback_msg
383+
)
384+
385+
def _safe_callback_wrapper(
386+
self, callback: Callable[[Any], None], feedback_msg: Any
387+
) -> None:
388+
"""Safely execute a callback with error handling.
389+
390+
Args:
391+
callback: The callback function to execute
392+
feedback_msg: The feedback message to pass to the callback
393+
"""
394+
try:
395+
callback(copy.deepcopy(feedback_msg))
396+
except Exception as e:
397+
self._logger.error(f"Error in feedback callback: {str(e)}")
398+
334399
def send_goal(
335400
self,
336401
action_name: str,
337402
action_type: str,
338403
goal: Dict[str, Any],
404+
*,
405+
feedback_callback: Callable[[Any], None] = lambda _: None,
406+
done_callback: Callable[
407+
[Any], None
408+
] = lambda _: None, # TODO: handle done callback
339409
timeout_sec: float = 1.0,
340410
) -> Tuple[bool, Annotated[str, "action handle"]]:
341411
handle = self._generate_handle()
@@ -355,8 +425,13 @@ def send_goal(
355425
if not action_client.wait_for_server(timeout_sec=timeout_sec): # type: ignore
356426
return False, ""
357427

428+
feedback_callbacks = [
429+
partial(self._generic_callback, handle),
430+
feedback_callback,
431+
]
358432
send_goal_future: Future = action_client.send_goal_async(
359-
goal=action_goal, feedback_callback=partial(self._generic_callback, handle)
433+
goal=action_goal,
434+
feedback_callback=partial(self._fan_out_feedback, feedback_callbacks),
360435
)
361436
self.actions[handle]["action_client"] = action_client
362437
self.actions[handle]["goal_future"] = send_goal_future
@@ -372,6 +447,7 @@ def send_goal(
372447
return False, ""
373448

374449
get_result_future = cast(Future, goal_handle.get_result_async()) # type: ignore
450+
get_result_future.add_done_callback(done_callback) # type: ignore
375451

376452
self.actions[handle]["result_future"] = get_result_future
377453
self.actions[handle]["client_goal_handle"] = goal_handle
@@ -403,3 +479,8 @@ def get_result(self, handle: str) -> Any:
403479
if self.actions[handle]["result_future"] is None:
404480
raise ValueError(f"No result available for goal {handle}")
405481
return self.actions[handle]["result_future"].result()
482+
483+
def shutdown(self) -> None:
484+
"""Cleanup thread pool when object is destroyed."""
485+
if hasattr(self, "_callback_executor"):
486+
self._callback_executor.shutdown(wait=False)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import threading
16+
import uuid
17+
from typing import Any, Callable, Dict, Optional
18+
19+
from rclpy.executors import MultiThreadedExecutor
20+
from rclpy.node import Node
21+
22+
from rai.communication.ari_connector import ARIConnector, ARIMessage
23+
from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI
24+
25+
26+
class ROS2ARIMessage(ARIMessage):
27+
def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None):
28+
super().__init__(payload, metadata)
29+
30+
31+
class ROS2ARIConnector(ARIConnector[ROS2ARIMessage]):
32+
def __init__(
33+
self, node_name: str = f"rai_ros2_ari_connector_{str(uuid.uuid4())[-12:]}"
34+
):
35+
super().__init__()
36+
self._node = Node(node_name)
37+
self._topic_api = ROS2TopicAPI(self._node)
38+
self._service_api = ROS2ServiceAPI(self._node)
39+
self._actions_api = ROS2ActionAPI(self._node)
40+
41+
self._executor = MultiThreadedExecutor()
42+
self._executor.add_node(self._node)
43+
self._thread = threading.Thread(target=self._executor.spin)
44+
self._thread.start()
45+
46+
def send_message(self, message: ROS2ARIMessage, target: str):
47+
auto_qos_matching = message.metadata.get("auto_qos_matching", True)
48+
qos_profile = message.metadata.get("qos_profile", None)
49+
msg_type = message.metadata.get("msg_type", None)
50+
51+
# TODO: allow msg_type to be None, add auto topic type detection
52+
if msg_type is None:
53+
raise ValueError("msg_type is required")
54+
55+
self._topic_api.publish(
56+
topic=target,
57+
msg_content=message.payload,
58+
msg_type=msg_type,
59+
auto_qos_matching=auto_qos_matching,
60+
qos_profile=qos_profile,
61+
)
62+
63+
def receive_message(
64+
self,
65+
source: str,
66+
timeout_sec: float = 1.0,
67+
msg_type: Optional[str] = None,
68+
auto_topic_type: bool = True,
69+
) -> ROS2ARIMessage:
70+
msg = self._topic_api.receive(
71+
topic=source,
72+
timeout_sec=timeout_sec,
73+
msg_type=msg_type,
74+
auto_topic_type=auto_topic_type,
75+
)
76+
return ROS2ARIMessage(
77+
payload=msg, metadata={"msg_type": str(type(msg)), "topic": source}
78+
)
79+
80+
def service_call(
81+
self, message: ROS2ARIMessage, target: str, timeout_sec: float = 1.0
82+
) -> ROS2ARIMessage:
83+
msg = self._service_api.call_service(
84+
service_name=target,
85+
service_type=message.metadata["msg_type"],
86+
request=message.payload,
87+
timeout_sec=timeout_sec,
88+
)
89+
return ROS2ARIMessage(
90+
payload=msg, metadata={"msg_type": str(type(msg)), "service": target}
91+
)
92+
93+
def start_action(
94+
self,
95+
action_data: Optional[ROS2ARIMessage],
96+
target: str,
97+
on_feedback: Callable[[Any], None] = lambda _: None,
98+
on_done: Callable[[Any], None] = lambda _: None,
99+
timeout_sec: float = 1.0,
100+
) -> str:
101+
if not isinstance(action_data, ROS2ARIMessage):
102+
raise ValueError("Action data must be of type ROS2ARIMessage")
103+
msg_type = action_data.metadata.get("msg_type", None)
104+
if msg_type is None:
105+
raise ValueError("msg_type is required")
106+
accepted, handle = self._actions_api.send_goal(
107+
action_name=target,
108+
action_type=msg_type,
109+
goal=action_data.payload,
110+
timeout_sec=timeout_sec,
111+
feedback_callback=on_feedback,
112+
done_callback=on_done,
113+
)
114+
if not accepted:
115+
raise RuntimeError("Action goal was not accepted")
116+
return handle
117+
118+
def terminate_action(self, action_handle: str):
119+
self._actions_api.terminate_goal(action_handle)
120+
121+
def shutdown(self):
122+
self._executor.shutdown()
123+
self._thread.join()
124+
self._actions_api.shutdown()
125+
self._topic_api.shutdown()
126+
self._node.destroy_node()

0 commit comments

Comments
 (0)