Skip to content

Commit 89ec76c

Browse files
committed
refactor: registering on source in agent
- rollback registering using __call__ method - add subscribe_source method
1 parent 0de696d commit 89ec76c

File tree

7 files changed

+26
-23
lines changed

7 files changed

+26
-23
lines changed

examples/agents/react.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language goveself.rning permissions and
1313
# limitations under the License.
1414

15+
1516
from rai.agents import AgentRunner
1617
from rai.agents.langchain.react_agent import ReActAgent
1718
from rai.communication.ros2 import ROS2Connector, ROS2Context
@@ -23,14 +24,15 @@
2324
def main():
2425
ros2_connector = ROS2Connector()
2526
hri_connector = ROS2HRIConnector()
26-
target_connectors = {"/to_human": hri_connector}
27+
2728
agent = ReActAgent(
28-
target_connectors=target_connectors,
29+
target_connectors={
30+
"/to_human": hri_connector,
31+
},
2932
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
3033
)
31-
hri_connector.register_callback(
32-
"/from_human", agent, msg_type="rai_interfaces/msg/HRIMessage"
33-
)
34+
agent.subscribe_source("/from_human", hri_connector)
35+
3436
runner = AgentRunner([agent])
3537
runner.run_and_wait_for_shutdown()
3638

src/rai_core/rai/agents/langchain/agent.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,22 @@
2121

2222
from langchain_core.messages import BaseMessage
2323
from langchain_core.runnables import Runnable
24-
from pydantic import BaseModel
2524

2625
from rai.agents.base import BaseAgent
2726
from rai.agents.langchain import HRICallbackHandler
2827
from rai.agents.langchain.runnables import ReActAgentState
29-
from rai.communication.base_connector import BaseConnector
30-
from rai.communication.hri_connector import HRIMessage
28+
from rai.communication.hri_connector import HRIConnector, HRIMessage
3129
from rai.initialization import get_tracing_callbacks
3230

3331

3432
class BaseState(TypedDict):
3533
messages: List[BaseMessage]
3634

3735

38-
class HRIConfig(BaseModel):
39-
source: str
40-
targets: List[str]
41-
42-
4336
class LangChainAgent(BaseAgent):
4437
def __init__(
4538
self,
46-
target_connectors: Dict[str, BaseConnector],
39+
target_connectors: Dict[str, HRIConnector[HRIMessage]],
4740
runnable: Runnable,
4841
state: BaseState | None = None,
4942
new_message_behavior: Literal[
@@ -61,7 +54,7 @@ def __init__(
6154
self.new_message_behavior = new_message_behavior
6255
self.tracing_callbacks = get_tracing_callbacks()
6356
self.state = state or ReActAgentState(messages=[])
64-
self.callback = HRICallbackHandler(
57+
self._langchain_callback = HRICallbackHandler(
6558
connectors=target_connectors,
6659
aggregate_chunks=True,
6760
logger=self.logger,
@@ -76,7 +69,13 @@ def __init__(
7669
self._interupt_event = threading.Event()
7770
self._agent_ready_event = threading.Event()
7871

79-
def __call__(self, msg: HRIMessage):
72+
def subscribe_source(self, source: str, connector: HRIConnector[HRIMessage]):
73+
connector.register_callback(
74+
source,
75+
self.source_callback,
76+
)
77+
78+
def source_callback(self, msg: HRIMessage):
8079
if self.max_size is not None and len(self._received_messages) >= self.max_size:
8180
self.logger.warning("Buffer overflow. Dropping olders message")
8281
self._received_messages.popleft()
@@ -117,7 +116,9 @@ def run_agent(self):
117116
self.state["messages"].append(langchain_message)
118117
for _ in self.agent.stream(
119118
self.state,
120-
config={"callbacks": [self.callback, *self.tracing_callbacks]},
119+
config={
120+
"callbacks": [self._langchain_callback, *self.tracing_callbacks]
121+
},
121122
):
122123
if self._interupt_event.is_set():
123124
break

src/rai_core/rai/agents/langchain/callback.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _send_all_targets(self, tokens: str, done: bool = False):
5050
for target, connector in self.connectors.items():
5151
self.logger.info(f"Sending {len(tokens)} tokens to targer: {target}")
5252
try:
53-
to_send = connector.T_class.from_langchain(
53+
to_send: HRIMessage = connector.build_message(
5454
AIMessage(content=tokens),
5555
self.current_conversation_id,
5656
self.current_chunk_id,

src/rai_core/rai/agents/langchain/react_agent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
from rai.agents.langchain import create_react_runnable
2121
from rai.agents.langchain.agent import LangChainAgent
2222
from rai.agents.langchain.runnables import ReActAgentState
23-
from rai.communication.hri_connector import HRIConnector
23+
from rai.communication.hri_connector import HRIConnector, HRIMessage
2424

2525

2626
class ReActAgent(LangChainAgent):
2727
def __init__(
2828
self,
29-
target_connectors: Dict[str, HRIConnector],
29+
target_connectors: Dict[str, HRIConnector[HRIMessage]],
3030
llm: Optional[BaseChatModel] = None,
3131
tools: Optional[List[BaseTool]] = None,
3232
state: Optional[ReActAgentState] = None,

src/rai_core/rai/communication/hri_connector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class HRIConnector(Generic[T], BaseConnector[T]):
143143
Used for sending and receiving messages between human and robot from various sources.
144144
"""
145145

146-
def _build_message(
146+
def build_message(
147147
self,
148148
message: LangchainBaseMessage | RAIMultimodalMessage,
149149
communication_id: Optional[str] = None,

src/rai_core/rai/communication/ros2/connectors/hri_connector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from rclpy.qos import QoSProfile
2121

22-
from rai.communication.hri_connector import HRIConnector
22+
from rai.communication import HRIConnector
2323
from rai.communication.ros2.connectors.base import ROS2BaseConnector
2424
from rai.communication.ros2.messages import ROS2HRIMessage
2525

src/rai_core/rai/communication/ros2/messages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def from_ros2(
6565
for audio_msg in cast(List[ROS2HRIMessage__Audio], msg.audios)
6666
]
6767
communication_id = msg.communication_id if msg.communication_id != "" else None
68-
return ROS2HRIMessage(
68+
return cls(
6969
text=msg.text,
7070
images=pil_images,
7171
audios=audio_segments,

0 commit comments

Comments
 (0)