Skip to content

Commit 7e3005a

Browse files
Add types to waitable.py (#1328)
* add types Signed-off-by: Michael Carlstrom <[email protected]> * move typing into string Signed-off-by: Michael Carlstrom <[email protected]> * move Future type into string Signed-off-by: Michael Carlstrom <[email protected]> * flake8 fixes Signed-off-by: Michael Carlstrom <[email protected]> * move typedicts to outside TYPE_CHECKING Signed-off-by: Michael Carlstrom <[email protected]> * rerun stuck ci Signed-off-by: Michael Carlstrom <[email protected]> * undo accidental removal Signed-off-by: Michael Carlstrom <[email protected]> * add functions Signed-off-by: Michael Carlstrom <[email protected]> --------- Signed-off-by: Michael Carlstrom <[email protected]> Co-authored-by: Shane Loretz <[email protected]>
1 parent 1eb4208 commit 7e3005a

9 files changed

+126
-37
lines changed

rclpy/rclpy/action/client.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import threading
1616
import time
17+
from typing import Any
18+
from typing import TypedDict
1719
import uuid
1820
import weakref
1921

@@ -32,6 +34,14 @@
3234
from unique_identifier_msgs.msg import UUID
3335

3436

37+
class ClientGoalHandleDict(TypedDict, total=False):
38+
goal: Any
39+
cancel: Any
40+
result: Any
41+
feedback: Any
42+
status: Any
43+
44+
3545
class ClientGoalHandle():
3646
"""Goal handle for working with Action Clients."""
3747

@@ -108,7 +118,7 @@ def get_result_async(self):
108118
return self._action_client._get_result_async(self)
109119

110120

111-
class ActionClient(Waitable):
121+
class ActionClient(Waitable[ClientGoalHandleDict]):
112122
"""ROS Action client."""
113123

114124
def __init__(
@@ -237,9 +247,9 @@ def is_ready(self, wait_set):
237247
self._is_result_response_ready = ready_entities[4]
238248
return any(ready_entities)
239249

240-
def take_data(self):
250+
def take_data(self) -> ClientGoalHandleDict:
241251
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
242-
data = {}
252+
data: ClientGoalHandleDict = {}
243253
if self._is_goal_response_ready:
244254
taken_data = self._client_handle.take_goal_response(
245255
self._action_type.Impl.SendGoalService.Response)
@@ -277,7 +287,7 @@ def take_data(self):
277287

278288
return data
279289

280-
async def execute(self, taken_data):
290+
async def execute(self, taken_data: ClientGoalHandleDict) -> None:
281291
"""
282292
Execute work after data has been taken from a ready wait set.
283293

rclpy/rclpy/action/server.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import threading
1818
import traceback
1919

20+
from typing import Any, TypedDict
21+
2022
from action_msgs.msg import GoalInfo, GoalStatus
2123

2224
from rclpy.executors import await_or_execute
@@ -49,6 +51,13 @@ class CancelResponse(Enum):
4951
GoalEvent = _rclpy.GoalEvent
5052

5153

54+
class ServerGoalHandleDict(TypedDict, total=False):
55+
goal: Any
56+
cancel: Any
57+
result: Any
58+
expired: Any
59+
60+
5261
class ServerGoalHandle:
5362
"""Goal handle for working with Action Servers."""
5463

@@ -178,7 +187,7 @@ def default_cancel_callback(cancel_request):
178187
return CancelResponse.REJECT
179188

180189

181-
class ActionServer(Waitable):
190+
class ActionServer(Waitable[ServerGoalHandleDict]):
182191
"""ROS Action server."""
183192

184193
def __init__(
@@ -446,9 +455,9 @@ def is_ready(self, wait_set):
446455
self._is_goal_expired = ready_entities[3]
447456
return any(ready_entities)
448457

449-
def take_data(self):
458+
def take_data(self) -> ServerGoalHandleDict:
450459
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
451-
data = {}
460+
data: ServerGoalHandleDict = {}
452461
if self._is_goal_request_ready:
453462
with self._lock:
454463
taken_data = self._handle.take_goal_request(
@@ -482,7 +491,7 @@ def take_data(self):
482491

483492
return data
484493

485-
async def execute(self, taken_data):
494+
async def execute(self, taken_data: ServerGoalHandleDict) -> None:
486495
"""
487496
Execute work after data has been taken from a ready wait set.
488497

rclpy/rclpy/callback_groups.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from threading import Lock
16-
from typing import Literal, Optional, TYPE_CHECKING, Union
16+
from typing import Any, Literal, Optional, TYPE_CHECKING, Union
1717
import weakref
1818

1919

@@ -23,7 +23,7 @@
2323
from rclpy.client import Client
2424
from rclpy.service import Service
2525
from rclpy.waitable import Waitable
26-
Entity = Union[Subscription, Timer, Client, Service, Waitable]
26+
Entity = Union[Subscription, Timer, Client, Service, Waitable[Any]]
2727

2828

2929
class CallbackGroup:

rclpy/rclpy/event_handler.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from enum import IntEnum
16+
from typing import Any
1617
from typing import Callable
1718
from typing import List
1819
from typing import Optional
@@ -27,6 +28,9 @@
2728
from rclpy.waitable import NumberOfEntities
2829
from rclpy.waitable import Waitable
2930

31+
if TYPE_CHECKING:
32+
from typing import TypeAlias
33+
3034

3135
if TYPE_CHECKING:
3236
from rclpy.subscription import SubscriptionHandle
@@ -75,7 +79,10 @@
7579
UnsupportedEventTypeError = _rclpy.UnsupportedEventTypeError
7680

7781

78-
class EventHandler(Waitable):
82+
EventHandlerData: 'TypeAlias' = Optional[Any]
83+
84+
85+
class EventHandler(Waitable[EventHandlerData]):
7986
"""Waitable type to handle QoS events."""
8087

8188
def __init__(
@@ -106,15 +113,15 @@ def is_ready(self, wait_set):
106113
self._ready_to_take_data = True
107114
return self._ready_to_take_data
108115

109-
def take_data(self):
116+
def take_data(self) -> EventHandlerData:
110117
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
111118
if self._ready_to_take_data:
112119
self._ready_to_take_data = False
113120
with self.__event:
114121
return self.__event.take_event()
115122
return None
116123

117-
async def execute(self, taken_data):
124+
async def execute(self, taken_data: EventHandlerData) -> None:
118125
"""Execute work after data has been taken from a ready wait set."""
119126
if not taken_data:
120127
return

rclpy/rclpy/executors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def _wait_for_ready_callbacks(
610610
timers: List[Timer] = []
611611
clients: List[Client] = []
612612
services: List[Service] = []
613-
waitables: List[Waitable] = []
613+
waitables: List[Waitable[Any]] = []
614614
for node in nodes_to_use:
615615
subscriptions.extend(filter(self.can_execute, node.subscriptions))
616616
timers.extend(filter(self.can_execute, node.timers))

rclpy/rclpy/node.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717

1818
from types import TracebackType
19+
from typing import Any
1920
from typing import Callable
2021
from typing import Dict
2122
from typing import Iterator
@@ -181,7 +182,7 @@ def __init__(
181182
self._services: List[Service] = []
182183
self._timers: List[Timer] = []
183184
self._guards: List[GuardCondition] = []
184-
self.__waitables: List[Waitable] = []
185+
self.__waitables: List[Waitable[Any]] = []
185186
self._default_callback_group = MutuallyExclusiveCallbackGroup()
186187
self._pre_set_parameters_callbacks: List[Callable[[List[Parameter]], List[Parameter]]] = []
187188
self._on_set_parameters_callbacks: \
@@ -290,7 +291,7 @@ def guards(self) -> Iterator[GuardCondition]:
290291
yield from self._guards
291292

292293
@property
293-
def waitables(self) -> Iterator[Waitable]:
294+
def waitables(self) -> Iterator[Waitable[Any]]:
294295
"""Get waitables that have been created on this node."""
295296
yield from self.__waitables
296297

@@ -1485,7 +1486,7 @@ def _validate_qos_or_depth_parameter(self, qos_or_depth) -> QoSProfile:
14851486
raise TypeError(
14861487
'Expected QoSProfile or int, but received {!r}'.format(type(qos_or_depth)))
14871488

1488-
def add_waitable(self, waitable: Waitable) -> None:
1489+
def add_waitable(self, waitable: Waitable[Any]) -> None:
14891490
"""
14901491
Add a class that is capable of adding things to the wait set.
14911492
@@ -1494,7 +1495,7 @@ def add_waitable(self, waitable: Waitable) -> None:
14941495
self.__waitables.append(waitable)
14951496
self._wake_executor()
14961497

1497-
def remove_waitable(self, waitable: Waitable) -> None:
1498+
def remove_waitable(self, waitable: Waitable[Any]) -> None:
14981499
"""
14991500
Remove a Waitable that was previously added to the node.
15001501

rclpy/rclpy/waitable.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from types import TracebackType
16+
from typing import Any, Generic, List, Optional, Type, TYPE_CHECKING, TypeVar
17+
18+
19+
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
20+
21+
T = TypeVar('T')
22+
23+
24+
if TYPE_CHECKING:
25+
from typing_extensions import Self
26+
27+
from rclpy.callback_groups import CallbackGroup
28+
from rclpy.task import Future
29+
1530

1631
class NumberOfEntities:
1732

@@ -24,8 +39,8 @@ class NumberOfEntities:
2439
'num_events']
2540

2641
def __init__(
27-
self, num_subs=0, num_gcs=0, num_timers=0,
28-
num_clients=0, num_services=0, num_events=0
42+
self, num_subs: int = 0, num_gcs: int = 0, num_timers: int = 0,
43+
num_clients: int = 0, num_services: int = 0, num_events: int = 0
2944
):
3045
self.num_subscriptions = num_subs
3146
self.num_guard_conditions = num_gcs
@@ -34,7 +49,7 @@ def __init__(
3449
self.num_services = num_services
3550
self.num_events = num_events
3651

37-
def __add__(self, other):
52+
def __add__(self, other: 'NumberOfEntities') -> 'NumberOfEntities':
3853
result = self.__class__()
3954
result.num_subscriptions = self.num_subscriptions + other.num_subscriptions
4055
result.num_guard_conditions = self.num_guard_conditions + other.num_guard_conditions
@@ -44,7 +59,7 @@ def __add__(self, other):
4459
result.num_events = self.num_events + other.num_events
4560
return result
4661

47-
def __iadd__(self, other):
62+
def __iadd__(self, other: 'NumberOfEntities') -> 'NumberOfEntities':
4863
self.num_subscriptions += other.num_subscriptions
4964
self.num_guard_conditions += other.num_guard_conditions
5065
self.num_timers += other.num_timers
@@ -53,59 +68,64 @@ def __iadd__(self, other):
5368
self.num_events += other.num_events
5469
return self
5570

56-
def __repr__(self):
71+
def __repr__(self) -> str:
5772
return '<{0}({1}, {2}, {3}, {4}, {5}, {6})>'.format(
5873
self.__class__.__name__, self.num_subscriptions,
5974
self.num_guard_conditions, self.num_timers, self.num_clients,
6075
self.num_services, self.num_events)
6176

6277

63-
class Waitable:
78+
class Waitable(Generic[T]):
6479
"""
6580
Add something to a wait set and execute it.
6681
6782
This class wraps a collection of entities which can be added to a wait set.
6883
"""
6984

70-
def __init__(self, callback_group):
85+
def __init__(self, callback_group: 'CallbackGroup'):
7186
# A callback group to control when this entity can execute (used by Executor)
7287
self.callback_group = callback_group
7388
self.callback_group.add_entity(self)
7489
# Flag set by executor when a handler has been created but not executed (used by Executor)
7590
self._executor_event = False
7691
# List of Futures that have callbacks needing execution
77-
self._futures = []
92+
self._futures: List[Future[Any]] = []
7893

79-
def __enter__(self):
94+
def __enter__(self) -> 'Self':
8095
"""Implement to mark entities as in-use to prevent destruction while waiting on them."""
81-
pass
96+
raise NotImplementedError('Must be implemented by subclass')
8297

83-
def __exit__(self, t, v, tb):
98+
def __exit__(
99+
self,
100+
exc_type: Optional[Type[BaseException]],
101+
exc_val: Optional[BaseException],
102+
exc_tb: Optional[TracebackType],
103+
) -> None:
84104
"""Implement to mark entities as not-in-use to allow destruction after waiting on them."""
85-
pass
105+
raise NotImplementedError('Must be implemented by subclass')
86106

87-
def add_future(self, future):
107+
def add_future(self, future: 'Future[Any]') -> None:
88108
self._futures.append(future)
89109

90-
def remove_future(self, future):
110+
def remove_future(self, future: 'Future[Any]') -> None:
91111
self._futures.remove(future)
92112

93-
def is_ready(self, wait_set):
113+
def is_ready(self, wait_set: _rclpy.WaitSet) -> bool:
94114
"""Return True if entities are ready in the wait set."""
95115
raise NotImplementedError('Must be implemented by subclass')
96116

97-
def take_data(self):
117+
def take_data(self) -> T:
98118
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
99119
raise NotImplementedError('Must be implemented by subclass')
100120

101-
async def execute(self, taken_data):
121+
async def execute(self, taken_data: T) -> None:
102122
"""Execute work after data has been taken from a ready wait set."""
103123
raise NotImplementedError('Must be implemented by subclass')
104124

105-
def get_num_entities(self):
125+
def get_num_entities(self) -> NumberOfEntities:
106126
"""Return number of each type of entity used."""
107127
raise NotImplementedError('Must be implemented by subclass')
108128

109-
def add_to_wait_set(self, wait_set):
129+
def add_to_wait_set(self, wait_set: _rclpy.WaitSet) -> None:
110130
"""Add entities to wait set."""
111131
raise NotImplementedError('Must be implemented by subclass')

rclpy/test/test_create_while_spinning.py

+6
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ class DummyWaitable(Waitable):
9494
def __init__(self):
9595
super().__init__(ReentrantCallbackGroup())
9696

97+
def __enter__(self):
98+
return self
99+
100+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
101+
pass
102+
97103
def is_ready(self, wait_set):
98104
return False
99105

0 commit comments

Comments
 (0)