+ "answer": "Sure, here's the implementation for the `BaseGraph` class in ScrapeGraphAI:\n\n```python\n\"\"\"\nBaseGraph Module\n\"\"\"\n\nimport time\nimport warnings\nfrom langchain_community.callbacks import get_openai_callback\nfrom typing import Tuple\n\n\nclass BaseGraph:\n \"\"\"\n BaseGraph manages the execution flow of a graph composed of interconnected nodes.\n\n Attributes:\n nodes (list): A dictionary mapping each node's name to its corresponding node instance.\n edges (list): A dictionary representing the directed edges of the graph where each\n key-value pair corresponds to the from-node and to-node relationship.\n entry_point (str): The name of the entry point node from which the graph execution begins.\n\n Args:\n nodes (iterable): An iterable of node instances that will be part of the graph.\n edges (iterable): An iterable of tuples where each tuple represents a directed edge\n in the graph, defined by a pair of nodes (from_node, to_node).\n entry_point (BaseNode): The node instance that represents the entry point of the graph.\n\n Raises:\n Warning: If the entry point node is not the first node in the list.\n\n Example:\n >>> BaseGraph(\n ... nodes=[\n ... fetch_node,\n ... parse_node,\n ... rag_node,\n ... generate_answer_node,\n ... ],\n ... edges=[\n ... (fetch_node, parse_node),\n ... (parse_node, rag_node),\n ... (rag_node, generate_answer_node)\n ... ],\n ... entry_point=fetch_node,\n ... use_burr=True,\n ... burr_config={\"app_instance_id\": \"example-instance\"}\n ... )\n \"\"\"\n\n def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool = False, burr_config: dict = None):\n\n self.nodes = nodes\n self.edges = self._create_edges({e for e in edges})\n self.entry_point = entry_point.node_name\n self.initial_state = {}\n\n if nodes[0].node_name != entry_point.node_name:\n # raise a warning if the entry point is not the first node in the list\n warnings.warn(\n \"Careful! The entry point node is different from the first node if the graph.\")\n \n # Burr configuration\n self.use_burr = use_burr\n self.burr_config = burr_config or {}\n\n def _create_edges(self, edges: list) -> dict:\n \"\"\"\n Helper method to create a dictionary of edges from the given iterable of tuples.\n\n Args:\n edges (iterable): An iterable of tuples representing the directed edges.\n\n Returns:\n dict: A dictionary of edges with the from-node as keys and to-node as values.\n \"\"\"\n\n edge_dict = {}\n for from_node, to_node in edges:\n edge_dict[from_node.node_name] = to_node.node_name\n return edge_dict\n\n def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:\n \"\"\"\n Executes the graph by traversing nodes starting from the entry point using the standard method.\n\n Args:\n initial_state (dict): The initial state to pass to the entry point node.\n\n Returns:\n Tuple[dict, list]: A tuple containing the final state and a list of execution info.\n \"\"\"\n current_node_name = self.entry_point\n state = initial_state\n\n # variables for tracking execution info\n total_exec_time = 0.0\n exec_info = []\n cb_total = {\n \"total_tokens\": 0,\n \"prompt_tokens\": 0,\n \"completion_tokens\": 0,\n \"successful_requests\": 0,\n \"total_cost_USD\": 0.0,\n }\n\n while current_node_name:\n curr_time = time.time()\n current_node = next(node for node in self.nodes if node.node_name == current_node_name)\n\n with get_openai_callback() as cb:\n result = current_node.execute(state)\n node_exec_time = time.time() - curr_time\n total_exec_time += node_exec_time\n\n cb_data = {\n \"node_name\": current_node.node_name,\n \"total_tokens\": cb.total_tokens,\n \"prompt_tokens\": cb.prompt_tokens,\n \"completion_tokens\": cb.completion_tokens,\n \"successful_requests\": cb.successful_requests,\n \"total_cost_USD\": cb.total_cost,\n \"exec_time\": node_exec_time,\n }\n\n exec_info.append(cb_data)\n\n cb_total[\"total_tokens\"] += cb_data[\"total_tokens\"]\n cb_total[\"prompt_tokens\"] += cb_data[\"prompt_tokens\"]\n cb_total[\"completion_tokens\"] += cb_data[\"completion_tokens\"]\n cb_total[\"successful_requests\"] += cb_data[\"successful_requests\"]\n cb_total[\"total_cost_USD\"] += cb_data[\"total_cost_USD\"]\n\n if current_node.node_type == \"conditional_node\":\n current_node_name = result\n elif current_node_name in self.edges:\n current_node_name = self.edges[current_node_name]\n else:\n current_node_name = None\n\n exec_info.append({\n \"node_name\": \"TOTAL RESULT\",\n \"total_tokens\": cb_total[\"total_tokens\"],\n \"prompt_tokens\": cb_total[\"prompt_tokens\"],\n \"completion_tokens\": cb_total[\"completion_tokens\"],\n \"successful_requests\": cb_total[\"successful_requests\"],\n \"total_cost_USD\": cb_total[\"total_cost_USD\"],\n \"exec_time\": total_exec_time,\n })\n\n return state, exec_info\n\n def execute(self, initial_state: dict) -> Tuple[dict, list]:\n \"\"\"\n Executes the graph by either using BurrBridge or the standard method.\n\n Args:\n initial_state (dict): The initial state to pass to the entry point node.\n\n Returns:\n Tuple[dict, list]: A tuple containing the final state and a list of execution info.\n \"\"\"\n\n self.initial_state = initial_state\n if self.use_burr:\n\n from ..integrations import BurrBridge\n \n bridge = BurrBridge(self, self.burr_config)\n result = bridge.execute(initial_state)\n return (result[\"_state\"], [])\n else:\n return self._execute_standard(initial_state)\n```"
0 commit comments