|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import logging |
15 | 16 | from functools import partial |
16 | | -from typing import List, Optional, TypedDict, cast |
| 17 | +from typing import ( |
| 18 | + Any, |
| 19 | + Callable, |
| 20 | + Dict, |
| 21 | + List, |
| 22 | + Optional, |
| 23 | + TypedDict, |
| 24 | + cast, |
| 25 | +) |
17 | 26 |
|
18 | 27 | from langchain_core.language_models import BaseChatModel |
19 | | -from langchain_core.messages import BaseMessage, SystemMessage |
| 28 | +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage |
20 | 29 | from langchain_core.runnables import Runnable |
21 | 30 | from langchain_core.tools import BaseTool |
22 | 31 | from langgraph.graph import START, StateGraph |
23 | 32 | from langgraph.prebuilt.tool_node import tools_condition |
24 | 33 |
|
25 | | -from rai.agents.tool_runner import ToolRunner |
26 | 34 | from rai.initialization import get_llm_model |
| 35 | +from rai.messages import HumanMultimodalMessage |
| 36 | + |
| 37 | +from ..tool_runner import ToolRunner |
27 | 38 |
|
28 | 39 |
|
29 | 40 | class ReActAgentState(TypedDict): |
@@ -112,3 +123,50 @@ def create_react_runnable( |
112 | 123 |
|
113 | 124 | # Compile the graph |
114 | 125 | return graph.compile() |
| 126 | + |
| 127 | + |
| 128 | +def retriever_wrapper( |
| 129 | + state_retriever: Callable[[], Dict[str, HumanMessage | HumanMultimodalMessage]], |
| 130 | + state: ReActAgentState, |
| 131 | +): |
| 132 | + """This wrapper is used to put state messages into LLM context""" |
| 133 | + for source, message in state_retriever().items(): |
| 134 | + message.content = f"{source}: {message.content}" |
| 135 | + logging.getLogger("state_retriever").debug( |
| 136 | + f"Adding state message:\n{message.pretty_repr()}" |
| 137 | + ) |
| 138 | + state["messages"].append(message) |
| 139 | + return state |
| 140 | + |
| 141 | + |
| 142 | +def create_state_based_runnable( |
| 143 | + llm: Optional[BaseChatModel] = None, |
| 144 | + tools: Optional[List[BaseTool]] = None, |
| 145 | + system_prompt: Optional[str] = None, |
| 146 | + state_retriever: Optional[Callable[[], Dict[str, Any]]] = None, |
| 147 | +) -> Runnable[ReActAgentState, ReActAgentState]: |
| 148 | + if llm is None: |
| 149 | + llm = get_llm_model("complex_model", streaming=True) |
| 150 | + graph = StateGraph(ReActAgentState) |
| 151 | + graph.add_edge(START, "state_retriever") |
| 152 | + graph.add_edge("state_retriever", "llm") |
| 153 | + graph.add_conditional_edges( |
| 154 | + "llm", |
| 155 | + tools_condition, |
| 156 | + ) |
| 157 | + graph.add_edge("tools", "state_retriever") |
| 158 | + |
| 159 | + if state_retriever is None: |
| 160 | + state_retriever = lambda: {} |
| 161 | + |
| 162 | + graph.add_node("state_retriever", partial(retriever_wrapper, state_retriever)) |
| 163 | + |
| 164 | + if tools is None: |
| 165 | + tools = [] |
| 166 | + bound_llm = cast(BaseChatModel, llm.bind_tools(tools)) |
| 167 | + graph.add_node("llm", partial(llm_node, bound_llm, system_prompt)) |
| 168 | + |
| 169 | + tool_runner = ToolRunner(tools) |
| 170 | + graph.add_node("tools", tool_runner) |
| 171 | + |
| 172 | + return graph.compile() |
0 commit comments