mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
core: mermaid: Render metadata key-value pairs when drawing mermaid graph (#24103)
- if node is runnable binding with metadata attached
This commit is contained in:
parent
1e7d8ba9a6
commit
ee3fe20af4
@ -398,7 +398,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
input_node = graph.add_node(self.get_input_schema(config))
|
input_node = graph.add_node(self.get_input_schema(config))
|
||||||
except TypeError:
|
except TypeError:
|
||||||
input_node = graph.add_node(create_model(self.get_name("Input")))
|
input_node = graph.add_node(create_model(self.get_name("Input")))
|
||||||
runnable_node = graph.add_node(self)
|
runnable_node = graph.add_node(
|
||||||
|
self, metadata=config.get("metadata") if config else None
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
output_node = graph.add_node(self.get_output_schema(config))
|
output_node = graph.add_node(self.get_output_schema(config))
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@ -4629,7 +4631,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
return self.bound.config_specs
|
return self.bound.config_specs
|
||||||
|
|
||||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||||
return self.bound.get_graph(config)
|
return self.bound.get_graph(self._merge_configs(config))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections import Counter
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -63,12 +64,32 @@ class Edge(NamedTuple):
|
|||||||
data: Optional[Stringifiable] = None
|
data: Optional[Stringifiable] = None
|
||||||
conditional: bool = False
|
conditional: bool = False
|
||||||
|
|
||||||
|
def copy(
|
||||||
|
self, *, source: Optional[str] = None, target: Optional[str] = None
|
||||||
|
) -> Edge:
|
||||||
|
return Edge(
|
||||||
|
source=source or self.source,
|
||||||
|
target=target or self.target,
|
||||||
|
data=self.data,
|
||||||
|
conditional=self.conditional,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Node(NamedTuple):
|
class Node(NamedTuple):
|
||||||
"""Node in a graph."""
|
"""Node in a graph."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
|
name: str
|
||||||
data: Union[Type[BaseModel], RunnableType]
|
data: Union[Type[BaseModel], RunnableType]
|
||||||
|
metadata: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||||
|
return Node(
|
||||||
|
id=id or self.id,
|
||||||
|
name=name or self.name,
|
||||||
|
data=self.data,
|
||||||
|
metadata=self.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Branch(NamedTuple):
|
class Branch(NamedTuple):
|
||||||
@ -111,35 +132,25 @@ class MermaidDrawMethod(Enum):
|
|||||||
API = "api" # Uses Mermaid.INK API to render the graph
|
API = "api" # Uses Mermaid.INK API to render the graph
|
||||||
|
|
||||||
|
|
||||||
def node_data_str(node: Node) -> str:
|
def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
|
||||||
"""Convert the data of a node to a string.
|
"""Convert the data of a node to a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node: The node to convert.
|
node: The node to convert.
|
||||||
|
html: Whether to format the data as HTML rich text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A string representation of the data.
|
A string representation of the data.
|
||||||
"""
|
"""
|
||||||
from langchain_core.runnables.base import Runnable
|
from langchain_core.runnables.base import Runnable
|
||||||
|
|
||||||
if not is_uuid(node.id):
|
if not is_uuid(id):
|
||||||
return node.id
|
return id
|
||||||
elif isinstance(node.data, Runnable):
|
elif isinstance(data, Runnable):
|
||||||
try:
|
data_str = data.get_name()
|
||||||
data = str(node.data)
|
|
||||||
if (
|
|
||||||
data.startswith("<")
|
|
||||||
or data[0] != data[0].upper()
|
|
||||||
or len(data.splitlines()) > 1
|
|
||||||
):
|
|
||||||
data = node.data.__class__.__name__
|
|
||||||
elif len(data) > 42:
|
|
||||||
data = data[:42] + "..."
|
|
||||||
except Exception:
|
|
||||||
data = node.data.__class__.__name__
|
|
||||||
else:
|
else:
|
||||||
data = node.data.__name__
|
data_str = data.__name__
|
||||||
return data if not data.startswith("Runnable") else data[8:]
|
return data_str if not data_str.startswith("Runnable") else data_str[8:]
|
||||||
|
|
||||||
|
|
||||||
def node_data_json(
|
def node_data_json(
|
||||||
@ -163,7 +174,7 @@ def node_data_json(
|
|||||||
"type": "runnable",
|
"type": "runnable",
|
||||||
"data": {
|
"data": {
|
||||||
"id": node.data.lc_id(),
|
"id": node.data.lc_id(),
|
||||||
"name": node.data.get_name(),
|
"name": node_data_str(node.id, node.data),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
elif isinstance(node.data, Runnable):
|
elif isinstance(node.data, Runnable):
|
||||||
@ -171,7 +182,7 @@ def node_data_json(
|
|||||||
"type": "runnable",
|
"type": "runnable",
|
||||||
"data": {
|
"data": {
|
||||||
"id": to_json_not_implemented(node.data)["id"],
|
"id": to_json_not_implemented(node.data)["id"],
|
||||||
"name": node.data.get_name(),
|
"name": node_data_str(node.id, node.data),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
|
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
|
||||||
@ -183,13 +194,13 @@ def node_data_json(
|
|||||||
if with_schemas
|
if with_schemas
|
||||||
else {
|
else {
|
||||||
"type": "schema",
|
"type": "schema",
|
||||||
"data": node_data_str(node),
|
"data": node_data_str(node.id, node.data),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"type": "unknown",
|
"type": "unknown",
|
||||||
"data": node_data_str(node),
|
"data": node_data_str(node.id, node.data),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -236,12 +247,17 @@ class Graph:
|
|||||||
return uuid4().hex
|
return uuid4().hex
|
||||||
|
|
||||||
def add_node(
|
def add_node(
|
||||||
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
|
self,
|
||||||
|
data: Union[Type[BaseModel], RunnableType],
|
||||||
|
id: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""Add a node to the graph and return it."""
|
"""Add a node to the graph and return it."""
|
||||||
if id is not None and id in self.nodes:
|
if id is not None and id in self.nodes:
|
||||||
raise ValueError(f"Node with id {id} already exists")
|
raise ValueError(f"Node with id {id} already exists")
|
||||||
node = Node(id=id or self.next_id(), data=data)
|
id = id or self.next_id()
|
||||||
|
node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data))
|
||||||
self.nodes[node.id] = node
|
self.nodes[node.id] = node
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@ -285,25 +301,47 @@ class Graph:
|
|||||||
|
|
||||||
# prefix each node
|
# prefix each node
|
||||||
self.nodes.update(
|
self.nodes.update(
|
||||||
{prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()}
|
{prefixed(k): v.copy(id=prefixed(k)) for k, v in graph.nodes.items()}
|
||||||
)
|
)
|
||||||
# prefix each edge's source and target
|
# prefix each edge's source and target
|
||||||
self.edges.extend(
|
self.edges.extend(
|
||||||
[
|
[
|
||||||
Edge(
|
edge.copy(source=prefixed(edge.source), target=prefixed(edge.target))
|
||||||
prefixed(edge.source),
|
|
||||||
prefixed(edge.target),
|
|
||||||
edge.data,
|
|
||||||
edge.conditional,
|
|
||||||
)
|
|
||||||
for edge in graph.edges
|
for edge in graph.edges
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
# return (prefixed) first and last nodes of the subgraph
|
# return (prefixed) first and last nodes of the subgraph
|
||||||
first, last = graph.first_node(), graph.last_node()
|
first, last = graph.first_node(), graph.last_node()
|
||||||
return (
|
return (
|
||||||
Node(prefixed(first.id), first.data) if first else None,
|
first.copy(id=prefixed(first.id)) if first else None,
|
||||||
Node(prefixed(last.id), last.data) if last else None,
|
last.copy(id=prefixed(last.id)) if last else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reid(self) -> Graph:
|
||||||
|
"""Return a new graph with all nodes re-identified,
|
||||||
|
using their unique, readable names where possible."""
|
||||||
|
node_labels = {node.id: node.name for node in self.nodes.values()}
|
||||||
|
node_label_counts = Counter(node_labels.values())
|
||||||
|
|
||||||
|
def _get_node_id(node_id: str) -> str:
|
||||||
|
label = node_labels[node_id]
|
||||||
|
if is_uuid(node_id) and node_label_counts[label] == 1:
|
||||||
|
return label
|
||||||
|
else:
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
return Graph(
|
||||||
|
nodes={
|
||||||
|
_get_node_id(id): node.copy(id=_get_node_id(id))
|
||||||
|
for id, node in self.nodes.items()
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
edge.copy(
|
||||||
|
source=_get_node_id(edge.source),
|
||||||
|
target=_get_node_id(edge.target),
|
||||||
|
)
|
||||||
|
for edge in self.edges
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def first_node(self) -> Optional[Node]:
|
def first_node(self) -> Optional[Node]:
|
||||||
@ -357,7 +395,7 @@ class Graph:
|
|||||||
from langchain_core.runnables.graph_ascii import draw_ascii
|
from langchain_core.runnables.graph_ascii import draw_ascii
|
||||||
|
|
||||||
return draw_ascii(
|
return draw_ascii(
|
||||||
{node.id: node_data_str(node) for node in self.nodes.values()},
|
{node.id: node.name for node in self.nodes.values()},
|
||||||
self.edges,
|
self.edges,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -388,9 +426,7 @@ class Graph:
|
|||||||
) -> Union[bytes, None]:
|
) -> Union[bytes, None]:
|
||||||
from langchain_core.runnables.graph_png import PngDrawer
|
from langchain_core.runnables.graph_png import PngDrawer
|
||||||
|
|
||||||
default_node_labels = {
|
default_node_labels = {node.id: node.name for node in self.nodes.values()}
|
||||||
node.id: node_data_str(node) for node in self.nodes.values()
|
|
||||||
}
|
|
||||||
|
|
||||||
return PngDrawer(
|
return PngDrawer(
|
||||||
fontname,
|
fontname,
|
||||||
@ -415,19 +451,15 @@ class Graph:
|
|||||||
) -> str:
|
) -> str:
|
||||||
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
||||||
|
|
||||||
nodes = {node.id: node_data_str(node) for node in self.nodes.values()}
|
graph = self.reid()
|
||||||
|
first_node = graph.first_node()
|
||||||
first_node = self.first_node()
|
last_node = graph.last_node()
|
||||||
first_label = node_data_str(first_node) if first_node is not None else None
|
|
||||||
|
|
||||||
last_node = self.last_node()
|
|
||||||
last_label = node_data_str(last_node) if last_node is not None else None
|
|
||||||
|
|
||||||
return draw_mermaid(
|
return draw_mermaid(
|
||||||
nodes=nodes,
|
nodes=graph.nodes,
|
||||||
edges=self.edges,
|
edges=graph.edges,
|
||||||
first_node_label=first_label,
|
first_node=first_node.id if first_node else None,
|
||||||
last_node_label=last_label,
|
last_node=last_node.id if last_node else None,
|
||||||
with_styles=with_styles,
|
with_styles=with_styles,
|
||||||
curve_style=curve_style,
|
curve_style=curve_style,
|
||||||
node_colors=node_colors,
|
node_colors=node_colors,
|
||||||
|
@ -1,22 +1,23 @@
|
|||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.runnables.graph import (
|
from langchain_core.runnables.graph import (
|
||||||
CurveStyle,
|
CurveStyle,
|
||||||
Edge,
|
Edge,
|
||||||
MermaidDrawMethod,
|
MermaidDrawMethod,
|
||||||
|
Node,
|
||||||
NodeColors,
|
NodeColors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def draw_mermaid(
|
def draw_mermaid(
|
||||||
nodes: Dict[str, str],
|
nodes: Dict[str, Node],
|
||||||
edges: List[Edge],
|
edges: List[Edge],
|
||||||
*,
|
*,
|
||||||
first_node_label: Optional[str] = None,
|
first_node: Optional[str] = None,
|
||||||
last_node_label: Optional[str] = None,
|
last_node: Optional[str] = None,
|
||||||
with_styles: bool = True,
|
with_styles: bool = True,
|
||||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||||
node_colors: NodeColors = NodeColors(),
|
node_colors: NodeColors = NodeColors(),
|
||||||
@ -49,15 +50,20 @@ def draw_mermaid(
|
|||||||
# Node formatting templates
|
# Node formatting templates
|
||||||
default_class_label = "default"
|
default_class_label = "default"
|
||||||
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
|
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
|
||||||
if first_node_label is not None:
|
if first_node is not None:
|
||||||
format_dict[first_node_label] = "{0}[{0}]:::startclass"
|
format_dict[first_node] = "{0}[{0}]:::startclass"
|
||||||
if last_node_label is not None:
|
if last_node is not None:
|
||||||
format_dict[last_node_label] = "{0}[{0}]:::endclass"
|
format_dict[last_node] = "{0}[{0}]:::endclass"
|
||||||
|
|
||||||
# Add nodes to the graph
|
# Add nodes to the graph
|
||||||
for node in nodes.values():
|
for key, node in nodes.items():
|
||||||
node_label = format_dict.get(node, format_dict[default_class_label]).format(
|
label = node.name.split(":")[-1]
|
||||||
_escape_node_label(node), node.split(":", 1)[-1]
|
if node.metadata:
|
||||||
|
label = f"<strong>{label}</strong>\n" + "\n".join(
|
||||||
|
f"{key} = {value}" for key, value in node.metadata.items()
|
||||||
|
)
|
||||||
|
node_label = format_dict.get(key, format_dict[default_class_label]).format(
|
||||||
|
_escape_node_label(key), label
|
||||||
)
|
)
|
||||||
mermaid_graph += f"\t{node_label};\n"
|
mermaid_graph += f"\t{node_label};\n"
|
||||||
|
|
||||||
@ -74,9 +80,8 @@ def draw_mermaid(
|
|||||||
if not subgraph and src_prefix and src_prefix == tgt_prefix:
|
if not subgraph and src_prefix and src_prefix == tgt_prefix:
|
||||||
mermaid_graph += f"\tsubgraph {src_prefix}\n"
|
mermaid_graph += f"\tsubgraph {src_prefix}\n"
|
||||||
subgraph = src_prefix
|
subgraph = src_prefix
|
||||||
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
|
|
||||||
|
|
||||||
source, target = adjusted_edge
|
source, target = edge.source, edge.target
|
||||||
|
|
||||||
# Add BR every wrap_label_n_words words
|
# Add BR every wrap_label_n_words words
|
||||||
if edge.data is not None:
|
if edge.data is not None:
|
||||||
@ -117,17 +122,6 @@ def _escape_node_label(node_label: str) -> str:
|
|||||||
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
|
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
|
||||||
|
|
||||||
|
|
||||||
def _adjust_mermaid_edge(
|
|
||||||
edge: Edge,
|
|
||||||
nodes: Dict[str, str],
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
|
|
||||||
source_node_label = nodes.get(edge.source, edge.source)
|
|
||||||
target_node_label = nodes.get(edge.target, edge.target)
|
|
||||||
|
|
||||||
return source_node_label, target_node_label
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
|
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
|
||||||
"""Generates Mermaid graph styles for different node types."""
|
"""Generates Mermaid graph styles for different node types."""
|
||||||
styles = ""
|
styles = ""
|
||||||
|
@ -58,7 +58,7 @@
|
|||||||
"base",
|
"base",
|
||||||
"RunnableLambda"
|
"RunnableLambda"
|
||||||
],
|
],
|
||||||
"name": "RunnableLambda"
|
"name": "Lambda"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -458,7 +458,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnableWithFallbacks"
|
"RunnableWithFallbacks"
|
||||||
],
|
],
|
||||||
"name": "RunnableWithFallbacks"
|
"name": "WithFallbacks"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -498,7 +498,7 @@
|
|||||||
"base",
|
"base",
|
||||||
"RunnableLambda"
|
"RunnableLambda"
|
||||||
],
|
],
|
||||||
"name": "RunnableLambda"
|
"name": "Lambda"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -511,7 +511,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnableWithFallbacks"
|
"RunnableWithFallbacks"
|
||||||
],
|
],
|
||||||
"name": "RunnableWithFallbacks"
|
"name": "WithFallbacks"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -589,7 +589,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnablePassthrough"
|
"RunnablePassthrough"
|
||||||
],
|
],
|
||||||
"name": "RunnablePassthrough"
|
"name": "Passthrough"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -635,7 +635,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnablePassthrough"
|
"RunnablePassthrough"
|
||||||
],
|
],
|
||||||
"name": "RunnablePassthrough"
|
"name": "Passthrough"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -716,7 +716,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnableWithFallbacks"
|
"RunnableWithFallbacks"
|
||||||
],
|
],
|
||||||
"name": "RunnableWithFallbacks"
|
"name": "WithFallbacks"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -756,7 +756,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnablePassthrough"
|
"RunnablePassthrough"
|
||||||
],
|
],
|
||||||
"name": "RunnablePassthrough"
|
"name": "Passthrough"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -769,7 +769,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnableWithFallbacks"
|
"RunnableWithFallbacks"
|
||||||
],
|
],
|
||||||
"name": "RunnableWithFallbacks"
|
"name": "WithFallbacks"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -938,7 +938,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnableWithFallbacks"
|
"RunnableWithFallbacks"
|
||||||
],
|
],
|
||||||
"name": "RunnableWithFallbacks"
|
"name": "WithFallbacks"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1152,7 +1152,7 @@
|
|||||||
"runnable",
|
"runnable",
|
||||||
"RunnableWithFallbacks"
|
"RunnableWithFallbacks"
|
||||||
],
|
],
|
||||||
"name": "RunnableWithFallbacks"
|
"name": "WithFallbacks"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -36,7 +36,8 @@
|
|||||||
graph TD;
|
graph TD;
|
||||||
PromptInput[PromptInput]:::startclass;
|
PromptInput[PromptInput]:::startclass;
|
||||||
PromptTemplate([PromptTemplate]):::otherclass;
|
PromptTemplate([PromptTemplate]):::otherclass;
|
||||||
FakeListLLM([FakeListLLM]):::otherclass;
|
FakeListLLM([<strong>FakeListLLM</strong>
|
||||||
|
key = 2]):::otherclass;
|
||||||
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
|
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
|
||||||
CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
|
CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
|
||||||
PromptInput --> PromptTemplate;
|
PromptInput --> PromptTemplate;
|
||||||
|
File diff suppressed because one or more lines are too long
@ -33,7 +33,7 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
|||||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||||
list_parser = CommaSeparatedListOutputParser()
|
list_parser = CommaSeparatedListOutputParser()
|
||||||
|
|
||||||
sequence = prompt | fake_llm | list_parser
|
sequence = prompt | fake_llm.with_config(metadata={"key": 2}) | list_parser
|
||||||
graph = sequence.get_graph()
|
graph = sequence.get_graph()
|
||||||
assert graph.to_json() == {
|
assert graph.to_json() == {
|
||||||
"nodes": [
|
"nodes": [
|
||||||
|
@ -6,6 +6,7 @@ authors = []
|
|||||||
license = "MIT"
|
license = "MIT"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
repository = "https://www.github.com/langchain-ai/langchain"
|
repository = "https://www.github.com/langchain-ai/langchain"
|
||||||
|
package-mode = false
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
|
Loading…
Reference in New Issue
Block a user