core: mermaid: Render metadata key-value pairs when drawing mermaid graph (#24103)

- if node is runnable binding with metadata attached
This commit is contained in:
Nuno Campos 2024-07-11 09:22:23 -07:00 committed by GitHub
parent 1e7d8ba9a6
commit ee3fe20af4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 146 additions and 116 deletions

View File

@ -398,7 +398,9 @@ class Runnable(Generic[Input, Output], ABC):
input_node = graph.add_node(self.get_input_schema(config)) input_node = graph.add_node(self.get_input_schema(config))
except TypeError: except TypeError:
input_node = graph.add_node(create_model(self.get_name("Input"))) input_node = graph.add_node(create_model(self.get_name("Input")))
runnable_node = graph.add_node(self) runnable_node = graph.add_node(
self, metadata=config.get("metadata") if config else None
)
try: try:
output_node = graph.add_node(self.get_output_schema(config)) output_node = graph.add_node(self.get_output_schema(config))
except TypeError: except TypeError:
@ -4629,7 +4631,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return self.bound.config_specs return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
return self.bound.get_graph(config) return self.bound.get_graph(self._merge_configs(config))
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
from collections import Counter
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import ( from typing import (
@ -63,12 +64,32 @@ class Edge(NamedTuple):
data: Optional[Stringifiable] = None data: Optional[Stringifiable] = None
conditional: bool = False conditional: bool = False
def copy(
self, *, source: Optional[str] = None, target: Optional[str] = None
) -> Edge:
return Edge(
source=source or self.source,
target=target or self.target,
data=self.data,
conditional=self.conditional,
)
class Node(NamedTuple): class Node(NamedTuple):
"""Node in a graph.""" """Node in a graph."""
id: str id: str
name: str
data: Union[Type[BaseModel], RunnableType] data: Union[Type[BaseModel], RunnableType]
metadata: Optional[Dict[str, Any]]
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
return Node(
id=id or self.id,
name=name or self.name,
data=self.data,
metadata=self.metadata,
)
class Branch(NamedTuple): class Branch(NamedTuple):
@ -111,35 +132,25 @@ class MermaidDrawMethod(Enum):
API = "api" # Uses Mermaid.INK API to render the graph API = "api" # Uses Mermaid.INK API to render the graph
def node_data_str(node: Node) -> str: def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
"""Convert the data of a node to a string. """Convert the data of a node to a string.
Args: Args:
node: The node to convert. node: The node to convert.
html: Whether to format the data as HTML rich text.
Returns: Returns:
A string representation of the data. A string representation of the data.
""" """
from langchain_core.runnables.base import Runnable from langchain_core.runnables.base import Runnable
if not is_uuid(node.id): if not is_uuid(id):
return node.id return id
elif isinstance(node.data, Runnable): elif isinstance(data, Runnable):
try: data_str = data.get_name()
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else: else:
data = node.data.__name__ data_str = data.__name__
return data if not data.startswith("Runnable") else data[8:] return data_str if not data_str.startswith("Runnable") else data_str[8:]
def node_data_json( def node_data_json(
@ -163,7 +174,7 @@ def node_data_json(
"type": "runnable", "type": "runnable",
"data": { "data": {
"id": node.data.lc_id(), "id": node.data.lc_id(),
"name": node.data.get_name(), "name": node_data_str(node.id, node.data),
}, },
} }
elif isinstance(node.data, Runnable): elif isinstance(node.data, Runnable):
@ -171,7 +182,7 @@ def node_data_json(
"type": "runnable", "type": "runnable",
"data": { "data": {
"id": to_json_not_implemented(node.data)["id"], "id": to_json_not_implemented(node.data)["id"],
"name": node.data.get_name(), "name": node_data_str(node.id, node.data),
}, },
} }
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel): elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
@ -183,13 +194,13 @@ def node_data_json(
if with_schemas if with_schemas
else { else {
"type": "schema", "type": "schema",
"data": node_data_str(node), "data": node_data_str(node.id, node.data),
} }
) )
else: else:
return { return {
"type": "unknown", "type": "unknown",
"data": node_data_str(node), "data": node_data_str(node.id, node.data),
} }
@ -236,12 +247,17 @@ class Graph:
return uuid4().hex return uuid4().hex
def add_node( def add_node(
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None self,
data: Union[Type[BaseModel], RunnableType],
id: Optional[str] = None,
*,
metadata: Optional[Dict[str, Any]] = None,
) -> Node: ) -> Node:
"""Add a node to the graph and return it.""" """Add a node to the graph and return it."""
if id is not None and id in self.nodes: if id is not None and id in self.nodes:
raise ValueError(f"Node with id {id} already exists") raise ValueError(f"Node with id {id} already exists")
node = Node(id=id or self.next_id(), data=data) id = id or self.next_id()
node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data))
self.nodes[node.id] = node self.nodes[node.id] = node
return node return node
@ -285,25 +301,47 @@ class Graph:
# prefix each node # prefix each node
self.nodes.update( self.nodes.update(
{prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()} {prefixed(k): v.copy(id=prefixed(k)) for k, v in graph.nodes.items()}
) )
# prefix each edge's source and target # prefix each edge's source and target
self.edges.extend( self.edges.extend(
[ [
Edge( edge.copy(source=prefixed(edge.source), target=prefixed(edge.target))
prefixed(edge.source),
prefixed(edge.target),
edge.data,
edge.conditional,
)
for edge in graph.edges for edge in graph.edges
] ]
) )
# return (prefixed) first and last nodes of the subgraph # return (prefixed) first and last nodes of the subgraph
first, last = graph.first_node(), graph.last_node() first, last = graph.first_node(), graph.last_node()
return ( return (
Node(prefixed(first.id), first.data) if first else None, first.copy(id=prefixed(first.id)) if first else None,
Node(prefixed(last.id), last.data) if last else None, last.copy(id=prefixed(last.id)) if last else None,
)
def reid(self) -> Graph:
"""Return a new graph with all nodes re-identified,
using their unique, readable names where possible."""
node_labels = {node.id: node.name for node in self.nodes.values()}
node_label_counts = Counter(node_labels.values())
def _get_node_id(node_id: str) -> str:
label = node_labels[node_id]
if is_uuid(node_id) and node_label_counts[label] == 1:
return label
else:
return node_id
return Graph(
nodes={
_get_node_id(id): node.copy(id=_get_node_id(id))
for id, node in self.nodes.items()
},
edges=[
edge.copy(
source=_get_node_id(edge.source),
target=_get_node_id(edge.target),
)
for edge in self.edges
],
) )
def first_node(self) -> Optional[Node]: def first_node(self) -> Optional[Node]:
@ -357,7 +395,7 @@ class Graph:
from langchain_core.runnables.graph_ascii import draw_ascii from langchain_core.runnables.graph_ascii import draw_ascii
return draw_ascii( return draw_ascii(
{node.id: node_data_str(node) for node in self.nodes.values()}, {node.id: node.name for node in self.nodes.values()},
self.edges, self.edges,
) )
@ -388,9 +426,7 @@ class Graph:
) -> Union[bytes, None]: ) -> Union[bytes, None]:
from langchain_core.runnables.graph_png import PngDrawer from langchain_core.runnables.graph_png import PngDrawer
default_node_labels = { default_node_labels = {node.id: node.name for node in self.nodes.values()}
node.id: node_data_str(node) for node in self.nodes.values()
}
return PngDrawer( return PngDrawer(
fontname, fontname,
@ -415,19 +451,15 @@ class Graph:
) -> str: ) -> str:
from langchain_core.runnables.graph_mermaid import draw_mermaid from langchain_core.runnables.graph_mermaid import draw_mermaid
nodes = {node.id: node_data_str(node) for node in self.nodes.values()} graph = self.reid()
first_node = graph.first_node()
first_node = self.first_node() last_node = graph.last_node()
first_label = node_data_str(first_node) if first_node is not None else None
last_node = self.last_node()
last_label = node_data_str(last_node) if last_node is not None else None
return draw_mermaid( return draw_mermaid(
nodes=nodes, nodes=graph.nodes,
edges=self.edges, edges=graph.edges,
first_node_label=first_label, first_node=first_node.id if first_node else None,
last_node_label=last_label, last_node=last_node.id if last_node else None,
with_styles=with_styles, with_styles=with_styles,
curve_style=curve_style, curve_style=curve_style,
node_colors=node_colors, node_colors=node_colors,

View File

@ -1,22 +1,23 @@
import base64 import base64
import re import re
from dataclasses import asdict from dataclasses import asdict
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
from langchain_core.runnables.graph import ( from langchain_core.runnables.graph import (
CurveStyle, CurveStyle,
Edge, Edge,
MermaidDrawMethod, MermaidDrawMethod,
Node,
NodeColors, NodeColors,
) )
def draw_mermaid( def draw_mermaid(
nodes: Dict[str, str], nodes: Dict[str, Node],
edges: List[Edge], edges: List[Edge],
*, *,
first_node_label: Optional[str] = None, first_node: Optional[str] = None,
last_node_label: Optional[str] = None, last_node: Optional[str] = None,
with_styles: bool = True, with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(), node_colors: NodeColors = NodeColors(),
@ -49,15 +50,20 @@ def draw_mermaid(
# Node formatting templates # Node formatting templates
default_class_label = "default" default_class_label = "default"
format_dict = {default_class_label: "{0}([{1}]):::otherclass"} format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
if first_node_label is not None: if first_node is not None:
format_dict[first_node_label] = "{0}[{0}]:::startclass" format_dict[first_node] = "{0}[{0}]:::startclass"
if last_node_label is not None: if last_node is not None:
format_dict[last_node_label] = "{0}[{0}]:::endclass" format_dict[last_node] = "{0}[{0}]:::endclass"
# Add nodes to the graph # Add nodes to the graph
for node in nodes.values(): for key, node in nodes.items():
node_label = format_dict.get(node, format_dict[default_class_label]).format( label = node.name.split(":")[-1]
_escape_node_label(node), node.split(":", 1)[-1] if node.metadata:
label = f"<strong>{label}</strong>\n" + "\n".join(
f"{key} = {value}" for key, value in node.metadata.items()
)
node_label = format_dict.get(key, format_dict[default_class_label]).format(
_escape_node_label(key), label
) )
mermaid_graph += f"\t{node_label};\n" mermaid_graph += f"\t{node_label};\n"
@ -74,9 +80,8 @@ def draw_mermaid(
if not subgraph and src_prefix and src_prefix == tgt_prefix: if not subgraph and src_prefix and src_prefix == tgt_prefix:
mermaid_graph += f"\tsubgraph {src_prefix}\n" mermaid_graph += f"\tsubgraph {src_prefix}\n"
subgraph = src_prefix subgraph = src_prefix
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
source, target = adjusted_edge source, target = edge.source, edge.target
# Add BR every wrap_label_n_words words # Add BR every wrap_label_n_words words
if edge.data is not None: if edge.data is not None:
@ -117,17 +122,6 @@ def _escape_node_label(node_label: str) -> str:
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label) return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
def _adjust_mermaid_edge(
edge: Edge,
nodes: Dict[str, str],
) -> Tuple[str, str]:
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
source_node_label = nodes.get(edge.source, edge.source)
target_node_label = nodes.get(edge.target, edge.target)
return source_node_label, target_node_label
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str: def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
"""Generates Mermaid graph styles for different node types.""" """Generates Mermaid graph styles for different node types."""
styles = "" styles = ""

View File

@ -58,7 +58,7 @@
"base", "base",
"RunnableLambda" "RunnableLambda"
], ],
"name": "RunnableLambda" "name": "Lambda"
} }
} }
], ],
@ -458,7 +458,7 @@
"runnable", "runnable",
"RunnableWithFallbacks" "RunnableWithFallbacks"
], ],
"name": "RunnableWithFallbacks" "name": "WithFallbacks"
} }
}, },
{ {
@ -498,7 +498,7 @@
"base", "base",
"RunnableLambda" "RunnableLambda"
], ],
"name": "RunnableLambda" "name": "Lambda"
} }
}, },
{ {
@ -511,7 +511,7 @@
"runnable", "runnable",
"RunnableWithFallbacks" "RunnableWithFallbacks"
], ],
"name": "RunnableWithFallbacks" "name": "WithFallbacks"
} }
}, },
{ {
@ -589,7 +589,7 @@
"runnable", "runnable",
"RunnablePassthrough" "RunnablePassthrough"
], ],
"name": "RunnablePassthrough" "name": "Passthrough"
} }
}, },
{ {
@ -635,7 +635,7 @@
"runnable", "runnable",
"RunnablePassthrough" "RunnablePassthrough"
], ],
"name": "RunnablePassthrough" "name": "Passthrough"
} }
} }
], ],
@ -716,7 +716,7 @@
"runnable", "runnable",
"RunnableWithFallbacks" "RunnableWithFallbacks"
], ],
"name": "RunnableWithFallbacks" "name": "WithFallbacks"
} }
}, },
{ {
@ -756,7 +756,7 @@
"runnable", "runnable",
"RunnablePassthrough" "RunnablePassthrough"
], ],
"name": "RunnablePassthrough" "name": "Passthrough"
} }
}, },
{ {
@ -769,7 +769,7 @@
"runnable", "runnable",
"RunnableWithFallbacks" "RunnableWithFallbacks"
], ],
"name": "RunnableWithFallbacks" "name": "WithFallbacks"
} }
}, },
{ {
@ -938,7 +938,7 @@
"runnable", "runnable",
"RunnableWithFallbacks" "RunnableWithFallbacks"
], ],
"name": "RunnableWithFallbacks" "name": "WithFallbacks"
} }
}, },
{ {
@ -1152,7 +1152,7 @@
"runnable", "runnable",
"RunnableWithFallbacks" "RunnableWithFallbacks"
], ],
"name": "RunnableWithFallbacks" "name": "WithFallbacks"
} }
}, },
{ {

View File

@ -36,7 +36,8 @@
graph TD; graph TD;
PromptInput[PromptInput]:::startclass; PromptInput[PromptInput]:::startclass;
PromptTemplate([PromptTemplate]):::otherclass; PromptTemplate([PromptTemplate]):::otherclass;
FakeListLLM([FakeListLLM]):::otherclass; FakeListLLM([<strong>FakeListLLM</strong>
key = 2]):::otherclass;
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass; CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
PromptInput --> PromptTemplate; PromptInput --> PromptTemplate;

File diff suppressed because one or more lines are too long

View File

@ -33,7 +33,7 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
prompt = PromptTemplate.from_template("Hello, {name}!") prompt = PromptTemplate.from_template("Hello, {name}!")
list_parser = CommaSeparatedListOutputParser() list_parser = CommaSeparatedListOutputParser()
sequence = prompt | fake_llm | list_parser sequence = prompt | fake_llm.with_config(metadata={"key": 2}) | list_parser
graph = sequence.get_graph() graph = sequence.get_graph()
assert graph.to_json() == { assert graph.to_json() == {
"nodes": [ "nodes": [

View File

@ -6,6 +6,7 @@ authors = []
license = "MIT" license = "MIT"
readme = "README.md" readme = "README.md"
repository = "https://www.github.com/langchain-ai/langchain" repository = "https://www.github.com/langchain-ai/langchain"
package-mode = false
[tool.poetry.dependencies] [tool.poetry.dependencies]