mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 14:43:08 +00:00
Bug fixes for runnables (#10738)
- tools invoked in async methods would not work due to missing await - RunnableSequence.stream() was creating an extra root run by mistake, and it can simplified due to existence of default implementation for .transform() <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
6e48092746
commit
8201cae770
@ -1269,98 +1269,23 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
# setup callbacks
|
|
||||||
config = ensure_config(config)
|
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
|
||||||
# start the root run
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
|
||||||
dumpd(self), input, name=config.get("run_name")
|
|
||||||
)
|
|
||||||
|
|
||||||
steps = [self.first] + self.middle + [self.last]
|
steps = [self.first] + self.middle + [self.last]
|
||||||
streaming_start_index = 0
|
|
||||||
|
|
||||||
for i in range(len(steps) - 1, 0, -1):
|
# transform the input stream of each step with the next
|
||||||
if type(steps[i]).transform != Runnable.transform:
|
# steps that don't natively support transforming an input stream will
|
||||||
streaming_start_index = i - 1
|
# buffer input in memory until all available, and then start emitting output
|
||||||
else:
|
final_pipeline = cast(Iterator[Output], input)
|
||||||
break
|
for step in steps:
|
||||||
|
|
||||||
final_pipeline = None
|
|
||||||
gathered_input = None
|
|
||||||
if streaming_start_index == 0:
|
|
||||||
final_pipeline = steps[streaming_start_index].transform(
|
|
||||||
input,
|
|
||||||
patch_config(config, callbacks=run_manager.get_child("seq:step:1")),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
for input_chunk in input:
|
|
||||||
if gathered_input is None:
|
|
||||||
gathered_input = input_chunk
|
|
||||||
else:
|
|
||||||
gathered_input += input_chunk
|
|
||||||
# invoke the first steps
|
|
||||||
for step in steps[0:streaming_start_index]:
|
|
||||||
gathered_input = step.invoke(
|
|
||||||
gathered_input,
|
|
||||||
# mark each step as a child run
|
|
||||||
patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(
|
|
||||||
f"seq:step:{steps.index(step)+1}"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# stream the first of the last steps with the final non-streaming input
|
|
||||||
final_pipeline = steps[streaming_start_index].stream(
|
|
||||||
gathered_input,
|
|
||||||
patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(
|
|
||||||
f"seq:step:{streaming_start_index+1}"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
|
||||||
run_manager.on_chain_error(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# stream the last steps
|
|
||||||
final: Union[Output, None] = None
|
|
||||||
final_supported = True
|
|
||||||
try:
|
|
||||||
# stream the rest of the last steps with streaming input
|
|
||||||
for step in steps[streaming_start_index + 1 :]:
|
|
||||||
final_pipeline = step.transform(
|
final_pipeline = step.transform(
|
||||||
final_pipeline,
|
final_pipeline,
|
||||||
patch_config(
|
patch_config(
|
||||||
config,
|
config,
|
||||||
callbacks=run_manager.get_child(
|
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
|
||||||
f"seq:step:{steps.index(step)+1}"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for output in final_pipeline:
|
for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
# Accumulate output if possible, otherwise disable accumulation
|
|
||||||
if final_supported:
|
|
||||||
if final is None:
|
|
||||||
final = output
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
final += output # type: ignore[operator]
|
|
||||||
except TypeError:
|
|
||||||
final = None
|
|
||||||
final_supported = False
|
|
||||||
pass
|
|
||||||
# finish the root run
|
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
|
||||||
run_manager.on_chain_error(e)
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
run_manager.on_chain_end(final)
|
|
||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self,
|
self,
|
||||||
@ -1368,97 +1293,23 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
# setup callbacks
|
|
||||||
config = ensure_config(config)
|
|
||||||
callback_manager = get_async_callback_manager_for_config(config)
|
|
||||||
# start the root run
|
|
||||||
run_manager = await callback_manager.on_chain_start(
|
|
||||||
dumpd(self), input, name=config.get("run_name")
|
|
||||||
)
|
|
||||||
|
|
||||||
steps = [self.first] + self.middle + [self.last]
|
steps = [self.first] + self.middle + [self.last]
|
||||||
streaming_start_index = len(steps) - 1
|
|
||||||
|
|
||||||
for i in range(len(steps) - 1, 0, -1):
|
|
||||||
if type(steps[i]).atransform != Runnable.atransform:
|
|
||||||
streaming_start_index = i - 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
final_pipeline = None
|
|
||||||
gathered_input = None
|
|
||||||
if streaming_start_index == 0:
|
|
||||||
final_pipeline = steps[0].atransform(
|
|
||||||
input,
|
|
||||||
patch_config(config, callbacks=run_manager.get_child("seq:step:1")),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
async for input_chunk in input:
|
|
||||||
if gathered_input is None:
|
|
||||||
gathered_input = input_chunk
|
|
||||||
else:
|
|
||||||
gathered_input += input_chunk
|
|
||||||
# invoke the first steps
|
|
||||||
for step in steps[0:streaming_start_index]:
|
|
||||||
gathered_input = await step.ainvoke(
|
|
||||||
gathered_input,
|
|
||||||
# mark each step as a child run
|
|
||||||
patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(
|
|
||||||
f"seq:step:{steps.index(step)+1}"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# stream the first of the last steps with the final non-streaming input
|
|
||||||
final_pipeline = steps[streaming_start_index].astream(
|
|
||||||
gathered_input,
|
|
||||||
patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(
|
|
||||||
f"seq:step:{streaming_start_index+1}"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
|
||||||
await run_manager.on_chain_error(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# stream the last steps
|
# stream the last steps
|
||||||
final: Union[Output, None] = None
|
# transform the input stream of each step with the next
|
||||||
final_supported = True
|
# steps that don't natively support transforming an input stream will
|
||||||
try:
|
# buffer input in memory until all available, and then start emitting output
|
||||||
# stream the rest of the last steps with streaming input
|
final_pipeline = cast(AsyncIterator[Output], input)
|
||||||
for step in steps[streaming_start_index + 1 :]:
|
for step in steps:
|
||||||
final_pipeline = step.atransform(
|
final_pipeline = step.atransform(
|
||||||
final_pipeline,
|
final_pipeline,
|
||||||
patch_config(
|
patch_config(
|
||||||
config,
|
config,
|
||||||
callbacks=run_manager.get_child(
|
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
|
||||||
f"seq:step:{steps.index(step)+1}"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async for output in final_pipeline:
|
async for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
# Accumulate output if possible, otherwise disable accumulation
|
|
||||||
if final_supported:
|
|
||||||
if final is None:
|
|
||||||
final = output
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
final += output # type: ignore[operator]
|
|
||||||
except TypeError:
|
|
||||||
final = None
|
|
||||||
final_supported = False
|
|
||||||
pass
|
|
||||||
# finish the root run
|
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
|
||||||
await run_manager.on_chain_error(e)
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
await run_manager.on_chain_end(final)
|
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
|
@ -210,7 +210,7 @@ class ChildTool(BaseTool):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
if type(self)._arun == BaseTool._arun:
|
if type(self)._arun == BaseTool._arun:
|
||||||
# If the tool does not implement async, fall back to default implementation
|
# If the tool does not implement async, fall back to default implementation
|
||||||
return super().ainvoke(input, config, **kwargs)
|
return await super().ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
return await self.arun(
|
return await self.arun(
|
||||||
@ -461,7 +461,7 @@ class Tool(BaseTool):
|
|||||||
None, partial(self.invoke, input, config, **kwargs)
|
None, partial(self.invoke, input, config, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().ainvoke(input, config, **kwargs)
|
return await super().ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
# --- Tool ---
|
# --- Tool ---
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user