mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
core: In RunnableSequence pass kwargs to the first step (#22393)
- This is a pattern that shows up occasionally in langgraph questions, people chain a graph to something else after, and want to pass the graph some kwargs (eg. stream_mode)
This commit is contained in:
parent
eabcfaa3d6
commit
ed8e9c437a
@ -2379,7 +2379,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
name=self.name,
|
name=self.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
from langchain_core.beta.runnables.context import config_with_context
|
from langchain_core.beta.runnables.context import config_with_context
|
||||||
|
|
||||||
# setup callbacks and context
|
# setup callbacks and context
|
||||||
@ -2396,13 +2398,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
for i, step in enumerate(self.steps):
|
for i, step in enumerate(self.steps):
|
||||||
input = step.invoke(
|
# mark each step as a child run
|
||||||
input,
|
config = patch_config(
|
||||||
# mark each step as a child run
|
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||||
patch_config(
|
|
||||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
if i == 0:
|
||||||
|
input = step.invoke(input, config, **kwargs)
|
||||||
|
else:
|
||||||
|
input = step.invoke(input, config)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
@ -2433,13 +2436,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
for i, step in enumerate(self.steps):
|
for i, step in enumerate(self.steps):
|
||||||
input = await step.ainvoke(
|
# mark each step as a child run
|
||||||
input,
|
config = patch_config(
|
||||||
# mark each step as a child run
|
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||||
patch_config(
|
|
||||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
if i == 0:
|
||||||
|
input = await step.ainvoke(input, config, **kwargs)
|
||||||
|
else:
|
||||||
|
input = await step.ainvoke(input, config)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
@ -2519,7 +2523,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
if i not in failed_inputs_map
|
if i not in failed_inputs_map
|
||||||
],
|
],
|
||||||
return_exceptions=return_exceptions,
|
return_exceptions=return_exceptions,
|
||||||
**kwargs,
|
**(kwargs if stepidx == 0 else {}),
|
||||||
)
|
)
|
||||||
# If an input failed, add it to the map
|
# If an input failed, add it to the map
|
||||||
for i, inp in zip(remaining_idxs, inputs):
|
for i, inp in zip(remaining_idxs, inputs):
|
||||||
@ -2549,6 +2553,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**(kwargs if i == 0 else {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# finish the root runs
|
# finish the root runs
|
||||||
@ -2646,7 +2652,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
if i not in failed_inputs_map
|
if i not in failed_inputs_map
|
||||||
],
|
],
|
||||||
return_exceptions=return_exceptions,
|
return_exceptions=return_exceptions,
|
||||||
**kwargs,
|
**(kwargs if stepidx == 0 else {}),
|
||||||
)
|
)
|
||||||
# If an input failed, add it to the map
|
# If an input failed, add it to the map
|
||||||
for i, inp in zip(remaining_idxs, inputs):
|
for i, inp in zip(remaining_idxs, inputs):
|
||||||
@ -2676,6 +2682,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**(kwargs if i == 0 else {}),
|
||||||
)
|
)
|
||||||
# finish the root runs
|
# finish the root runs
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
@ -2704,6 +2712,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
input: Iterator[Input],
|
input: Iterator[Input],
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
from langchain_core.beta.runnables.context import config_with_context
|
from langchain_core.beta.runnables.context import config_with_context
|
||||||
|
|
||||||
@ -2714,14 +2723,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
# steps that don't natively support transforming an input stream will
|
# steps that don't natively support transforming an input stream will
|
||||||
# buffer input in memory until all available, and then start emitting output
|
# buffer input in memory until all available, and then start emitting output
|
||||||
final_pipeline = cast(Iterator[Output], input)
|
final_pipeline = cast(Iterator[Output], input)
|
||||||
for step in steps:
|
for idx, step in enumerate(steps):
|
||||||
final_pipeline = step.transform(
|
config = patch_config(
|
||||||
final_pipeline,
|
config, callbacks=run_manager.get_child(f"seq:step:{idx+1}")
|
||||||
patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
if idx == 0:
|
||||||
|
final_pipeline = step.transform(final_pipeline, config, **kwargs)
|
||||||
|
else:
|
||||||
|
final_pipeline = step.transform(final_pipeline, config)
|
||||||
|
|
||||||
for output in final_pipeline:
|
for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
@ -2731,6 +2740,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
input: AsyncIterator[Input],
|
input: AsyncIterator[Input],
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||||
|
|
||||||
@ -2742,14 +2752,15 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
# steps that don't natively support transforming an input stream will
|
# steps that don't natively support transforming an input stream will
|
||||||
# buffer input in memory until all available, and then start emitting output
|
# buffer input in memory until all available, and then start emitting output
|
||||||
final_pipeline = cast(AsyncIterator[Output], input)
|
final_pipeline = cast(AsyncIterator[Output], input)
|
||||||
for step in steps:
|
for idx, step in enumerate(steps):
|
||||||
final_pipeline = step.atransform(
|
config = patch_config(
|
||||||
final_pipeline,
|
config,
|
||||||
patch_config(
|
callbacks=run_manager.get_child(f"seq:step:{idx+1}"),
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
if idx == 0:
|
||||||
|
final_pipeline = step.atransform(final_pipeline, config, **kwargs)
|
||||||
|
else:
|
||||||
|
final_pipeline = step.atransform(final_pipeline, config)
|
||||||
async for output in final_pipeline:
|
async for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
@ -182,6 +182,7 @@ class FakeRunnable(Runnable[str, int]):
|
|||||||
self,
|
self,
|
||||||
input: str,
|
input: str,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> int:
|
) -> int:
|
||||||
return len(input)
|
return len(input)
|
||||||
|
|
||||||
@ -1409,26 +1410,136 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
|
|||||||
fake = FakeRunnable()
|
fake = FakeRunnable()
|
||||||
mock = mocker.Mock()
|
mock = mocker.Mock()
|
||||||
|
|
||||||
seq: Runnable = fake | RunnablePassthrough(mock)
|
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock)
|
||||||
|
|
||||||
assert await seq.ainvoke("hello") == 5
|
assert await seq.ainvoke("hello", my_kwarg="value") == 5
|
||||||
assert mock.call_args_list == [mocker.call(5)]
|
assert mock.call_args_list == [
|
||||||
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
|
]
|
||||||
|
mock.reset_mock()
|
||||||
|
|
||||||
|
assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]
|
||||||
|
assert len(mock.call_args_list) == 4
|
||||||
|
for call in [
|
||||||
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call("byebye", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
|
mocker.call(6),
|
||||||
|
]:
|
||||||
|
assert call in mock.call_args_list
|
||||||
|
mock.reset_mock()
|
||||||
|
|
||||||
|
assert await seq.abatch(
|
||||||
|
["hello", "byebye"], my_kwarg="value", return_exceptions=True
|
||||||
|
) == [
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
]
|
||||||
|
assert len(mock.call_args_list) == 4
|
||||||
|
for call in [
|
||||||
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call("byebye", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
|
mocker.call(6),
|
||||||
|
]:
|
||||||
|
assert call in mock.call_args_list
|
||||||
|
mock.reset_mock()
|
||||||
|
|
||||||
|
assert sorted(
|
||||||
|
[
|
||||||
|
a
|
||||||
|
async for a in seq.abatch_as_completed(
|
||||||
|
["hello", "byebye"], my_kwarg="value", return_exceptions=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
) == [
|
||||||
|
(0, 5),
|
||||||
|
(1, 6),
|
||||||
|
]
|
||||||
|
assert len(mock.call_args_list) == 4
|
||||||
|
for call in [
|
||||||
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call("byebye", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
|
mocker.call(6),
|
||||||
|
]:
|
||||||
|
assert call in mock.call_args_list
|
||||||
mock.reset_mock()
|
mock.reset_mock()
|
||||||
|
|
||||||
assert [
|
assert [
|
||||||
part async for part in seq.astream("hello", dict(metadata={"key": "value"}))
|
part
|
||||||
|
async for part in seq.astream(
|
||||||
|
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
|
||||||
|
)
|
||||||
] == [5]
|
] == [5]
|
||||||
assert mock.call_args_list == [mocker.call(5)]
|
assert mock.call_args_list == [
|
||||||
mock.reset_mock()
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
assert seq.invoke("hello") == 5
|
]
|
||||||
assert mock.call_args_list == [mocker.call(5)]
|
mock.reset_mock()
|
||||||
mock.reset_mock()
|
|
||||||
|
assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg]
|
||||||
assert [part for part in seq.stream("hello", dict(metadata={"key": "value"}))] == [
|
assert mock.call_args_list == [
|
||||||
5
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
|
]
|
||||||
|
mock.reset_mock()
|
||||||
|
|
||||||
|
assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]
|
||||||
|
assert len(mock.call_args_list) == 4
|
||||||
|
for call in [
|
||||||
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call("byebye", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
|
mocker.call(6),
|
||||||
|
]:
|
||||||
|
assert call in mock.call_args_list
|
||||||
|
mock.reset_mock()
|
||||||
|
|
||||||
|
assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
]
|
||||||
|
assert len(mock.call_args_list) == 4
|
||||||
|
for call in [
|
||||||
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call("byebye", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
|
mocker.call(6),
|
||||||
|
]:
|
||||||
|
assert call in mock.call_args_list
|
||||||
|
mock.reset_mock()
|
||||||
|
|
||||||
|
assert sorted(
|
||||||
|
a
|
||||||
|
for a in seq.batch_as_completed(
|
||||||
|
["hello", "byebye"], my_kwarg="value", return_exceptions=True
|
||||||
|
)
|
||||||
|
) == [
|
||||||
|
(0, 5),
|
||||||
|
(1, 6),
|
||||||
|
]
|
||||||
|
assert len(mock.call_args_list) == 4
|
||||||
|
for call in [
|
||||||
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call("byebye", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
|
mocker.call(6),
|
||||||
|
]:
|
||||||
|
assert call in mock.call_args_list
|
||||||
|
mock.reset_mock()
|
||||||
|
|
||||||
|
assert [
|
||||||
|
part
|
||||||
|
for part in seq.stream(
|
||||||
|
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
|
||||||
|
)
|
||||||
|
] == [5]
|
||||||
|
assert mock.call_args_list == [
|
||||||
|
mocker.call("hello", my_kwarg="value"),
|
||||||
|
mocker.call(5),
|
||||||
]
|
]
|
||||||
assert mock.call_args_list == [mocker.call(5)]
|
|
||||||
mock.reset_mock()
|
mock.reset_mock()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user