mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
core: mermaid: Render metadata key-value pairs when drawing mermaid graph (#24103)
- if node is runnable binding with metadata attached
This commit is contained in:
parent
1e7d8ba9a6
commit
ee3fe20af4
@ -398,7 +398,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input_node = graph.add_node(self.get_input_schema(config))
|
||||
except TypeError:
|
||||
input_node = graph.add_node(create_model(self.get_name("Input")))
|
||||
runnable_node = graph.add_node(self)
|
||||
runnable_node = graph.add_node(
|
||||
self, metadata=config.get("metadata") if config else None
|
||||
)
|
||||
try:
|
||||
output_node = graph.add_node(self.get_output_schema(config))
|
||||
except TypeError:
|
||||
@ -4629,7 +4631,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
return self.bound.config_specs
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
return self.bound.get_graph(config)
|
||||
return self.bound.get_graph(self._merge_configs(config))
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@ -63,12 +64,32 @@ class Edge(NamedTuple):
|
||||
data: Optional[Stringifiable] = None
|
||||
conditional: bool = False
|
||||
|
||||
def copy(
|
||||
self, *, source: Optional[str] = None, target: Optional[str] = None
|
||||
) -> Edge:
|
||||
return Edge(
|
||||
source=source or self.source,
|
||||
target=target or self.target,
|
||||
data=self.data,
|
||||
conditional=self.conditional,
|
||||
)
|
||||
|
||||
|
||||
class Node(NamedTuple):
|
||||
"""Node in a graph."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
data: Union[Type[BaseModel], RunnableType]
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
|
||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||
return Node(
|
||||
id=id or self.id,
|
||||
name=name or self.name,
|
||||
data=self.data,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
|
||||
class Branch(NamedTuple):
|
||||
@ -111,35 +132,25 @@ class MermaidDrawMethod(Enum):
|
||||
API = "api" # Uses Mermaid.INK API to render the graph
|
||||
|
||||
|
||||
def node_data_str(node: Node) -> str:
|
||||
def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
|
||||
"""Convert the data of a node to a string.
|
||||
|
||||
Args:
|
||||
node: The node to convert.
|
||||
html: Whether to format the data as HTML rich text.
|
||||
|
||||
Returns:
|
||||
A string representation of the data.
|
||||
"""
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
if not is_uuid(node.id):
|
||||
return node.id
|
||||
elif isinstance(node.data, Runnable):
|
||||
try:
|
||||
data = str(node.data)
|
||||
if (
|
||||
data.startswith("<")
|
||||
or data[0] != data[0].upper()
|
||||
or len(data.splitlines()) > 1
|
||||
):
|
||||
data = node.data.__class__.__name__
|
||||
elif len(data) > 42:
|
||||
data = data[:42] + "..."
|
||||
except Exception:
|
||||
data = node.data.__class__.__name__
|
||||
if not is_uuid(id):
|
||||
return id
|
||||
elif isinstance(data, Runnable):
|
||||
data_str = data.get_name()
|
||||
else:
|
||||
data = node.data.__name__
|
||||
return data if not data.startswith("Runnable") else data[8:]
|
||||
data_str = data.__name__
|
||||
return data_str if not data_str.startswith("Runnable") else data_str[8:]
|
||||
|
||||
|
||||
def node_data_json(
|
||||
@ -163,7 +174,7 @@ def node_data_json(
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": node.data.lc_id(),
|
||||
"name": node.data.get_name(),
|
||||
"name": node_data_str(node.id, node.data),
|
||||
},
|
||||
}
|
||||
elif isinstance(node.data, Runnable):
|
||||
@ -171,7 +182,7 @@ def node_data_json(
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": to_json_not_implemented(node.data)["id"],
|
||||
"name": node.data.get_name(),
|
||||
"name": node_data_str(node.id, node.data),
|
||||
},
|
||||
}
|
||||
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
|
||||
@ -183,13 +194,13 @@ def node_data_json(
|
||||
if with_schemas
|
||||
else {
|
||||
"type": "schema",
|
||||
"data": node_data_str(node),
|
||||
"data": node_data_str(node.id, node.data),
|
||||
}
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"type": "unknown",
|
||||
"data": node_data_str(node),
|
||||
"data": node_data_str(node.id, node.data),
|
||||
}
|
||||
|
||||
|
||||
@ -236,12 +247,17 @@ class Graph:
|
||||
return uuid4().hex
|
||||
|
||||
def add_node(
|
||||
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
|
||||
self,
|
||||
data: Union[Type[BaseModel], RunnableType],
|
||||
id: Optional[str] = None,
|
||||
*,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Node:
|
||||
"""Add a node to the graph and return it."""
|
||||
if id is not None and id in self.nodes:
|
||||
raise ValueError(f"Node with id {id} already exists")
|
||||
node = Node(id=id or self.next_id(), data=data)
|
||||
id = id or self.next_id()
|
||||
node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data))
|
||||
self.nodes[node.id] = node
|
||||
return node
|
||||
|
||||
@ -285,25 +301,47 @@ class Graph:
|
||||
|
||||
# prefix each node
|
||||
self.nodes.update(
|
||||
{prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()}
|
||||
{prefixed(k): v.copy(id=prefixed(k)) for k, v in graph.nodes.items()}
|
||||
)
|
||||
# prefix each edge's source and target
|
||||
self.edges.extend(
|
||||
[
|
||||
Edge(
|
||||
prefixed(edge.source),
|
||||
prefixed(edge.target),
|
||||
edge.data,
|
||||
edge.conditional,
|
||||
)
|
||||
edge.copy(source=prefixed(edge.source), target=prefixed(edge.target))
|
||||
for edge in graph.edges
|
||||
]
|
||||
)
|
||||
# return (prefixed) first and last nodes of the subgraph
|
||||
first, last = graph.first_node(), graph.last_node()
|
||||
return (
|
||||
Node(prefixed(first.id), first.data) if first else None,
|
||||
Node(prefixed(last.id), last.data) if last else None,
|
||||
first.copy(id=prefixed(first.id)) if first else None,
|
||||
last.copy(id=prefixed(last.id)) if last else None,
|
||||
)
|
||||
|
||||
def reid(self) -> Graph:
|
||||
"""Return a new graph with all nodes re-identified,
|
||||
using their unique, readable names where possible."""
|
||||
node_labels = {node.id: node.name for node in self.nodes.values()}
|
||||
node_label_counts = Counter(node_labels.values())
|
||||
|
||||
def _get_node_id(node_id: str) -> str:
|
||||
label = node_labels[node_id]
|
||||
if is_uuid(node_id) and node_label_counts[label] == 1:
|
||||
return label
|
||||
else:
|
||||
return node_id
|
||||
|
||||
return Graph(
|
||||
nodes={
|
||||
_get_node_id(id): node.copy(id=_get_node_id(id))
|
||||
for id, node in self.nodes.items()
|
||||
},
|
||||
edges=[
|
||||
edge.copy(
|
||||
source=_get_node_id(edge.source),
|
||||
target=_get_node_id(edge.target),
|
||||
)
|
||||
for edge in self.edges
|
||||
],
|
||||
)
|
||||
|
||||
def first_node(self) -> Optional[Node]:
|
||||
@ -357,7 +395,7 @@ class Graph:
|
||||
from langchain_core.runnables.graph_ascii import draw_ascii
|
||||
|
||||
return draw_ascii(
|
||||
{node.id: node_data_str(node) for node in self.nodes.values()},
|
||||
{node.id: node.name for node in self.nodes.values()},
|
||||
self.edges,
|
||||
)
|
||||
|
||||
@ -388,9 +426,7 @@ class Graph:
|
||||
) -> Union[bytes, None]:
|
||||
from langchain_core.runnables.graph_png import PngDrawer
|
||||
|
||||
default_node_labels = {
|
||||
node.id: node_data_str(node) for node in self.nodes.values()
|
||||
}
|
||||
default_node_labels = {node.id: node.name for node in self.nodes.values()}
|
||||
|
||||
return PngDrawer(
|
||||
fontname,
|
||||
@ -415,19 +451,15 @@ class Graph:
|
||||
) -> str:
|
||||
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
||||
|
||||
nodes = {node.id: node_data_str(node) for node in self.nodes.values()}
|
||||
|
||||
first_node = self.first_node()
|
||||
first_label = node_data_str(first_node) if first_node is not None else None
|
||||
|
||||
last_node = self.last_node()
|
||||
last_label = node_data_str(last_node) if last_node is not None else None
|
||||
graph = self.reid()
|
||||
first_node = graph.first_node()
|
||||
last_node = graph.last_node()
|
||||
|
||||
return draw_mermaid(
|
||||
nodes=nodes,
|
||||
edges=self.edges,
|
||||
first_node_label=first_label,
|
||||
last_node_label=last_label,
|
||||
nodes=graph.nodes,
|
||||
edges=graph.edges,
|
||||
first_node=first_node.id if first_node else None,
|
||||
last_node=last_node.id if last_node else None,
|
||||
with_styles=with_styles,
|
||||
curve_style=curve_style,
|
||||
node_colors=node_colors,
|
||||
|
@ -1,22 +1,23 @@
|
||||
import base64
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.runnables.graph import (
|
||||
CurveStyle,
|
||||
Edge,
|
||||
MermaidDrawMethod,
|
||||
Node,
|
||||
NodeColors,
|
||||
)
|
||||
|
||||
|
||||
def draw_mermaid(
|
||||
nodes: Dict[str, str],
|
||||
nodes: Dict[str, Node],
|
||||
edges: List[Edge],
|
||||
*,
|
||||
first_node_label: Optional[str] = None,
|
||||
last_node_label: Optional[str] = None,
|
||||
first_node: Optional[str] = None,
|
||||
last_node: Optional[str] = None,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeColors = NodeColors(),
|
||||
@ -49,15 +50,20 @@ def draw_mermaid(
|
||||
# Node formatting templates
|
||||
default_class_label = "default"
|
||||
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
|
||||
if first_node_label is not None:
|
||||
format_dict[first_node_label] = "{0}[{0}]:::startclass"
|
||||
if last_node_label is not None:
|
||||
format_dict[last_node_label] = "{0}[{0}]:::endclass"
|
||||
if first_node is not None:
|
||||
format_dict[first_node] = "{0}[{0}]:::startclass"
|
||||
if last_node is not None:
|
||||
format_dict[last_node] = "{0}[{0}]:::endclass"
|
||||
|
||||
# Add nodes to the graph
|
||||
for node in nodes.values():
|
||||
node_label = format_dict.get(node, format_dict[default_class_label]).format(
|
||||
_escape_node_label(node), node.split(":", 1)[-1]
|
||||
for key, node in nodes.items():
|
||||
label = node.name.split(":")[-1]
|
||||
if node.metadata:
|
||||
label = f"<strong>{label}</strong>\n" + "\n".join(
|
||||
f"{key} = {value}" for key, value in node.metadata.items()
|
||||
)
|
||||
node_label = format_dict.get(key, format_dict[default_class_label]).format(
|
||||
_escape_node_label(key), label
|
||||
)
|
||||
mermaid_graph += f"\t{node_label};\n"
|
||||
|
||||
@ -74,9 +80,8 @@ def draw_mermaid(
|
||||
if not subgraph and src_prefix and src_prefix == tgt_prefix:
|
||||
mermaid_graph += f"\tsubgraph {src_prefix}\n"
|
||||
subgraph = src_prefix
|
||||
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
|
||||
|
||||
source, target = adjusted_edge
|
||||
source, target = edge.source, edge.target
|
||||
|
||||
# Add BR every wrap_label_n_words words
|
||||
if edge.data is not None:
|
||||
@ -117,17 +122,6 @@ def _escape_node_label(node_label: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
|
||||
|
||||
|
||||
def _adjust_mermaid_edge(
|
||||
edge: Edge,
|
||||
nodes: Dict[str, str],
|
||||
) -> Tuple[str, str]:
|
||||
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
|
||||
source_node_label = nodes.get(edge.source, edge.source)
|
||||
target_node_label = nodes.get(edge.target, edge.target)
|
||||
|
||||
return source_node_label, target_node_label
|
||||
|
||||
|
||||
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
|
||||
"""Generates Mermaid graph styles for different node types."""
|
||||
styles = ""
|
||||
|
@ -58,7 +58,7 @@
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"name": "RunnableLambda"
|
||||
"name": "Lambda"
|
||||
}
|
||||
}
|
||||
],
|
||||
@ -458,7 +458,7 @@
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
"name": "WithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -498,7 +498,7 @@
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"name": "RunnableLambda"
|
||||
"name": "Lambda"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -511,7 +511,7 @@
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
"name": "WithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -589,7 +589,7 @@
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"name": "RunnablePassthrough"
|
||||
"name": "Passthrough"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -635,7 +635,7 @@
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"name": "RunnablePassthrough"
|
||||
"name": "Passthrough"
|
||||
}
|
||||
}
|
||||
],
|
||||
@ -716,7 +716,7 @@
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
"name": "WithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -756,7 +756,7 @@
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"name": "RunnablePassthrough"
|
||||
"name": "Passthrough"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -769,7 +769,7 @@
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
"name": "WithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -938,7 +938,7 @@
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
"name": "WithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -1152,7 +1152,7 @@
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"name": "RunnableWithFallbacks"
|
||||
"name": "WithFallbacks"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -36,7 +36,8 @@
|
||||
graph TD;
|
||||
PromptInput[PromptInput]:::startclass;
|
||||
PromptTemplate([PromptTemplate]):::otherclass;
|
||||
FakeListLLM([FakeListLLM]):::otherclass;
|
||||
FakeListLLM([<strong>FakeListLLM</strong>
|
||||
key = 2]):::otherclass;
|
||||
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
|
||||
CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
|
||||
PromptInput --> PromptTemplate;
|
||||
|
File diff suppressed because one or more lines are too long
@ -33,7 +33,7 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
sequence = prompt | fake_llm | list_parser
|
||||
sequence = prompt | fake_llm.with_config(metadata={"key": 2}) | list_parser
|
||||
graph = sequence.get_graph()
|
||||
assert graph.to_json() == {
|
||||
"nodes": [
|
||||
|
@ -6,6 +6,7 @@ authors = []
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
repository = "https://www.github.com/langchain-ai/langchain"
|
||||
package-mode = false
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
|
Loading…
Reference in New Issue
Block a user