mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
core: use friendlier names for duplicated nodes in mermaid output (#27747)
Thank you for contributing to LangChain! - [x] **PR title**: "core: use friendlier names for duplicated nodes in mermaid output" - **Description:** When generating the Mermaid visualization of a chain, if the chain had multiple nodes of the same type, the reid function would replace their names with the UUID node_id. This made the generated graph difficult to understand. This change deduplicates the nodes in a chain by appending an index to their names. - **Issue:** None - **Discussion:** https://github.com/langchain-ai/langchain/discussions/27714 - **Dependencies:** None - [ ] **Add tests and docs**: - Currently this functionality is not covered by unit tests, happy to add tests if you'd like - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. # Example Code: ```python from langchain_core.runnables import RunnablePassthrough def fake_llm(prompt: str) -> str: # Fake LLM for the example return "completion" runnable = { 'llm1': fake_llm, 'llm2': fake_llm, } | RunnablePassthrough.assign( total_chars=lambda inputs: len(inputs['llm1'] + inputs['llm2']) ) print(runnable.get_graph().draw_mermaid(with_styles=False)) ``` # Before ```mermaid graph TD; Parallel_llm1_llm2_Input --> 0b01139db5ed4587ad37964e3a40c0ec; 0b01139db5ed4587ad37964e3a40c0ec --> Parallel_llm1_llm2_Output; Parallel_llm1_llm2_Input --> a98d4b56bd294156a651230b9293347f; a98d4b56bd294156a651230b9293347f --> Parallel_llm1_llm2_Output; Parallel_total_chars_Input --> Lambda; Lambda --> Parallel_total_chars_Output; Parallel_total_chars_Input --> Passthrough; Passthrough --> Parallel_total_chars_Output; Parallel_llm1_llm2_Output --> Parallel_total_chars_Input; ``` # After ```mermaid graph TD; Parallel_llm1_llm2_Input --> fake_llm_1; fake_llm_1 --> Parallel_llm1_llm2_Output; Parallel_llm1_llm2_Input --> fake_llm_2; fake_llm_2 --> Parallel_llm1_llm2_Output; Parallel_total_chars_Input --> Lambda; Lambda --> Parallel_total_chars_Output; Parallel_total_chars_Input --> Passthrough; Passthrough --> Parallel_total_chars_Output; Parallel_llm1_llm2_Output --> Parallel_total_chars_Input; ```
This commit is contained in:
parent
71f590de50
commit
e3ea365725
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from collections import Counter
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -423,12 +423,19 @@ class Graph:
|
|||||||
def reid(self) -> Graph:
|
def reid(self) -> Graph:
|
||||||
"""Return a new graph with all nodes re-identified,
|
"""Return a new graph with all nodes re-identified,
|
||||||
using their unique, readable names where possible."""
|
using their unique, readable names where possible."""
|
||||||
node_labels = {node.id: node.name for node in self.nodes.values()}
|
node_name_to_ids = defaultdict(list)
|
||||||
node_label_counts = Counter(node_labels.values())
|
for node in self.nodes.values():
|
||||||
|
node_name_to_ids[node.name].append(node.id)
|
||||||
|
|
||||||
|
unique_labels = {
|
||||||
|
node_id: node_name if len(node_ids) == 1 else f"{node_name}_{i + 1}"
|
||||||
|
for node_name, node_ids in node_name_to_ids.items()
|
||||||
|
for i, node_id in enumerate(node_ids)
|
||||||
|
}
|
||||||
|
|
||||||
def _get_node_id(node_id: str) -> str:
|
def _get_node_id(node_id: str) -> str:
|
||||||
label = node_labels[node_id]
|
label = unique_labels[node_id]
|
||||||
if is_uuid(node_id) and node_label_counts[label] == 1:
|
if is_uuid(node_id):
|
||||||
return label
|
return label
|
||||||
else:
|
else:
|
||||||
return node_id
|
return node_id
|
||||||
|
@ -26,6 +26,20 @@
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_graph_mermaid_duplicate_nodes[mermaid]
|
||||||
|
'''
|
||||||
|
graph TD;
|
||||||
|
PromptInput --> PromptTemplate_1;
|
||||||
|
Parallel_llm1_llm2_Input --> FakeListLLM_1;
|
||||||
|
FakeListLLM_1 --> Parallel_llm1_llm2_Output;
|
||||||
|
Parallel_llm1_llm2_Input --> FakeListLLM_2;
|
||||||
|
FakeListLLM_2 --> Parallel_llm1_llm2_Output;
|
||||||
|
PromptTemplate_1 --> Parallel_llm1_llm2_Input;
|
||||||
|
PromptTemplate_2 --> PromptTemplateOutput;
|
||||||
|
Parallel_llm1_llm2_Output --> PromptTemplate_2;
|
||||||
|
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
# name: test_graph_sequence[ascii]
|
# name: test_graph_sequence[ascii]
|
||||||
'''
|
'''
|
||||||
+-------------+
|
+-------------+
|
||||||
|
@ -405,3 +405,17 @@ def test_graph_mermaid_escape_node_label() -> None:
|
|||||||
assert _escape_node_label("foo-bar") == "foo-bar"
|
assert _escape_node_label("foo-bar") == "foo-bar"
|
||||||
assert _escape_node_label("foo_1") == "foo_1"
|
assert _escape_node_label("foo_1") == "foo_1"
|
||||||
assert _escape_node_label("#foo*&!") == "_foo___"
|
assert _escape_node_label("#foo*&!") == "_foo___"
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
|
||||||
|
fake_llm = FakeListLLM(responses=["foo", "bar"])
|
||||||
|
sequence: Runnable = (
|
||||||
|
PromptTemplate.from_template("Hello, {input}")
|
||||||
|
| {
|
||||||
|
"llm1": fake_llm,
|
||||||
|
"llm2": fake_llm,
|
||||||
|
}
|
||||||
|
| PromptTemplate.from_template("{llm1} {llm2}")
|
||||||
|
)
|
||||||
|
graph = sequence.get_graph()
|
||||||
|
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")
|
||||||
|
Loading…
Reference in New Issue
Block a user