12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Annotated , List , Literal , Optional , Sequence
15
+ from typing import (
16
+ Annotated ,
17
+ Generic ,
18
+ List ,
19
+ Literal ,
20
+ Optional ,
21
+ Sequence ,
22
+ TypeVar ,
23
+ get_args ,
24
+ )
16
25
17
26
from langchain_core .messages import AIMessage
18
27
from langchain_core .messages import BaseMessage as LangchainBaseMessage
25
34
from .base_connector import BaseConnector , BaseMessage
26
35
27
36
37
+ class HRIException (Exception ):
38
+ def __init__ (self , msg ):
39
+ super ().__init__ (msg )
40
+
41
+
28
42
class HRIPayload (BaseModel ):
29
43
text : str
30
44
images : Optional [Annotated [List [str ], "base64 encoded png images" ]] = None
@@ -42,8 +56,6 @@ def __init__(
42
56
self .images = payload .images
43
57
self .audios = payload .audios
44
58
45
- # type: Literal["ai", "human"]
46
-
47
59
def __repr__ (self ):
48
60
return f"HRIMessage(type={ self .message_author } , text={ self .text } , images={ self .images } , audios={ self .audios } )"
49
61
@@ -91,7 +103,10 @@ def from_langchain(
91
103
)
92
104
93
105
94
- class HRIConnector (BaseConnector [HRIMessage ]):
106
+ T = TypeVar ("T" , bound = HRIMessage )
107
+
108
+
109
+ class HRIConnector (Generic [T ], BaseConnector [T ]):
95
110
"""
96
111
Base class for Human-Robot Interaction (HRI) connectors.
97
112
Used for sending and receiving messages between human and robot from various sources.
@@ -105,19 +120,26 @@ def __init__(
105
120
):
106
121
self .configured_targets = configured_targets
107
122
self .configured_sources = configured_sources
123
+ if not hasattr (self , "__orig_bases__" ):
124
+ self .__orig_bases__ = {}
125
+ raise HRIException (
126
+ f"Error while instantiating { str (self .__class__ )} : Message type T derived from HRIMessage needs to be provided e.g. Connector[MessageType]()"
127
+ )
128
+ self .T_class = get_args (self .__orig_bases__ [0 ])[0 ]
108
129
109
130
def _build_message (
110
131
self ,
111
132
message : LangchainBaseMessage | RAIMultimodalMessage ,
112
- ) -> HRIMessage :
113
- return HRIMessage .from_langchain (message )
133
+ ) -> T :
134
+
135
+ return self .T_class .from_langchain (message )
114
136
115
137
def send_all_targets (self , message : LangchainBaseMessage | RAIMultimodalMessage ):
116
138
for target in self .configured_targets :
117
139
to_send = self ._build_message (message )
118
140
self .send_message (to_send , target )
119
141
120
- def receive_all_sources (self , timeout_sec : float = 1.0 ) -> dict [str , HRIMessage ]:
142
+ def receive_all_sources (self , timeout_sec : float = 1.0 ) -> dict [str , T ]:
121
143
ret = {}
122
144
for source in self .configured_sources :
123
145
received = self .receive_message (source , timeout_sec )
0 commit comments