Source code for langchain_core.runnables.graph

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
from uuid import uuid4

from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.graph_draw import draw

if TYPE_CHECKING:
    from langchain_core.runnables.base import Runnable as RunnableType


[docs]class Edge(NamedTuple): source: str target: str
[docs]class Node(NamedTuple): id: str data: Union[Type[BaseModel], RunnableType]
[docs]@dataclass class Graph: nodes: Dict[str, Node] = field(default_factory=dict) edges: List[Edge] = field(default_factory=list) def __bool__(self) -> bool: return bool(self.nodes)
[docs] def next_id(self) -> str: return uuid4().hex
[docs] def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node: """Add a node to the graph and return it.""" node = Node(id=self.next_id(), data=data) self.nodes[node.id] = node return node
[docs] def remove_node(self, node: Node) -> None: """Remove a node from the graphm and all edges connected to it.""" self.nodes.pop(node.id) self.edges = [ edge for edge in self.edges if edge.source != node.id and edge.target != node.id ]
[docs] def add_edge(self, source: Node, target: Node) -> Edge: """Add an edge to the graph and return it.""" if source.id not in self.nodes: raise ValueError(f"Source node {source.id} not in graph") if target.id not in self.nodes: raise ValueError(f"Target node {target.id} not in graph") edge = Edge(source=source.id, target=target.id) self.edges.append(edge) return edge
[docs] def extend(self, graph: Graph) -> None: """Add all nodes and edges from another graph. Note this doesn't check for duplicates, nor does it connect the graphs.""" self.nodes.update(graph.nodes) self.edges.extend(graph.edges)
[docs] def first_node(self) -> Optional[Node]: """Find the single node that is not a target of any edge. If there is no such node, or there are multiple, return None. When drawing the graph this node would be the origin.""" targets = {edge.target for edge in self.edges} found: List[Node] = [] for node in self.nodes.values(): if node.id not in targets: found.append(node) return found[0] if len(found) == 1 else None
[docs] def last_node(self) -> Optional[Node]: """Find the single node that is not a source of any edge. If there is no such node, or there are multiple, return None. When drawing the graph this node would be the destination. """ sources = {edge.source for edge in self.edges} found: List[Node] = [] for node in self.nodes.values(): if node.id not in sources: found.append(node) return found[0] if len(found) == 1 else None
[docs] def trim_first_node(self) -> None: """Remove the first node if it exists and has a single outgoing edge, ie. if removing it would not leave the graph without a "first" node.""" first_node = self.first_node() if first_node: if ( len(self.nodes) == 1 or len([edge for edge in self.edges if edge.source == first_node.id]) == 1 ): self.remove_node(first_node)
[docs] def trim_last_node(self) -> None: """Remove the last node if it exists and has a single incoming edge, ie. if removing it would not leave the graph without a "last" node.""" last_node = self.last_node() if last_node: if ( len(self.nodes) == 1 or len([edge for edge in self.edges if edge.target == last_node.id]) == 1 ): self.remove_node(last_node)
[docs] def draw_ascii(self) -> str: from langchain_core.runnables.base import Runnable def node_data(node: Node) -> str: if isinstance(node.data, Runnable): try: data = str(node.data) if ( data.startswith("<") or data[0] != data[0].upper() or len(data.splitlines()) > 1 ): data = node.data.__class__.__name__ elif len(data) > 42: data = data[:42] + "..." except Exception: data = node.data.__class__.__name__ else: data = node.data.__name__ return data if not data.startswith("Runnable") else data[8:] return draw( {node.id: node_data(node) for node in self.nodes.values()}, [(edge.source, edge.target) for edge in self.edges], )
[docs] def print_ascii(self) -> None: print(self.draw_ascii())