from __future__ import annotations
import inspect
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
NamedTuple,
Optional,
Type,
TypedDict,
Union,
overload,
)
from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.graph_ascii import draw_ascii
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
[docs]class LabelsDict(TypedDict):
nodes: dict[str, str]
edges: dict[str, str]
[docs]def is_uuid(value: str) -> bool:
try:
UUID(value)
return True
except ValueError:
return False
[docs]class Edge(NamedTuple):
"""Edge in a graph."""
source: str
target: str
data: Optional[str] = None
[docs]class Node(NamedTuple):
"""Node in a graph."""
id: str
data: Union[Type[BaseModel], RunnableType]
[docs]def node_data_str(node: Node) -> str:
from langchain_core.runnables.base import Runnable
if not is_uuid(node.id):
return node.id
elif isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]
[docs]def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable
if isinstance(node.data, RunnableSerializable):
return {
"type": "runnable",
"data": {
"id": node.data.lc_id(),
"name": node.data.get_name(),
},
}
elif isinstance(node.data, Runnable):
return {
"type": "runnable",
"data": {
"id": to_json_not_implemented(node.data)["id"],
"name": node.data.get_name(),
},
}
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
return {
"type": "schema",
"data": node.data.schema(),
}
else:
return {
"type": "unknown",
"data": node_data_str(node),
}
[docs]@dataclass
class Graph:
"""Graph of nodes and edges."""
nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)
[docs] def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
stable_node_ids = {
node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values())
}
return {
"nodes": [
{"id": stable_node_ids[node.id], **node_data_json(node)}
for node in self.nodes.values()
],
"edges": [
{
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
"data": edge.data,
}
if edge.data is not None
else {
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
}
for edge in self.edges
],
}
def __bool__(self) -> bool:
return bool(self.nodes)
[docs] def next_id(self) -> str:
return uuid4().hex
[docs] def add_node(
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
) -> Node:
"""Add a node to the graph and return it."""
if id is not None and id in self.nodes:
raise ValueError(f"Node with id {id} already exists")
node = Node(id=id or self.next_id(), data=data)
self.nodes[node.id] = node
return node
[docs] def remove_node(self, node: Node) -> None:
"""Remove a node from the graphm and all edges connected to it."""
self.nodes.pop(node.id)
self.edges = [
edge
for edge in self.edges
if edge.source != node.id and edge.target != node.id
]
[docs] def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge:
"""Add an edge to the graph and return it."""
if source.id not in self.nodes:
raise ValueError(f"Source node {source.id} not in graph")
if target.id not in self.nodes:
raise ValueError(f"Target node {target.id} not in graph")
edge = Edge(source=source.id, target=target.id, data=data)
self.edges.append(edge)
return edge
[docs] def extend(self, graph: Graph) -> None:
"""Add all nodes and edges from another graph.
Note this doesn't check for duplicates, nor does it connect the graphs."""
self.nodes.update(graph.nodes)
self.edges.extend(graph.edges)
[docs] def first_node(self) -> Optional[Node]:
"""Find the single node that is not a target of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph this node would be the origin."""
targets = {edge.target for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in targets:
found.append(node)
return found[0] if len(found) == 1 else None
[docs] def last_node(self) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph this node would be the destination.
"""
sources = {edge.source for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in sources:
found.append(node)
return found[0] if len(found) == 1 else None
[docs] def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge,
ie. if removing it would not leave the graph without a "first" node."""
first_node = self.first_node()
if first_node:
if (
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.source == first_node.id])
== 1
):
self.remove_node(first_node)
[docs] def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge,
ie. if removing it would not leave the graph without a "last" node."""
last_node = self.last_node()
if last_node:
if (
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.target == last_node.id])
== 1
):
self.remove_node(last_node)
[docs] def draw_ascii(self) -> str:
return draw_ascii(
{node.id: node_data_str(node) for node in self.nodes.values()},
[(edge.source, edge.target) for edge in self.edges],
)
[docs] def print_ascii(self) -> None:
print(self.draw_ascii()) # noqa: T201
@overload
def draw_png(
self,
output_file_path: str,
fontname: Optional[str] = None,
labels: Optional[LabelsDict] = None,
) -> None:
...
@overload
def draw_png(
self,
output_file_path: None,
fontname: Optional[str] = None,
labels: Optional[LabelsDict] = None,
) -> bytes:
...
[docs] def draw_png(
self,
output_file_path: Optional[str] = None,
fontname: Optional[str] = None,
labels: Optional[LabelsDict] = None,
) -> Union[bytes, None]:
from langchain_core.runnables.graph_png import PngDrawer
default_node_labels = {
node.id: node_data_str(node) for node in self.nodes.values()
}
return PngDrawer(
fontname,
LabelsDict(
nodes={
**default_node_labels,
**(labels["nodes"] if labels is not None else {}),
},
edges=labels["edges"] if labels is not None else {},
),
).draw(self, output_file_path)