Skip to content

Commit 1595c52

Browse files
authored
feat: make HRIConnector and ARIConnector classes generic (#375)
1 parent c7d1f2b commit 1595c52

File tree

4 files changed

+40
-13
lines changed

4 files changed

+40
-13
lines changed

src/rai/rai/communication/ari_connector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import Generic, Optional, TypeVar
1616

1717
from pydantic import Field
1818

@@ -37,7 +37,10 @@ class ROS2RRIMessage(ARIMessage):
3737
)
3838

3939

40-
class ARIConnector(BaseConnector[ARIMessage]):
40+
T = TypeVar("T", bound=ARIMessage)
41+
42+
43+
class ARIConnector(Generic[T], BaseConnector[T]):
4144
"""
4245
Base class for Agent-Robot Interface (ARI) connectors.
4346

src/rai/rai/communication/base_connector.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from abc import ABC, abstractmethod
16-
from typing import Any, Callable, Generic, Optional, TypeVar
15+
from abc import abstractmethod
16+
from typing import Any, Callable, Generic, Optional, Protocol, TypeVar
1717
from uuid import uuid4
1818

1919

20-
class BaseMessage(ABC):
20+
class BaseMessage(Protocol):
21+
payload: Any
22+
2123
def __init__(self, payload: Any, *args, **kwargs):
2224
self.payload = payload
2325

src/rai/rai/communication/hri_connector.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Annotated, List, Literal, Optional, Sequence
15+
from typing import (
16+
Annotated,
17+
Generic,
18+
List,
19+
Literal,
20+
Optional,
21+
Sequence,
22+
TypeVar,
23+
get_args,
24+
)
1625

1726
from langchain_core.messages import AIMessage
1827
from langchain_core.messages import BaseMessage as LangchainBaseMessage
@@ -25,6 +34,11 @@
2534
from .base_connector import BaseConnector, BaseMessage
2635

2736

37+
class HRIException(Exception):
38+
def __init__(self, msg):
39+
super().__init__(msg)
40+
41+
2842
class HRIPayload(BaseModel):
2943
text: str
3044
images: Optional[Annotated[List[str], "base64 encoded png images"]] = None
@@ -42,8 +56,6 @@ def __init__(
4256
self.images = payload.images
4357
self.audios = payload.audios
4458

45-
# type: Literal["ai", "human"]
46-
4759
def __repr__(self):
4860
return f"HRIMessage(type={self.message_author}, text={self.text}, images={self.images}, audios={self.audios})"
4961

@@ -91,7 +103,10 @@ def from_langchain(
91103
)
92104

93105

94-
class HRIConnector(BaseConnector[HRIMessage]):
106+
T = TypeVar("T", bound=HRIMessage)
107+
108+
109+
class HRIConnector(Generic[T], BaseConnector[T]):
95110
"""
96111
Base class for Human-Robot Interaction (HRI) connectors.
97112
Used for sending and receiving messages between human and robot from various sources.
@@ -105,19 +120,26 @@ def __init__(
105120
):
106121
self.configured_targets = configured_targets
107122
self.configured_sources = configured_sources
123+
if not hasattr(self, "__orig_bases__"):
124+
self.__orig_bases__ = {}
125+
raise HRIException(
126+
f"Error while instantiating {str(self.__class__)}: Message type T derived from HRIMessage needs to be provided e.g. Connector[MessageType]()"
127+
)
128+
self.T_class = get_args(self.__orig_bases__[0])[0]
108129

109130
def _build_message(
110131
self,
111132
message: LangchainBaseMessage | RAIMultimodalMessage,
112-
) -> HRIMessage:
113-
return HRIMessage.from_langchain(message)
133+
) -> T:
134+
135+
return self.T_class.from_langchain(message)
114136

115137
def send_all_targets(self, message: LangchainBaseMessage | RAIMultimodalMessage):
116138
for target in self.configured_targets:
117139
to_send = self._build_message(message)
118140
self.send_message(to_send, target)
119141

120-
def receive_all_sources(self, timeout_sec: float = 1.0) -> dict[str, HRIMessage]:
142+
def receive_all_sources(self, timeout_sec: float = 1.0) -> dict[str, T]:
121143
ret = {}
122144
for source in self.configured_sources:
123145
received = self.receive_message(source, timeout_sec)

src/rai/rai/communication/sound_device_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, config: AudioInputDeviceConfig):
6262
self.dtype = config["dtype"]
6363

6464

65-
class StreamingAudioInputDevice(HRIConnector):
65+
class StreamingAudioInputDevice(HRIConnector[HRIMessage]):
6666
"""Audio input device connector implementing the Human-Robot Interface.
6767
6868
This class provides audio streaming capabilities while conforming to the

0 commit comments

Comments
 (0)