Add Runnable.get_graph() to get a graph representation of a Runnable (#15040)

It can be drawn in ascii with Runnable.get_graph().draw()
This commit is contained in:
Nuno Campos
2023-12-22 11:40:45 -08:00
committed by GitHub
parent aad3d8bd47
commit 7d5800ee51
12 changed files with 739 additions and 27 deletions

View File

@@ -0,0 +1,88 @@
# serializer version: 1
# name: test_graph_sequence
'''
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+--------------------------------+
| CommaSeparatedListOutputParser |
+--------------------------------+
*
*
*
+--------------------------------------+
| CommaSeparatedListOutputParserOutput |
+--------------------------------------+
'''
# ---
# name: test_graph_sequence_map
'''
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+-----------------------+
| RunnableParallelInput |
+-----------------------+
**** ***
**** ****
** **
+---------------------+ +--------------------------------+
| RunnablePassthrough | | CommaSeparatedListOutputParser |
+---------------------+ +--------------------------------+
**** ***
**** ****
** **
+------------------------+
| RunnableParallelOutput |
+------------------------+
'''
# ---
# name: test_graph_single_runnable
'''
+----------------------+
| StrOutputParserInput |
+----------------------+
*
*
*
+-----------------+
| StrOutputParser |
+-----------------+
*
*
*
+-----------------------+
| StrOutputParserOutput |
+-----------------------+
'''
# ---

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,51 @@
from syrupy import SnapshotAssertion
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable
from langchain_core.runnables.passthrough import RunnablePassthrough
from tests.unit_tests.fake.llm import FakeListLLM
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
runnable = StrOutputParser()
graph = StrOutputParser().get_graph()
first_node = graph.first_node()
assert first_node is not None
assert first_node.data.schema() == runnable.input_schema.schema() # type: ignore[union-attr]
last_node = graph.last_node()
assert last_node is not None
assert last_node.data.schema() == runnable.output_schema.schema() # type: ignore[union-attr]
assert len(graph.nodes) == 3
assert len(graph.edges) == 2
assert graph.edges[0].source == first_node.id
assert graph.edges[1].target == last_node.id
assert graph.draw_ascii() == snapshot
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"])
prompt = PromptTemplate.from_template("Hello, {name}!")
list_parser = CommaSeparatedListOutputParser()
sequence = prompt | fake_llm | list_parser
graph = sequence.get_graph()
assert graph.draw_ascii() == snapshot
def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"])
prompt = PromptTemplate.from_template("Hello, {name}!")
list_parser = CommaSeparatedListOutputParser()
sequence: Runnable = (
prompt
| fake_llm
| {
"original": RunnablePassthrough(input_type=str),
"as_list": list_parser,
}
)
graph = sequence.get_graph()
assert graph.draw_ascii() == snapshot

View File

@@ -4127,13 +4127,13 @@ def test_representation_of_runnables() -> None:
"""Return 2."""
return 2
assert repr(RunnableLambda(func=f)) == "RunnableLambda(...)"
assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"
async def af(x: int) -> int:
"""Return 2."""
return 2
assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(...)"
assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(f)"
assert repr(
RunnableLambda(lambda x: x + 2)