Use a shared executor for all parallel calls

This commit is contained in:
Nuno Campos 2023-08-18 11:34:18 +01:00
parent a40c12bb88
commit d414d47c78
4 changed files with 138 additions and 92 deletions

View File

@ -7,11 +7,12 @@ from langchain.schema.runnable.base import (
RunnableSequence,
RunnableWithFallbacks,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [
"patch_config",
"GetLocalVar",
"PutLocalVar",
"RouterInput",

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import asyncio
import copy
import threading
from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import deepcopy
from concurrent.futures import FIRST_COMPLETED, wait
from functools import partial
from itertools import tee
from typing import (
@ -43,6 +41,7 @@ from langchain.schema.runnable.config import (
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_executor_for_config,
patch_config,
)
from langchain.schema.runnable.utils import (
@ -104,8 +103,6 @@ class Runnable(Generic[Input, Output], ABC):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""
@ -118,15 +115,19 @@ class Runnable(Generic[Input, Output], ABC):
if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0], **kwargs)]
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
with get_executor_for_config(configs[0]) as executor:
return list(
executor.map(
partial(self.invoke, **kwargs),
inputs,
(patch_config(c, executor=executor) for c in configs),
)
)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""
@ -136,7 +137,7 @@ class Runnable(Generic[Input, Output], ABC):
configs = self._get_config_list(config, len(inputs))
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
return await gather_with_concurrency(max_concurrency, *coros)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
def stream(
self,
@ -246,7 +247,7 @@ class Runnable(Generic[Input, Output], ABC):
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [deepcopy(ensure_config(config)) for _ in range(length)]
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
)
def _call_with_config(
@ -527,7 +528,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try:
output = runnable.invoke(
input,
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -560,7 +561,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try:
output = await runnable.ainvoke(
input,
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -580,8 +581,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
@ -615,10 +614,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, rm.get_child())
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -641,14 +639,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
@ -679,10 +672,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, rm.get_child())
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -782,7 +774,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
@ -810,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
@ -824,8 +816,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
@ -852,15 +842,17 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke
try:
with get_executor_for_config(configs[0]) as executor:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, rm.get_child())
patch_config(
config, callbacks=rm.get_child(), executor=executor
)
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
@ -876,8 +868,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import (
@ -914,10 +904,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, rm.get_child())
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
@ -956,7 +945,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
@ -968,12 +957,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try:
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream(
input, patch_config(config, run_manager.get_child())
input, patch_config(config, callbacks=run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform(
final_pipeline, patch_config(config, run_manager.get_child())
final_pipeline,
patch_config(config, callbacks=run_manager.get_child()),
)
for output in final_pipeline:
yield output
@ -1022,7 +1012,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
@ -1034,12 +1024,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try:
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream(
input, patch_config(config, run_manager.get_child())
input, patch_config(config, callbacks=run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform(
final_pipeline, patch_config(config, run_manager.get_child())
final_pipeline,
patch_config(config, callbacks=run_manager.get_child()),
)
async for output in final_pipeline:
yield output
@ -1068,7 +1059,7 @@ class RunnableMapChunk(Dict[str, Any]):
"""
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = copy.deepcopy(self)
chunk = self.copy()
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
@ -1076,6 +1067,15 @@ class RunnableMapChunk(Dict[str, Any]):
chunk[key] += other[key]
return chunk
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] += self[key]
return chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
"""
@ -1132,13 +1132,18 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps)
with ThreadPoolExecutor() as executor:
with get_executor_for_config(config) as executor:
futures = [
executor.submit(
step.invoke,
input,
# mark each step as a child run
patch_config(deepcopy(config), run_manager.get_child()),
patch_config(
config,
deep_copy_locals=True,
callbacks=run_manager.get_child(),
executor=executor,
),
)
for step in steps.values()
]
@ -1172,7 +1177,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.ainvoke(
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(config, callbacks=run_manager.get_child()),
)
for step in steps.values()
)
@ -1197,14 +1202,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
# Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread.
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
with ThreadPoolExecutor() as executor:
with get_executor_for_config(config) as executor:
# Create the transform() generator for each step
named_generators = [
(
name,
step.transform(
input_copies.pop(),
patch_config(config, run_manager.get_child()),
patch_config(
config, callbacks=run_manager.get_child(), executor=executor
),
),
)
for name, step in steps.items()
@ -1265,7 +1272,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
(
name,
step.atransform(
input_copies.pop(), patch_config(config, run_manager.get_child())
input_copies.pop(),
patch_config(config, callbacks=run_manager.get_child()),
),
)
for name, step in steps.items()
@ -1393,25 +1401,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
return self.bound.batch(
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
)
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]:
return await self.bound.abatch(
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
)
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
def stream(
self,

View File

@ -1,6 +1,9 @@
from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Dict, List, Optional, TypedDict
from typing import Any, Dict, Generator, List, Optional, TypedDict
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
@ -32,20 +35,44 @@ class RunnableConfig(TypedDict, total=False):
Local variables
"""
max_concurrency: Optional[int]
"""
Maximum number of parallel calls to make. If not provided, defaults to ThreadPoolExecutor's default. This is ignored if an executor is provided.
"""
executor: Executor
"""
Externally-managed executor to use for parallel calls. If not provided, a new
ThreadPoolExecutor will be created.
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
empty = RunnableConfig(
tags=[],
metadata={},
callbacks=None,
_locals={},
)
if config is not None:
empty.update(config)
return empty
def patch_config(
config: RunnableConfig,
callbacks: BaseCallbackManager,
config: Optional[RunnableConfig],
*,
deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None,
executor: Optional[Executor] = None,
) -> RunnableConfig:
config = config.copy()
config = ensure_config(config)
if deep_copy_locals:
config["_locals"] = deepcopy(config["_locals"])
if callbacks is not None:
config["callbacks"] = callbacks
if executor is not None:
config["executor"] = executor
return config
@ -65,3 +92,12 @@ def get_async_callback_manager_for_config(
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
@contextmanager
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
if config.get("executor"):
yield config["executor"]
else:
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
yield executor

View File

@ -131,26 +131,25 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7]
assert spy.call_args_list == [
mocker.call(
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call(
"wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
else:
assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}
spy.reset_mock()
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert spy.call_args_list == [
mocker.call(
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call(
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
spy.reset_mock()
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
@ -172,11 +171,21 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert spy.call_args_list == [
mocker.call(
"hello",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
_locals={},
),
),
mocker.call(
"wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
_locals={},
),
),
]