core: Fix implementation of trim_first_node/trim_last_node to use exact same definition of first/last node as in the getter methods (#24802)

This commit is contained in:
Nuno Campos
2024-07-30 08:44:27 -07:00
committed by GitHub
parent c2706cfb9e
commit 68ecebf1ec
3 changed files with 134 additions and 28 deletions

View File

@@ -1097,3 +1097,65 @@
'''
# ---
# name: test_trim
dict({
'edges': list([
dict({
'source': '__start__',
'target': 'ask_question',
}),
dict({
'source': 'ask_question',
'target': 'answer_question',
}),
dict({
'conditional': True,
'source': 'answer_question',
'target': 'ask_question',
}),
dict({
'conditional': True,
'source': 'answer_question',
'target': '__end__',
}),
]),
'nodes': list([
dict({
'data': '__start__',
'id': '__start__',
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'schema',
'output_parser',
'StrOutputParser',
]),
'name': 'ask_question',
}),
'id': 'ask_question',
'type': 'runnable',
}),
dict({
'data': dict({
'id': list([
'langchain',
'schema',
'output_parser',
'StrOutputParser',
]),
'name': 'answer_question',
}),
'id': 'answer_question',
'type': 'runnable',
}),
dict({
'data': '__end__',
'id': '__end__',
'type': 'schema',
}),
]),
})
# ---

View File

@@ -7,7 +7,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.graph_mermaid import _escape_node_label
@@ -27,6 +29,42 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid")
graph.trim_first_node()
first_node = graph.first_node()
assert first_node is not None
assert first_node.data == runnable
graph.trim_last_node()
last_node = graph.last_node()
assert last_node is not None
assert last_node.data == runnable
def test_trim(snapshot: SnapshotAssertion) -> None:
runnable = StrOutputParser()
class Schema(BaseModel):
a: int
graph = Graph()
start = graph.add_node(Schema, id="__start__")
ask = graph.add_node(runnable, id="ask_question")
answer = graph.add_node(runnable, id="answer_question")
end = graph.add_node(Schema, id="__end__")
graph.add_edge(start, ask)
graph.add_edge(ask, answer)
graph.add_edge(answer, ask, conditional=True)
graph.add_edge(answer, end, conditional=True)
assert graph.to_json() == snapshot
assert graph.first_node() is start
assert graph.last_node() is end
# can't trim start or end node
graph.trim_first_node()
assert graph.first_node() is start
graph.trim_last_node()
assert graph.last_node() is end
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"])