mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 14:03:26 +00:00
core: Fix implementation of trim_first_node/trim_last_node to use exact same definition of first/last node as in the getter methods (#24802)
This commit is contained in:
@@ -1097,3 +1097,65 @@
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_trim
|
||||
dict({
|
||||
'edges': list([
|
||||
dict({
|
||||
'source': '__start__',
|
||||
'target': 'ask_question',
|
||||
}),
|
||||
dict({
|
||||
'source': 'ask_question',
|
||||
'target': 'answer_question',
|
||||
}),
|
||||
dict({
|
||||
'conditional': True,
|
||||
'source': 'answer_question',
|
||||
'target': 'ask_question',
|
||||
}),
|
||||
dict({
|
||||
'conditional': True,
|
||||
'source': 'answer_question',
|
||||
'target': '__end__',
|
||||
}),
|
||||
]),
|
||||
'nodes': list([
|
||||
dict({
|
||||
'data': '__start__',
|
||||
'id': '__start__',
|
||||
'type': 'schema',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'id': list([
|
||||
'langchain',
|
||||
'schema',
|
||||
'output_parser',
|
||||
'StrOutputParser',
|
||||
]),
|
||||
'name': 'ask_question',
|
||||
}),
|
||||
'id': 'ask_question',
|
||||
'type': 'runnable',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'id': list([
|
||||
'langchain',
|
||||
'schema',
|
||||
'output_parser',
|
||||
'StrOutputParser',
|
||||
]),
|
||||
'name': 'answer_question',
|
||||
}),
|
||||
'id': 'answer_question',
|
||||
'type': 'runnable',
|
||||
}),
|
||||
dict({
|
||||
'data': '__end__',
|
||||
'id': '__end__',
|
||||
'type': 'schema',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
|
@@ -7,7 +7,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||
|
||||
|
||||
@@ -27,6 +29,42 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
assert graph.draw_ascii() == snapshot(name="ascii")
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
|
||||
graph.trim_first_node()
|
||||
first_node = graph.first_node()
|
||||
assert first_node is not None
|
||||
assert first_node.data == runnable
|
||||
|
||||
graph.trim_last_node()
|
||||
last_node = graph.last_node()
|
||||
assert last_node is not None
|
||||
assert last_node.data == runnable
|
||||
|
||||
|
||||
def test_trim(snapshot: SnapshotAssertion) -> None:
|
||||
runnable = StrOutputParser()
|
||||
|
||||
class Schema(BaseModel):
|
||||
a: int
|
||||
|
||||
graph = Graph()
|
||||
start = graph.add_node(Schema, id="__start__")
|
||||
ask = graph.add_node(runnable, id="ask_question")
|
||||
answer = graph.add_node(runnable, id="answer_question")
|
||||
end = graph.add_node(Schema, id="__end__")
|
||||
graph.add_edge(start, ask)
|
||||
graph.add_edge(ask, answer)
|
||||
graph.add_edge(answer, ask, conditional=True)
|
||||
graph.add_edge(answer, end, conditional=True)
|
||||
|
||||
assert graph.to_json() == snapshot
|
||||
assert graph.first_node() is start
|
||||
assert graph.last_node() is end
|
||||
# can't trim start or end node
|
||||
graph.trim_first_node()
|
||||
assert graph.first_node() is start
|
||||
graph.trim_last_node()
|
||||
assert graph.last_node() is end
|
||||
|
||||
|
||||
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"])
|
||||
|
Reference in New Issue
Block a user