Runnables: Use a shared executor for all parallel calls (sync) (#9443)

Async equivalent coming in future PR

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
This commit is contained in:
Nuno Campos 2023-08-23 19:47:35 +01:00 committed by GitHub
commit dacd5dcba8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 129 additions and 90 deletions

View File

@ -7,11 +7,12 @@ from langchain.schema.runnable.base import (
RunnableSequence,
RunnableWithFallbacks,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [
"patch_config",
"GetLocalVar",
"PutLocalVar",
"RouterInput",

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import asyncio
import copy
import threading
from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import deepcopy
from concurrent.futures import FIRST_COMPLETED, wait
from functools import partial
from itertools import tee
from typing import (
@ -43,6 +41,7 @@ from langchain.schema.runnable.config import (
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_executor_for_config,
patch_config,
)
from langchain.schema.runnable.utils import (
@ -104,8 +103,6 @@ class Runnable(Generic[Input, Output], ABC):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""
@ -118,15 +115,19 @@ class Runnable(Generic[Input, Output], ABC):
if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0], **kwargs)]
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
with get_executor_for_config(configs[0]) as executor:
return list(
executor.map(
partial(self.invoke, **kwargs),
inputs,
(patch_config(c, executor=executor) for c in configs),
)
)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""
@ -136,7 +137,7 @@ class Runnable(Generic[Input, Output], ABC):
configs = self._get_config_list(config, len(inputs))
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
return await gather_with_concurrency(max_concurrency, *coros)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
def stream(
self,
@ -237,6 +238,8 @@ class Runnable(Generic[Input, Output], ABC):
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if length < 1:
raise ValueError(f"length must be >= 1, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
@ -246,7 +249,7 @@ class Runnable(Generic[Input, Output], ABC):
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [deepcopy(ensure_config(config)) for _ in range(length)]
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
)
def _call_with_config(
@ -527,7 +530,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try:
output = runnable.invoke(
input,
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -560,7 +563,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try:
output = await runnable.ainvoke(
input,
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -580,8 +583,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
@ -615,10 +616,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, rm.get_child())
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -641,14 +641,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
@ -679,10 +674,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, rm.get_child())
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -782,7 +776,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
@ -810,7 +804,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
@ -824,8 +818,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
@ -852,15 +844,17 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke
try:
with get_executor_for_config(configs[0]) as executor:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, rm.get_child())
patch_config(
config, callbacks=rm.get_child(), executor=executor
)
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
@ -876,8 +870,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import (
@ -914,10 +906,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, rm.get_child())
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
@ -956,7 +947,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
@ -968,12 +959,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try:
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream(
input, patch_config(config, run_manager.get_child())
input, patch_config(config, callbacks=run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform(
final_pipeline, patch_config(config, run_manager.get_child())
final_pipeline,
patch_config(config, callbacks=run_manager.get_child()),
)
for output in final_pipeline:
yield output
@ -1022,7 +1014,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
@ -1034,12 +1026,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try:
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream(
input, patch_config(config, run_manager.get_child())
input, patch_config(config, callbacks=run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform(
final_pipeline, patch_config(config, run_manager.get_child())
final_pipeline,
patch_config(config, callbacks=run_manager.get_child()),
)
async for output in final_pipeline:
yield output
@ -1068,7 +1061,7 @@ class RunnableMapChunk(Dict[str, Any]):
"""
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = copy.deepcopy(self)
chunk = RunnableMapChunk(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
@ -1076,6 +1069,15 @@ class RunnableMapChunk(Dict[str, Any]):
chunk[key] += other[key]
return chunk
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] += self[key]
return chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
"""
@ -1132,13 +1134,18 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps)
with ThreadPoolExecutor() as executor:
with get_executor_for_config(config) as executor:
futures = [
executor.submit(
step.invoke,
input,
# mark each step as a child run
patch_config(deepcopy(config), run_manager.get_child()),
patch_config(
config,
deep_copy_locals=True,
callbacks=run_manager.get_child(),
executor=executor,
),
)
for step in steps.values()
]
@ -1172,7 +1179,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.ainvoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
for step in steps.values()
)
@ -1197,14 +1204,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
# Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread.
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
with ThreadPoolExecutor() as executor:
with get_executor_for_config(config) as executor:
# Create the transform() generator for each step
named_generators = [
(
name,
step.transform(
input_copies.pop(),
patch_config(config, run_manager.get_child()),
patch_config(
config, callbacks=run_manager.get_child(), executor=executor
),
),
)
for name, step in steps.items()
@ -1265,7 +1274,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
(
name,
step.atransform(
input_copies.pop(), patch_config(config, run_manager.get_child())
input_copies.pop(),
patch_config(config, callbacks=run_manager.get_child()),
),
)
for name, step in steps.items()
@ -1393,25 +1403,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
return self.bound.batch(
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
)
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
return await self.bound.abatch(
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
)
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
def stream(
self,

View File

@ -1,6 +1,9 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, TypedDict
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Dict, Generator, List, Optional, TypedDict
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
@ -32,20 +35,45 @@ class RunnableConfig(TypedDict, total=False):
Local variables
"""
max_concurrency: Optional[int]
"""
Maximum number of parallel calls to make. If not provided, defaults to
ThreadPoolExecutor's default. This is ignored if an executor is provided.
"""
executor: Executor
"""
Externally-managed executor to use for parallel calls. If not provided, a new
ThreadPoolExecutor will be created.
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
empty = RunnableConfig(
tags=[],
metadata={},
callbacks=None,
_locals={},
)
if config is not None:
empty.update(config)
return empty
def patch_config(
config: RunnableConfig,
callbacks: BaseCallbackManager,
config: Optional[RunnableConfig],
*,
deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None,
executor: Optional[Executor] = None,
) -> RunnableConfig:
config = config.copy()
config = ensure_config(config)
if deep_copy_locals:
config["_locals"] = deepcopy(config["_locals"])
if callbacks is not None:
config["callbacks"] = callbacks
if executor is not None:
config["executor"] = executor
return config
@ -65,3 +93,12 @@ def get_async_callback_manager_for_config(
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
@contextmanager
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
if config.get("executor"):
yield config["executor"]
else:
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
yield executor

View File

@ -131,26 +131,25 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7]
assert spy.call_args_list == [
mocker.call(
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call(
"wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
else:
assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}
spy.reset_mock()
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert spy.call_args_list == [
mocker.call(
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call(
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
spy.reset_mock()
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5