mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 00:48:45 +00:00
core: In RunnableSequence pass kwargs to the first step (#22393)
- This is a pattern that shows up occasionally in langgraph questions, people chain a graph to something else after, and want to pass the graph some kwargs (eg. stream_mode)
This commit is contained in:
parent
eabcfaa3d6
commit
ed8e9c437a
@ -2379,7 +2379,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
|
||||
# setup callbacks and context
|
||||
@ -2396,13 +2398,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for i, step in enumerate(self.steps):
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||
),
|
||||
# mark each step as a child run
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||
)
|
||||
if i == 0:
|
||||
input = step.invoke(input, config, **kwargs)
|
||||
else:
|
||||
input = step.invoke(input, config)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
@ -2433,13 +2436,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for i, step in enumerate(self.steps):
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||
),
|
||||
# mark each step as a child run
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||
)
|
||||
if i == 0:
|
||||
input = await step.ainvoke(input, config, **kwargs)
|
||||
else:
|
||||
input = await step.ainvoke(input, config)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
@ -2519,7 +2523,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
if i not in failed_inputs_map
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
**(kwargs if stepidx == 0 else {}),
|
||||
)
|
||||
# If an input failed, add it to the map
|
||||
for i, inp in zip(remaining_idxs, inputs):
|
||||
@ -2549,6 +2553,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**(kwargs if i == 0 else {}),
|
||||
)
|
||||
|
||||
# finish the root runs
|
||||
@ -2646,7 +2652,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
if i not in failed_inputs_map
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
**(kwargs if stepidx == 0 else {}),
|
||||
)
|
||||
# If an input failed, add it to the map
|
||||
for i, inp in zip(remaining_idxs, inputs):
|
||||
@ -2676,6 +2682,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**(kwargs if i == 0 else {}),
|
||||
)
|
||||
# finish the root runs
|
||||
except BaseException as e:
|
||||
@ -2704,6 +2712,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
input: Iterator[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
|
||||
@ -2714,14 +2723,14 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
final_pipeline = cast(Iterator[Output], input)
|
||||
for step in steps:
|
||||
final_pipeline = step.transform(
|
||||
final_pipeline,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
|
||||
),
|
||||
for idx, step in enumerate(steps):
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{idx+1}")
|
||||
)
|
||||
if idx == 0:
|
||||
final_pipeline = step.transform(final_pipeline, config, **kwargs)
|
||||
else:
|
||||
final_pipeline = step.transform(final_pipeline, config)
|
||||
|
||||
for output in final_pipeline:
|
||||
yield output
|
||||
@ -2731,6 +2740,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
input: AsyncIterator[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Output]:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
|
||||
@ -2742,14 +2752,15 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
final_pipeline = cast(AsyncIterator[Output], input)
|
||||
for step in steps:
|
||||
final_pipeline = step.atransform(
|
||||
final_pipeline,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
|
||||
),
|
||||
for idx, step in enumerate(steps):
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{idx+1}"),
|
||||
)
|
||||
if idx == 0:
|
||||
final_pipeline = step.atransform(final_pipeline, config, **kwargs)
|
||||
else:
|
||||
final_pipeline = step.atransform(final_pipeline, config)
|
||||
async for output in final_pipeline:
|
||||
yield output
|
||||
|
||||
|
@ -182,6 +182,7 @@ class FakeRunnable(Runnable[str, int]):
|
||||
self,
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
return len(input)
|
||||
|
||||
@ -1409,26 +1410,136 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
mock = mocker.Mock()
|
||||
|
||||
seq: Runnable = fake | RunnablePassthrough(mock)
|
||||
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock)
|
||||
|
||||
assert await seq.ainvoke("hello") == 5
|
||||
assert mock.call_args_list == [mocker.call(5)]
|
||||
assert await seq.ainvoke("hello", my_kwarg="value") == 5
|
||||
assert mock.call_args_list == [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
]
|
||||
mock.reset_mock()
|
||||
|
||||
assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]
|
||||
assert len(mock.call_args_list) == 4
|
||||
for call in [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call("byebye", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
mocker.call(6),
|
||||
]:
|
||||
assert call in mock.call_args_list
|
||||
mock.reset_mock()
|
||||
|
||||
assert await seq.abatch(
|
||||
["hello", "byebye"], my_kwarg="value", return_exceptions=True
|
||||
) == [
|
||||
5,
|
||||
6,
|
||||
]
|
||||
assert len(mock.call_args_list) == 4
|
||||
for call in [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call("byebye", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
mocker.call(6),
|
||||
]:
|
||||
assert call in mock.call_args_list
|
||||
mock.reset_mock()
|
||||
|
||||
assert sorted(
|
||||
[
|
||||
a
|
||||
async for a in seq.abatch_as_completed(
|
||||
["hello", "byebye"], my_kwarg="value", return_exceptions=True
|
||||
)
|
||||
]
|
||||
) == [
|
||||
(0, 5),
|
||||
(1, 6),
|
||||
]
|
||||
assert len(mock.call_args_list) == 4
|
||||
for call in [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call("byebye", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
mocker.call(6),
|
||||
]:
|
||||
assert call in mock.call_args_list
|
||||
mock.reset_mock()
|
||||
|
||||
assert [
|
||||
part async for part in seq.astream("hello", dict(metadata={"key": "value"}))
|
||||
part
|
||||
async for part in seq.astream(
|
||||
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
|
||||
)
|
||||
] == [5]
|
||||
assert mock.call_args_list == [mocker.call(5)]
|
||||
mock.reset_mock()
|
||||
|
||||
assert seq.invoke("hello") == 5
|
||||
assert mock.call_args_list == [mocker.call(5)]
|
||||
mock.reset_mock()
|
||||
|
||||
assert [part for part in seq.stream("hello", dict(metadata={"key": "value"}))] == [
|
||||
5
|
||||
assert mock.call_args_list == [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
]
|
||||
mock.reset_mock()
|
||||
|
||||
assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg]
|
||||
assert mock.call_args_list == [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
]
|
||||
mock.reset_mock()
|
||||
|
||||
assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]
|
||||
assert len(mock.call_args_list) == 4
|
||||
for call in [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call("byebye", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
mocker.call(6),
|
||||
]:
|
||||
assert call in mock.call_args_list
|
||||
mock.reset_mock()
|
||||
|
||||
assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [
|
||||
5,
|
||||
6,
|
||||
]
|
||||
assert len(mock.call_args_list) == 4
|
||||
for call in [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call("byebye", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
mocker.call(6),
|
||||
]:
|
||||
assert call in mock.call_args_list
|
||||
mock.reset_mock()
|
||||
|
||||
assert sorted(
|
||||
a
|
||||
for a in seq.batch_as_completed(
|
||||
["hello", "byebye"], my_kwarg="value", return_exceptions=True
|
||||
)
|
||||
) == [
|
||||
(0, 5),
|
||||
(1, 6),
|
||||
]
|
||||
assert len(mock.call_args_list) == 4
|
||||
for call in [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call("byebye", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
mocker.call(6),
|
||||
]:
|
||||
assert call in mock.call_args_list
|
||||
mock.reset_mock()
|
||||
|
||||
assert [
|
||||
part
|
||||
for part in seq.stream(
|
||||
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
|
||||
)
|
||||
] == [5]
|
||||
assert mock.call_args_list == [
|
||||
mocker.call("hello", my_kwarg="value"),
|
||||
mocker.call(5),
|
||||
]
|
||||
assert mock.call_args_list == [mocker.call(5)]
|
||||
mock.reset_mock()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user