mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 21:11:43 +00:00
Add .pick and .assign methods to Runnable (#15229)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
@@ -32,51 +32,51 @@
|
||||
# ---
|
||||
# name: test_graph_sequence_map
|
||||
'''
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+---------------+
|
||||
| ParallelInput |
|
||||
+---------------+******
|
||||
***** ******
|
||||
*** ******
|
||||
*** ******
|
||||
+------------------------------+ ***
|
||||
| conditional_str_parser_input | *
|
||||
+------------------------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-----------------+ +-----------------+ *
|
||||
| StrOutputParser | | XMLOutputParser | *
|
||||
+-----------------+ +-----------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-------------------------------+ +--------------------------------+
|
||||
| conditional_str_parser_output | | CommaSeparatedListOutputParser |
|
||||
+-------------------------------+ +--------------------------------+
|
||||
***** ******
|
||||
*** ******
|
||||
*** ***
|
||||
+----------------+
|
||||
| ParallelOutput |
|
||||
+----------------+
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------------------------+
|
||||
| Parallel<as_list,as_str>Input |
|
||||
+-------------------------------+
|
||||
***** ******
|
||||
*** ******
|
||||
*** ******
|
||||
+------------------------------+ ****
|
||||
| conditional_str_parser_input | *
|
||||
+------------------------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-----------------+ +-----------------+ *
|
||||
| StrOutputParser | | XMLOutputParser | *
|
||||
+-----------------+ +-----------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-------------------------------+ +--------------------------------+
|
||||
| conditional_str_parser_output | | CommaSeparatedListOutputParser |
|
||||
+-------------------------------+ +--------------------------------+
|
||||
***** ******
|
||||
*** ******
|
||||
*** ****
|
||||
+--------------------------------+
|
||||
| Parallel<as_list,as_str>Output |
|
||||
+--------------------------------+
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_single_runnable
|
||||
|
@@ -4012,7 +4012,7 @@
|
||||
'items': dict({
|
||||
'$ref': '#/definitions/PromptTemplateOutput',
|
||||
}),
|
||||
'title': 'RunnableEachOutput',
|
||||
'title': 'RunnableEach<PromptTemplate>Output',
|
||||
'type': 'array',
|
||||
})
|
||||
# ---
|
||||
|
@@ -18,6 +18,8 @@ EXPECTED_ALL = [
|
||||
"RunnableMap",
|
||||
"RunnableParallel",
|
||||
"RunnablePassthrough",
|
||||
"RunnableAssign",
|
||||
"RunnablePick",
|
||||
"RunnableSequence",
|
||||
"RunnableWithFallbacks",
|
||||
"get_config_list",
|
||||
|
@@ -64,6 +64,7 @@ from langchain_core.runnables import (
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
RunnablePick,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
add,
|
||||
@@ -510,7 +511,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
},
|
||||
"items": {"$ref": "#/definitions/PromptInput"},
|
||||
"type": "array",
|
||||
"title": "RunnableEachInput",
|
||||
"title": "RunnableEach<PromptTemplate>Input",
|
||||
}
|
||||
assert prompt_mapper.output_schema.schema() == snapshot
|
||||
|
||||
@@ -571,7 +572,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
}
|
||||
assert seq_w_map.output_schema.schema() == {
|
||||
"title": "RunnableParallelOutput",
|
||||
"title": "RunnableParallel<original,as_list,length>Output",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"original": {"title": "Original", "type": "string"},
|
||||
@@ -615,7 +616,7 @@ def test_passthrough_assign_schema() -> None:
|
||||
# expected dict input_schema
|
||||
assert invalid_seq_w_assign.input_schema.schema() == {
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
"title": "RunnableParallelInput",
|
||||
"title": "RunnableParallel<context>Input",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
@@ -774,7 +775,7 @@ def test_schema_complex_seq() -> None:
|
||||
)
|
||||
|
||||
assert chain2.input_schema.schema() == {
|
||||
"title": "RunnableParallelInput",
|
||||
"title": "RunnableParallel<city,language>Input",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {"title": "Person", "type": "string"},
|
||||
@@ -2160,8 +2161,8 @@ async def test_stream_log_retriever() -> None:
|
||||
"FakeListLLM:2",
|
||||
"Retriever",
|
||||
"RunnableLambda",
|
||||
"RunnableParallel",
|
||||
"RunnableParallel:2",
|
||||
"RunnableParallel<documents,question>",
|
||||
"RunnableParallel<one,two>",
|
||||
]
|
||||
|
||||
|
||||
@@ -2444,7 +2445,7 @@ What is your name?"""
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 4
|
||||
map_run = parent_run.child_runs[0]
|
||||
assert map_run.name == "RunnableParallel"
|
||||
assert map_run.name == "RunnableParallel<question,documents,just_to_test_lambda>"
|
||||
assert len(map_run.child_runs) == 3
|
||||
|
||||
|
||||
@@ -2505,7 +2506,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 3
|
||||
map_run = parent_run.child_runs[2]
|
||||
assert map_run.name == "RunnableParallel"
|
||||
assert map_run.name == "RunnableParallel<chat,llm>"
|
||||
assert len(map_run.child_runs) == 2
|
||||
|
||||
|
||||
@@ -2721,7 +2722,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 3
|
||||
map_run = parent_run.child_runs[2]
|
||||
assert map_run.name == "RunnableParallel"
|
||||
assert map_run.name == "RunnableParallel<chat,llm,passthrough>"
|
||||
assert len(map_run.child_runs) == 3
|
||||
|
||||
|
||||
@@ -2770,7 +2771,7 @@ def test_map_stream() -> None:
|
||||
{"question": "What is your name?"}
|
||||
)
|
||||
|
||||
chain_pick_one = chain | RunnablePassthrough.pick("llm")
|
||||
chain_pick_one = chain.pick("llm")
|
||||
|
||||
assert chain_pick_one.output_schema.schema() == {
|
||||
"title": "RunnableSequenceOutput",
|
||||
@@ -2791,10 +2792,8 @@ def test_map_stream() -> None:
|
||||
assert streamed_chunks[0] == "i"
|
||||
assert len(streamed_chunks) == len(llm_res)
|
||||
|
||||
chain_pick_two = (
|
||||
chain
|
||||
| RunnablePassthrough.assign(hello=RunnablePassthrough.pick("llm") | llm)
|
||||
| RunnablePassthrough.pick(["llm", "hello"])
|
||||
chain_pick_two = chain.assign(hello=RunnablePick("llm").pipe(llm)).pick(
|
||||
["llm", "hello"]
|
||||
)
|
||||
|
||||
assert chain_pick_two.output_schema.schema() == {
|
||||
@@ -2940,12 +2939,15 @@ async def test_map_astream() -> None:
|
||||
assert final_state.state["logs"]["ChatPromptTemplate"][
|
||||
"final_output"
|
||||
] == prompt.invoke({"question": "What is your name?"})
|
||||
assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel"
|
||||
assert (
|
||||
final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
|
||||
== "RunnableParallel<chat,llm,passthrough>"
|
||||
)
|
||||
assert sorted(final_state.state["logs"]) == [
|
||||
"ChatPromptTemplate",
|
||||
"FakeListChatModel",
|
||||
"FakeStreamingListLLM",
|
||||
"RunnableParallel",
|
||||
"RunnableParallel<chat,llm,passthrough>",
|
||||
"RunnablePassthrough",
|
||||
]
|
||||
|
||||
@@ -2985,11 +2987,14 @@ async def test_map_astream() -> None:
|
||||
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
|
||||
prompt.invoke({"question": "What is your name?"})
|
||||
)
|
||||
assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel"
|
||||
assert (
|
||||
final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
|
||||
== "RunnableParallel<chat,llm,passthrough>"
|
||||
)
|
||||
assert sorted(final_state.state["logs"]) == [
|
||||
"ChatPromptTemplate",
|
||||
"FakeStreamingListLLM",
|
||||
"RunnableParallel",
|
||||
"RunnableParallel<chat,llm,passthrough>",
|
||||
"RunnablePassthrough",
|
||||
]
|
||||
|
||||
@@ -3130,9 +3135,7 @@ def test_deep_stream_assign() -> None:
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert add(chunks) == {"str": "foo-lish"}
|
||||
|
||||
chain_with_assign = chain | RunnablePassthrough.assign(
|
||||
hello=itemgetter("str") | llm
|
||||
)
|
||||
chain_with_assign = chain.assign(hello=itemgetter("str") | llm)
|
||||
|
||||
assert chain_with_assign.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
@@ -3179,7 +3182,7 @@ def test_deep_stream_assign() -> None:
|
||||
"hello": "foo-lish",
|
||||
}
|
||||
|
||||
chain_with_assign_shadow = chain | RunnablePassthrough.assign(
|
||||
chain_with_assign_shadow = chain.assign(
|
||||
str=lambda _: "shadow",
|
||||
hello=itemgetter("str") | llm,
|
||||
)
|
||||
@@ -3254,7 +3257,7 @@ async def test_deep_astream_assign() -> None:
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert add(chunks) == {"str": "foo-lish"}
|
||||
|
||||
chain_with_assign = chain | RunnablePassthrough.assign(
|
||||
chain_with_assign = chain.assign(
|
||||
hello=itemgetter("str") | llm,
|
||||
)
|
||||
|
||||
@@ -4473,15 +4476,15 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||
tracer = FakeTracer()
|
||||
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||
|
||||
tracer = FakeTracer()
|
||||
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||
pass
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||
|
||||
|
||||
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||
@@ -4493,15 +4496,15 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||
tracer = FakeTracer()
|
||||
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||
|
||||
tracer = FakeTracer()
|
||||
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||
pass
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||
|
||||
|
||||
async def test_astream_log_deep_copies() -> None:
|
||||
|
Reference in New Issue
Block a user