21
21
22
22
from langchain_core .messages import BaseMessage
23
23
from langchain_core .runnables import Runnable
24
- from pydantic import BaseModel
25
24
26
25
from rai .agents .base import BaseAgent
27
26
from rai .agents .langchain import HRICallbackHandler
28
27
from rai .agents .langchain .runnables import ReActAgentState
29
- from rai .communication .base_connector import BaseConnector
30
- from rai .communication .hri_connector import HRIMessage
28
+ from rai .communication .hri_connector import HRIConnector , HRIMessage
31
29
from rai .initialization import get_tracing_callbacks
32
30
33
31
34
32
class BaseState (TypedDict ):
35
33
messages : List [BaseMessage ]
36
34
37
35
38
- class HRIConfig (BaseModel ):
39
- source : str
40
- targets : List [str ]
41
-
42
-
43
36
class LangChainAgent (BaseAgent ):
44
37
def __init__ (
45
38
self ,
46
- target_connectors : Dict [str , BaseConnector ],
39
+ target_connectors : Dict [str , HRIConnector [ HRIMessage ] ],
47
40
runnable : Runnable ,
48
41
state : BaseState | None = None ,
49
42
new_message_behavior : Literal [
@@ -61,7 +54,7 @@ def __init__(
61
54
self .new_message_behavior = new_message_behavior
62
55
self .tracing_callbacks = get_tracing_callbacks ()
63
56
self .state = state or ReActAgentState (messages = [])
64
- self .callback = HRICallbackHandler (
57
+ self ._langchain_callback = HRICallbackHandler (
65
58
connectors = target_connectors ,
66
59
aggregate_chunks = True ,
67
60
logger = self .logger ,
@@ -76,7 +69,13 @@ def __init__(
76
69
self ._interupt_event = threading .Event ()
77
70
self ._agent_ready_event = threading .Event ()
78
71
79
- def __call__ (self , msg : HRIMessage ):
72
+ def subscribe_source (self , source : str , connector : HRIConnector [HRIMessage ]):
73
+ connector .register_callback (
74
+ source ,
75
+ self .source_callback ,
76
+ )
77
+
78
+ def source_callback (self , msg : HRIMessage ):
80
79
if self .max_size is not None and len (self ._received_messages ) >= self .max_size :
81
80
self .logger .warning ("Buffer overflow. Dropping olders message" )
82
81
self ._received_messages .popleft ()
@@ -117,7 +116,9 @@ def run_agent(self):
117
116
self .state ["messages" ].append (langchain_message )
118
117
for _ in self .agent .stream (
119
118
self .state ,
120
- config = {"callbacks" : [self .callback , * self .tracing_callbacks ]},
119
+ config = {
120
+ "callbacks" : [self ._langchain_callback , * self .tracing_callbacks ]
121
+ },
121
122
):
122
123
if self ._interupt_event .is_set ():
123
124
break
0 commit comments