core: In RunnableSequence pass kwargs to the first step (#22393)

- This is a pattern that shows up occasionally in langgraph questions,
people chain a graph to something else after, and want to pass the graph
some kwargs (eg. stream_mode)
This commit is contained in:
Nuno Campos 2024-06-03 07:18:10 -07:00 committed by GitHub
parent eabcfaa3d6
commit ed8e9c437a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 165 additions and 43 deletions

View File

@ -2379,7 +2379,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
name=self.name,
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
from langchain_core.beta.runnables.context import config_with_context
# setup callbacks and context
@ -2396,13 +2398,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
input = step.invoke(
input,
# mark each step as a child run
patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
),
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
if i == 0:
input = step.invoke(input, config, **kwargs)
else:
input = step.invoke(input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
@ -2433,13 +2436,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
input = await step.ainvoke(
input,
# mark each step as a child run
patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
),
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
if i == 0:
input = await step.ainvoke(input, config, **kwargs)
else:
input = await step.ainvoke(input, config)
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
@ -2519,7 +2523,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
if i not in failed_inputs_map
],
return_exceptions=return_exceptions,
**kwargs,
**(kwargs if stepidx == 0 else {}),
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
@ -2549,6 +2553,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
)
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**(kwargs if i == 0 else {}),
)
# finish the root runs
@ -2646,7 +2652,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
if i not in failed_inputs_map
],
return_exceptions=return_exceptions,
**kwargs,
**(kwargs if stepidx == 0 else {}),
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
@ -2676,6 +2682,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
)
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**(kwargs if i == 0 else {}),
)
# finish the root runs
except BaseException as e:
@ -2704,6 +2712,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
from langchain_core.beta.runnables.context import config_with_context
@ -2714,14 +2723,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
final_pipeline = cast(Iterator[Output], input)
for step in steps:
final_pipeline = step.transform(
final_pipeline,
patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
),
for idx, step in enumerate(steps):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{idx+1}")
)
if idx == 0:
final_pipeline = step.transform(final_pipeline, config, **kwargs)
else:
final_pipeline = step.transform(final_pipeline, config)
for output in final_pipeline:
yield output
@ -2731,6 +2740,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context
@ -2742,14 +2752,15 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
final_pipeline = cast(AsyncIterator[Output], input)
for step in steps:
final_pipeline = step.atransform(
final_pipeline,
patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
),
for idx, step in enumerate(steps):
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx+1}"),
)
if idx == 0:
final_pipeline = step.atransform(final_pipeline, config, **kwargs)
else:
final_pipeline = step.atransform(final_pipeline, config)
async for output in final_pipeline:
yield output

View File

@ -182,6 +182,7 @@ class FakeRunnable(Runnable[str, int]):
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> int:
return len(input)
@ -1409,26 +1410,136 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
fake = FakeRunnable()
mock = mocker.Mock()
seq: Runnable = fake | RunnablePassthrough(mock)
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock)
assert await seq.ainvoke("hello") == 5
assert mock.call_args_list == [mocker.call(5)]
assert await seq.ainvoke("hello", my_kwarg="value") == 5
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
]
mock.reset_mock()
assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()
assert await seq.abatch(
["hello", "byebye"], my_kwarg="value", return_exceptions=True
) == [
5,
6,
]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()
assert sorted(
[
a
async for a in seq.abatch_as_completed(
["hello", "byebye"], my_kwarg="value", return_exceptions=True
)
]
) == [
(0, 5),
(1, 6),
]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()
assert [
part async for part in seq.astream("hello", dict(metadata={"key": "value"}))
part
async for part in seq.astream(
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
)
] == [5]
assert mock.call_args_list == [mocker.call(5)]
mock.reset_mock()
assert seq.invoke("hello") == 5
assert mock.call_args_list == [mocker.call(5)]
mock.reset_mock()
assert [part for part in seq.stream("hello", dict(metadata={"key": "value"}))] == [
5
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
]
mock.reset_mock()
assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg]
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
]
mock.reset_mock()
assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()
assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [
5,
6,
]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()
assert sorted(
a
for a in seq.batch_as_completed(
["hello", "byebye"], my_kwarg="value", return_exceptions=True
)
) == [
(0, 5),
(1, 6),
]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()
assert [
part
for part in seq.stream(
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
)
] == [5]
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
]
assert mock.call_args_list == [mocker.call(5)]
mock.reset_mock()