mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 03:26:17 +00:00
Add Runnable.get_graph() to get a graph representation of a Runnable (#15040)
It can be drawn in ascii with Runnable.get_graph().draw()
This commit is contained in:
@@ -0,0 +1,88 @@
|
||||
# serializer version: 1
|
||||
# name: test_graph_sequence
|
||||
'''
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+--------------------------------+
|
||||
| CommaSeparatedListOutputParser |
|
||||
+--------------------------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+--------------------------------------+
|
||||
| CommaSeparatedListOutputParserOutput |
|
||||
+--------------------------------------+
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_sequence_map
|
||||
'''
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------------+
|
||||
| RunnableParallelInput |
|
||||
+-----------------------+
|
||||
**** ***
|
||||
**** ****
|
||||
** **
|
||||
+---------------------+ +--------------------------------+
|
||||
| RunnablePassthrough | | CommaSeparatedListOutputParser |
|
||||
+---------------------+ +--------------------------------+
|
||||
**** ***
|
||||
**** ****
|
||||
** **
|
||||
+------------------------+
|
||||
| RunnableParallelOutput |
|
||||
+------------------------+
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_single_runnable
|
||||
'''
|
||||
+----------------------+
|
||||
| StrOutputParserInput |
|
||||
+----------------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------+
|
||||
| StrOutputParser |
|
||||
+-----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------------+
|
||||
| StrOutputParserOutput |
|
||||
+-----------------------+
|
||||
'''
|
||||
# ---
|
File diff suppressed because one or more lines are too long
51
libs/core/tests/unit_tests/runnables/test_graph.py
Normal file
51
libs/core/tests/unit_tests/runnables/test_graph.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
|
||||
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
runnable = StrOutputParser()
|
||||
graph = StrOutputParser().get_graph()
|
||||
first_node = graph.first_node()
|
||||
assert first_node is not None
|
||||
assert first_node.data.schema() == runnable.input_schema.schema() # type: ignore[union-attr]
|
||||
last_node = graph.last_node()
|
||||
assert last_node is not None
|
||||
assert last_node.data.schema() == runnable.output_schema.schema() # type: ignore[union-attr]
|
||||
assert len(graph.nodes) == 3
|
||||
assert len(graph.edges) == 2
|
||||
assert graph.edges[0].source == first_node.id
|
||||
assert graph.edges[1].target == last_node.id
|
||||
assert graph.draw_ascii() == snapshot
|
||||
|
||||
|
||||
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"])
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
sequence = prompt | fake_llm | list_parser
|
||||
graph = sequence.get_graph()
|
||||
assert graph.draw_ascii() == snapshot
|
||||
|
||||
|
||||
def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"])
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
sequence: Runnable = (
|
||||
prompt
|
||||
| fake_llm
|
||||
| {
|
||||
"original": RunnablePassthrough(input_type=str),
|
||||
"as_list": list_parser,
|
||||
}
|
||||
)
|
||||
graph = sequence.get_graph()
|
||||
assert graph.draw_ascii() == snapshot
|
@@ -4127,13 +4127,13 @@ def test_representation_of_runnables() -> None:
|
||||
"""Return 2."""
|
||||
return 2
|
||||
|
||||
assert repr(RunnableLambda(func=f)) == "RunnableLambda(...)"
|
||||
assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"
|
||||
|
||||
async def af(x: int) -> int:
|
||||
"""Return 2."""
|
||||
return 2
|
||||
|
||||
assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(...)"
|
||||
assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(f)"
|
||||
|
||||
assert repr(
|
||||
RunnableLambda(lambda x: x + 2)
|
||||
|
Reference in New Issue
Block a user