core: mermaid: Render metadata key-value pairs when drawing mermaid graph (#24103)

- if node is runnable binding with metadata attached
This commit is contained in:
Nuno Campos 2024-07-11 09:22:23 -07:00 committed by GitHub
parent 1e7d8ba9a6
commit ee3fe20af4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 146 additions and 116 deletions

View File

@ -398,7 +398,9 @@ class Runnable(Generic[Input, Output], ABC):
input_node = graph.add_node(self.get_input_schema(config))
except TypeError:
input_node = graph.add_node(create_model(self.get_name("Input")))
runnable_node = graph.add_node(self)
runnable_node = graph.add_node(
self, metadata=config.get("metadata") if config else None
)
try:
output_node = graph.add_node(self.get_output_schema(config))
except TypeError:
@ -4629,7 +4631,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
return self.bound.get_graph(config)
return self.bound.get_graph(self._merge_configs(config))
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import inspect
from collections import Counter
from dataclasses import dataclass, field
from enum import Enum
from typing import (
@ -63,12 +64,32 @@ class Edge(NamedTuple):
data: Optional[Stringifiable] = None
conditional: bool = False
def copy(
self, *, source: Optional[str] = None, target: Optional[str] = None
) -> Edge:
return Edge(
source=source or self.source,
target=target or self.target,
data=self.data,
conditional=self.conditional,
)
class Node(NamedTuple):
"""Node in a graph."""
id: str
name: str
data: Union[Type[BaseModel], RunnableType]
metadata: Optional[Dict[str, Any]]
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
return Node(
id=id or self.id,
name=name or self.name,
data=self.data,
metadata=self.metadata,
)
class Branch(NamedTuple):
@ -111,35 +132,25 @@ class MermaidDrawMethod(Enum):
API = "api" # Uses Mermaid.INK API to render the graph
def node_data_str(node: Node) -> str:
def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
"""Convert the data of a node to a string.
Args:
node: The node to convert.
html: Whether to format the data as HTML rich text.
Returns:
A string representation of the data.
"""
from langchain_core.runnables.base import Runnable
if not is_uuid(node.id):
return node.id
elif isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
if not is_uuid(id):
return id
elif isinstance(data, Runnable):
data_str = data.get_name()
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]
data_str = data.__name__
return data_str if not data_str.startswith("Runnable") else data_str[8:]
def node_data_json(
@ -163,7 +174,7 @@ def node_data_json(
"type": "runnable",
"data": {
"id": node.data.lc_id(),
"name": node.data.get_name(),
"name": node_data_str(node.id, node.data),
},
}
elif isinstance(node.data, Runnable):
@ -171,7 +182,7 @@ def node_data_json(
"type": "runnable",
"data": {
"id": to_json_not_implemented(node.data)["id"],
"name": node.data.get_name(),
"name": node_data_str(node.id, node.data),
},
}
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
@ -183,13 +194,13 @@ def node_data_json(
if with_schemas
else {
"type": "schema",
"data": node_data_str(node),
"data": node_data_str(node.id, node.data),
}
)
else:
return {
"type": "unknown",
"data": node_data_str(node),
"data": node_data_str(node.id, node.data),
}
@ -236,12 +247,17 @@ class Graph:
return uuid4().hex
def add_node(
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
self,
data: Union[Type[BaseModel], RunnableType],
id: Optional[str] = None,
*,
metadata: Optional[Dict[str, Any]] = None,
) -> Node:
"""Add a node to the graph and return it."""
if id is not None and id in self.nodes:
raise ValueError(f"Node with id {id} already exists")
node = Node(id=id or self.next_id(), data=data)
id = id or self.next_id()
node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data))
self.nodes[node.id] = node
return node
@ -285,25 +301,47 @@ class Graph:
# prefix each node
self.nodes.update(
{prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()}
{prefixed(k): v.copy(id=prefixed(k)) for k, v in graph.nodes.items()}
)
# prefix each edge's source and target
self.edges.extend(
[
Edge(
prefixed(edge.source),
prefixed(edge.target),
edge.data,
edge.conditional,
)
edge.copy(source=prefixed(edge.source), target=prefixed(edge.target))
for edge in graph.edges
]
)
# return (prefixed) first and last nodes of the subgraph
first, last = graph.first_node(), graph.last_node()
return (
Node(prefixed(first.id), first.data) if first else None,
Node(prefixed(last.id), last.data) if last else None,
first.copy(id=prefixed(first.id)) if first else None,
last.copy(id=prefixed(last.id)) if last else None,
)
def reid(self) -> Graph:
"""Return a new graph with all nodes re-identified,
using their unique, readable names where possible."""
node_labels = {node.id: node.name for node in self.nodes.values()}
node_label_counts = Counter(node_labels.values())
def _get_node_id(node_id: str) -> str:
label = node_labels[node_id]
if is_uuid(node_id) and node_label_counts[label] == 1:
return label
else:
return node_id
return Graph(
nodes={
_get_node_id(id): node.copy(id=_get_node_id(id))
for id, node in self.nodes.items()
},
edges=[
edge.copy(
source=_get_node_id(edge.source),
target=_get_node_id(edge.target),
)
for edge in self.edges
],
)
def first_node(self) -> Optional[Node]:
@ -357,7 +395,7 @@ class Graph:
from langchain_core.runnables.graph_ascii import draw_ascii
return draw_ascii(
{node.id: node_data_str(node) for node in self.nodes.values()},
{node.id: node.name for node in self.nodes.values()},
self.edges,
)
@ -388,9 +426,7 @@ class Graph:
) -> Union[bytes, None]:
from langchain_core.runnables.graph_png import PngDrawer
default_node_labels = {
node.id: node_data_str(node) for node in self.nodes.values()
}
default_node_labels = {node.id: node.name for node in self.nodes.values()}
return PngDrawer(
fontname,
@ -415,19 +451,15 @@ class Graph:
) -> str:
from langchain_core.runnables.graph_mermaid import draw_mermaid
nodes = {node.id: node_data_str(node) for node in self.nodes.values()}
first_node = self.first_node()
first_label = node_data_str(first_node) if first_node is not None else None
last_node = self.last_node()
last_label = node_data_str(last_node) if last_node is not None else None
graph = self.reid()
first_node = graph.first_node()
last_node = graph.last_node()
return draw_mermaid(
nodes=nodes,
edges=self.edges,
first_node_label=first_label,
last_node_label=last_label,
nodes=graph.nodes,
edges=graph.edges,
first_node=first_node.id if first_node else None,
last_node=last_node.id if last_node else None,
with_styles=with_styles,
curve_style=curve_style,
node_colors=node_colors,

View File

@ -1,22 +1,23 @@
import base64
import re
from dataclasses import asdict
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
from langchain_core.runnables.graph import (
CurveStyle,
Edge,
MermaidDrawMethod,
Node,
NodeColors,
)
def draw_mermaid(
nodes: Dict[str, str],
nodes: Dict[str, Node],
edges: List[Edge],
*,
first_node_label: Optional[str] = None,
last_node_label: Optional[str] = None,
first_node: Optional[str] = None,
last_node: Optional[str] = None,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(),
@ -49,15 +50,20 @@ def draw_mermaid(
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
if first_node_label is not None:
format_dict[first_node_label] = "{0}[{0}]:::startclass"
if last_node_label is not None:
format_dict[last_node_label] = "{0}[{0}]:::endclass"
if first_node is not None:
format_dict[first_node] = "{0}[{0}]:::startclass"
if last_node is not None:
format_dict[last_node] = "{0}[{0}]:::endclass"
# Add nodes to the graph
for node in nodes.values():
node_label = format_dict.get(node, format_dict[default_class_label]).format(
_escape_node_label(node), node.split(":", 1)[-1]
for key, node in nodes.items():
label = node.name.split(":")[-1]
if node.metadata:
label = f"<strong>{label}</strong>\n" + "\n".join(
f"{key} = {value}" for key, value in node.metadata.items()
)
node_label = format_dict.get(key, format_dict[default_class_label]).format(
_escape_node_label(key), label
)
mermaid_graph += f"\t{node_label};\n"
@ -74,9 +80,8 @@ def draw_mermaid(
if not subgraph and src_prefix and src_prefix == tgt_prefix:
mermaid_graph += f"\tsubgraph {src_prefix}\n"
subgraph = src_prefix
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
source, target = adjusted_edge
source, target = edge.source, edge.target
# Add BR every wrap_label_n_words words
if edge.data is not None:
@ -117,17 +122,6 @@ def _escape_node_label(node_label: str) -> str:
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
def _adjust_mermaid_edge(
edge: Edge,
nodes: Dict[str, str],
) -> Tuple[str, str]:
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
source_node_label = nodes.get(edge.source, edge.source)
target_node_label = nodes.get(edge.target, edge.target)
return source_node_label, target_node_label
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
"""Generates Mermaid graph styles for different node types."""
styles = ""

View File

@ -58,7 +58,7 @@
"base",
"RunnableLambda"
],
"name": "RunnableLambda"
"name": "Lambda"
}
}
],
@ -458,7 +458,7 @@
"runnable",
"RunnableWithFallbacks"
],
"name": "RunnableWithFallbacks"
"name": "WithFallbacks"
}
},
{
@ -498,7 +498,7 @@
"base",
"RunnableLambda"
],
"name": "RunnableLambda"
"name": "Lambda"
}
},
{
@ -511,7 +511,7 @@
"runnable",
"RunnableWithFallbacks"
],
"name": "RunnableWithFallbacks"
"name": "WithFallbacks"
}
},
{
@ -589,7 +589,7 @@
"runnable",
"RunnablePassthrough"
],
"name": "RunnablePassthrough"
"name": "Passthrough"
}
},
{
@ -635,7 +635,7 @@
"runnable",
"RunnablePassthrough"
],
"name": "RunnablePassthrough"
"name": "Passthrough"
}
}
],
@ -716,7 +716,7 @@
"runnable",
"RunnableWithFallbacks"
],
"name": "RunnableWithFallbacks"
"name": "WithFallbacks"
}
},
{
@ -756,7 +756,7 @@
"runnable",
"RunnablePassthrough"
],
"name": "RunnablePassthrough"
"name": "Passthrough"
}
},
{
@ -769,7 +769,7 @@
"runnable",
"RunnableWithFallbacks"
],
"name": "RunnableWithFallbacks"
"name": "WithFallbacks"
}
},
{
@ -938,7 +938,7 @@
"runnable",
"RunnableWithFallbacks"
],
"name": "RunnableWithFallbacks"
"name": "WithFallbacks"
}
},
{
@ -1152,7 +1152,7 @@
"runnable",
"RunnableWithFallbacks"
],
"name": "RunnableWithFallbacks"
"name": "WithFallbacks"
}
},
{

View File

@ -36,7 +36,8 @@
graph TD;
PromptInput[PromptInput]:::startclass;
PromptTemplate([PromptTemplate]):::otherclass;
FakeListLLM([FakeListLLM]):::otherclass;
FakeListLLM([<strong>FakeListLLM</strong>
key = 2]):::otherclass;
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
PromptInput --> PromptTemplate;

File diff suppressed because one or more lines are too long

View File

@ -33,7 +33,7 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
prompt = PromptTemplate.from_template("Hello, {name}!")
list_parser = CommaSeparatedListOutputParser()
sequence = prompt | fake_llm | list_parser
sequence = prompt | fake_llm.with_config(metadata={"key": 2}) | list_parser
graph = sequence.get_graph()
assert graph.to_json() == {
"nodes": [

View File

@ -6,6 +6,7 @@ authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langchain"
package-mode = false
[tool.poetry.dependencies]