12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import copy
15
16
import logging
16
17
import time
17
18
import uuid
19
+ from concurrent .futures import ThreadPoolExecutor
18
20
from functools import partial
19
- from typing import Annotated , Any , Dict , List , Optional , Tuple , Type , TypedDict , cast
21
+ from typing import (
22
+ Annotated ,
23
+ Any ,
24
+ Callable ,
25
+ Dict ,
26
+ List ,
27
+ Optional ,
28
+ Tuple ,
29
+ Type ,
30
+ TypedDict ,
31
+ cast ,
32
+ )
20
33
21
34
import rclpy
22
35
import rclpy .callback_groups
@@ -146,7 +159,7 @@ def publish(
146
159
topic : str ,
147
160
msg_content : Dict [str , Any ],
148
161
msg_type : str ,
149
- * , # Force keyword arguments
162
+ * ,
150
163
auto_qos_matching : bool = True ,
151
164
qos_profile : Optional [QoSProfile ] = None ,
152
165
) -> None :
@@ -170,11 +183,20 @@ def publish(
170
183
publisher = self ._get_or_create_publisher (topic , type (msg ), qos_profile )
171
184
publisher .publish (msg )
172
185
186
+ def _verify_receive_args (
187
+ self , topic : str , auto_topic_type : bool , msg_type : Optional [str ]
188
+ ) -> None :
189
+ if auto_topic_type and msg_type is not None :
190
+ raise ValueError ("Cannot provide both auto_topic_type and msg_type" )
191
+ if not auto_topic_type and msg_type is None :
192
+ raise ValueError ("msg_type must be provided if auto_topic_type is False" )
193
+
173
194
def receive (
174
195
self ,
175
196
topic : str ,
176
- msg_type : str ,
177
- * , # Force keyword arguments
197
+ * ,
198
+ auto_topic_type : bool = True ,
199
+ msg_type : Optional [str ] = None ,
178
200
timeout_sec : float = 1.0 ,
179
201
auto_qos_matching : bool = True ,
180
202
qos_profile : Optional [QoSProfile ] = None ,
@@ -193,8 +215,20 @@ def receive(
193
215
194
216
Raises:
195
217
ValueError: If no publisher exists or no message is received within timeout
218
+ ValueError: If auto_topic_type is False and msg_type is not provided
219
+ ValueError: If auto_topic_type is True and msg_type is provided
196
220
"""
197
- self ._verify_publisher_exists (topic )
221
+ self ._verify_receive_args (topic , auto_topic_type , msg_type )
222
+ topic_endpoints = self ._verify_publisher_exists (topic )
223
+
224
+ # TODO: Verify publishers topic type consistency
225
+ if auto_topic_type :
226
+ msg_type = topic_endpoints [0 ].topic_type
227
+ else :
228
+ if msg_type is None :
229
+ raise ValueError (
230
+ "msg_type must be provided if auto_topic_type is False"
231
+ )
198
232
199
233
qos_profile = self ._resolve_qos_profile (
200
234
topic , auto_qos_matching , qos_profile , for_publisher = False
@@ -260,16 +294,18 @@ def _get_message_class(msg_type: str) -> Type[Any]:
260
294
"""Convert message type string to actual message class."""
261
295
return import_message_from_str (msg_type )
262
296
263
- def _verify_publisher_exists (self , topic : str ) -> None :
297
+ def _verify_publisher_exists (self , topic : str ) -> List [ TopicEndpointInfo ] :
264
298
"""Verify that at least one publisher exists for the given topic.
265
299
266
300
Raises:
267
301
ValueError: If no publisher exists for the topic
268
302
"""
269
- if not self ._node .get_publishers_info_by_topic (topic ):
303
+ topic_endpoints = self ._node .get_publishers_info_by_topic (topic )
304
+ if not topic_endpoints :
270
305
raise ValueError (f"No publisher found for topic: { topic } " )
306
+ return topic_endpoints
271
307
272
- def __del__ (self ) -> None :
308
+ def shutdown (self ) -> None :
273
309
"""Cleanup publishers when object is destroyed."""
274
310
for publisher in self ._publishers .values ():
275
311
publisher .destroy ()
@@ -324,18 +360,52 @@ def __init__(self, node: rclpy.node.Node) -> None:
324
360
self .node = node
325
361
self ._logger = node .get_logger ()
326
362
self .actions : Dict [str , ROS2ActionData ] = {}
363
+ self ._callback_executor = ThreadPoolExecutor (max_workers = 10 )
327
364
328
365
def _generate_handle (self ):
329
366
return str (uuid .uuid4 ())
330
367
331
368
def _generic_callback (self , handle : str , feedback_msg : Any ) -> None :
332
369
self .actions [handle ]["feedbacks" ].append (feedback_msg .feedback )
333
370
371
+ def _fan_out_feedback (
372
+ self , callbacks : List [Callable [[Any ], None ]], feedback_msg : Any
373
+ ) -> None :
374
+ """Fan out feedback message to multiple callbacks concurrently.
375
+
376
+ Args:
377
+ callbacks: List of callback functions to execute
378
+ feedback_msg: The feedback message to pass to each callback
379
+ """
380
+ for callback in callbacks :
381
+ self ._callback_executor .submit (
382
+ self ._safe_callback_wrapper , callback , feedback_msg
383
+ )
384
+
385
+ def _safe_callback_wrapper (
386
+ self , callback : Callable [[Any ], None ], feedback_msg : Any
387
+ ) -> None :
388
+ """Safely execute a callback with error handling.
389
+
390
+ Args:
391
+ callback: The callback function to execute
392
+ feedback_msg: The feedback message to pass to the callback
393
+ """
394
+ try :
395
+ callback (copy .deepcopy (feedback_msg ))
396
+ except Exception as e :
397
+ self ._logger .error (f"Error in feedback callback: { str (e )} " )
398
+
334
399
def send_goal (
335
400
self ,
336
401
action_name : str ,
337
402
action_type : str ,
338
403
goal : Dict [str , Any ],
404
+ * ,
405
+ feedback_callback : Callable [[Any ], None ] = lambda _ : None ,
406
+ done_callback : Callable [
407
+ [Any ], None
408
+ ] = lambda _ : None , # TODO: handle done callback
339
409
timeout_sec : float = 1.0 ,
340
410
) -> Tuple [bool , Annotated [str , "action handle" ]]:
341
411
handle = self ._generate_handle ()
@@ -355,8 +425,13 @@ def send_goal(
355
425
if not action_client .wait_for_server (timeout_sec = timeout_sec ): # type: ignore
356
426
return False , ""
357
427
428
+ feedback_callbacks = [
429
+ partial (self ._generic_callback , handle ),
430
+ feedback_callback ,
431
+ ]
358
432
send_goal_future : Future = action_client .send_goal_async (
359
- goal = action_goal , feedback_callback = partial (self ._generic_callback , handle )
433
+ goal = action_goal ,
434
+ feedback_callback = partial (self ._fan_out_feedback , feedback_callbacks ),
360
435
)
361
436
self .actions [handle ]["action_client" ] = action_client
362
437
self .actions [handle ]["goal_future" ] = send_goal_future
@@ -372,6 +447,7 @@ def send_goal(
372
447
return False , ""
373
448
374
449
get_result_future = cast (Future , goal_handle .get_result_async ()) # type: ignore
450
+ get_result_future .add_done_callback (done_callback ) # type: ignore
375
451
376
452
self .actions [handle ]["result_future" ] = get_result_future
377
453
self .actions [handle ]["client_goal_handle" ] = goal_handle
@@ -403,3 +479,8 @@ def get_result(self, handle: str) -> Any:
403
479
if self .actions [handle ]["result_future" ] is None :
404
480
raise ValueError (f"No result available for goal { handle } " )
405
481
return self .actions [handle ]["result_future" ].result ()
482
+
483
+ def shutdown (self ) -> None :
484
+ """Cleanup thread pool when object is destroyed."""
485
+ if hasattr (self , "_callback_executor" ):
486
+ self ._callback_executor .shutdown (wait = False )
0 commit comments