core: use friendlier names for duplicated nodes in mermaid output (#27747)

Thank you for contributing to LangChain!

- [x] **PR title**: "core: use friendlier names for duplicated nodes in
mermaid output"

- **Description:** When generating the Mermaid visualization of a chain,
if the chain had multiple nodes of the same type, the reid function
would replace their names with the UUID node_id. This made the generated
graph difficult to understand. This change deduplicates the nodes in a
chain by appending an index to their names.
- **Issue:** None
- **Discussion:**
https://github.com/langchain-ai/langchain/discussions/27714
- **Dependencies:** None

- [ ] **Add tests and docs**:  
- Currently this functionality is not covered by unit tests, happy to
add tests if you'd like


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

# Example Code:
```python
from langchain_core.runnables import RunnablePassthrough

def fake_llm(prompt: str) -> str: # Fake LLM for the example
    return "completion"

runnable = {
    'llm1':  fake_llm,
    'llm2':  fake_llm,
} | RunnablePassthrough.assign(
    total_chars=lambda inputs: len(inputs['llm1'] + inputs['llm2'])
)

print(runnable.get_graph().draw_mermaid(with_styles=False))
```

# Before
```mermaid
graph TD;
	Parallel_llm1_llm2_Input --> 0b01139db5ed4587ad37964e3a40c0ec;
	0b01139db5ed4587ad37964e3a40c0ec --> Parallel_llm1_llm2_Output;
	Parallel_llm1_llm2_Input --> a98d4b56bd294156a651230b9293347f;
	a98d4b56bd294156a651230b9293347f --> Parallel_llm1_llm2_Output;
	Parallel_total_chars_Input --> Lambda;
	Lambda --> Parallel_total_chars_Output;
	Parallel_total_chars_Input --> Passthrough;
	Passthrough --> Parallel_total_chars_Output;
	Parallel_llm1_llm2_Output --> Parallel_total_chars_Input;
```

# After
```mermaid
graph TD;
	Parallel_llm1_llm2_Input --> fake_llm_1;
	fake_llm_1 --> Parallel_llm1_llm2_Output;
	Parallel_llm1_llm2_Input --> fake_llm_2;
	fake_llm_2 --> Parallel_llm1_llm2_Output;
	Parallel_total_chars_Input --> Lambda;
	Lambda --> Parallel_total_chars_Output;
	Parallel_total_chars_Input --> Passthrough;
	Passthrough --> Parallel_total_chars_Output;
	Parallel_llm1_llm2_Output --> Parallel_total_chars_Input;
```
This commit is contained in:
Ant White 2024-10-31 20:52:00 +00:00 committed by GitHub
parent 71f590de50
commit e3ea365725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 5 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
from collections import Counter from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@ -423,12 +423,19 @@ class Graph:
def reid(self) -> Graph: def reid(self) -> Graph:
"""Return a new graph with all nodes re-identified, """Return a new graph with all nodes re-identified,
using their unique, readable names where possible.""" using their unique, readable names where possible."""
node_labels = {node.id: node.name for node in self.nodes.values()} node_name_to_ids = defaultdict(list)
node_label_counts = Counter(node_labels.values()) for node in self.nodes.values():
node_name_to_ids[node.name].append(node.id)
unique_labels = {
node_id: node_name if len(node_ids) == 1 else f"{node_name}_{i + 1}"
for node_name, node_ids in node_name_to_ids.items()
for i, node_id in enumerate(node_ids)
}
def _get_node_id(node_id: str) -> str: def _get_node_id(node_id: str) -> str:
label = node_labels[node_id] label = unique_labels[node_id]
if is_uuid(node_id) and node_label_counts[label] == 1: if is_uuid(node_id):
return label return label
else: else:
return node_id return node_id

View File

@ -26,6 +26,20 @@
''' '''
# --- # ---
# name: test_graph_mermaid_duplicate_nodes[mermaid]
'''
graph TD;
PromptInput --> PromptTemplate_1;
Parallel_llm1_llm2_Input --> FakeListLLM_1;
FakeListLLM_1 --> Parallel_llm1_llm2_Output;
Parallel_llm1_llm2_Input --> FakeListLLM_2;
FakeListLLM_2 --> Parallel_llm1_llm2_Output;
PromptTemplate_1 --> Parallel_llm1_llm2_Input;
PromptTemplate_2 --> PromptTemplateOutput;
Parallel_llm1_llm2_Output --> PromptTemplate_2;
'''
# ---
# name: test_graph_sequence[ascii] # name: test_graph_sequence[ascii]
''' '''
+-------------+ +-------------+

View File

@ -405,3 +405,17 @@ def test_graph_mermaid_escape_node_label() -> None:
assert _escape_node_label("foo-bar") == "foo-bar" assert _escape_node_label("foo-bar") == "foo-bar"
assert _escape_node_label("foo_1") == "foo_1" assert _escape_node_label("foo_1") == "foo_1"
assert _escape_node_label("#foo*&!") == "_foo___" assert _escape_node_label("#foo*&!") == "_foo___"
def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["foo", "bar"])
sequence: Runnable = (
PromptTemplate.from_template("Hello, {input}")
| {
"llm1": fake_llm,
"llm2": fake_llm,
}
| PromptTemplate.from_template("{llm1} {llm2}")
)
graph = sequence.get_graph()
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")