Use a shared executor for all parallel calls

This commit is contained in:
Nuno Campos 2023-08-18 11:34:18 +01:00
parent a40c12bb88
commit d414d47c78
4 changed files with 138 additions and 92 deletions

View File

@ -7,11 +7,12 @@ from langchain.schema.runnable.base import (
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableWithFallbacks,
) )
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [ __all__ = [
"patch_config",
"GetLocalVar", "GetLocalVar",
"PutLocalVar", "PutLocalVar",
"RouterInput", "RouterInput",

View File

@ -1,11 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import copy
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from concurrent.futures import FIRST_COMPLETED, wait
from copy import deepcopy
from functools import partial from functools import partial
from itertools import tee from itertools import tee
from typing import ( from typing import (
@ -43,6 +41,7 @@ from langchain.schema.runnable.config import (
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
get_callback_manager_for_config, get_callback_manager_for_config,
get_executor_for_config,
patch_config, patch_config,
) )
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
@ -104,8 +103,6 @@ class Runnable(Generic[Input, Output], ABC):
self, self,
inputs: List[Input], inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
""" """
@ -118,15 +115,19 @@ class Runnable(Generic[Input, Output], ABC):
if len(inputs) == 1: if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0], **kwargs)] return [self.invoke(inputs[0], configs[0], **kwargs)]
with ThreadPoolExecutor(max_workers=max_concurrency) as executor: with get_executor_for_config(configs[0]) as executor:
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs)) return list(
executor.map(
partial(self.invoke, **kwargs),
inputs,
(patch_config(c, executor=executor) for c in configs),
)
)
async def abatch( async def abatch(
self, self,
inputs: List[Input], inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
""" """
@ -136,7 +137,7 @@ class Runnable(Generic[Input, Output], ABC):
configs = self._get_config_list(config, len(inputs)) configs = self._get_config_list(config, len(inputs))
coros = map(partial(self.ainvoke, **kwargs), inputs, configs) coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
return await gather_with_concurrency(max_concurrency, *coros) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
def stream( def stream(
self, self,
@ -246,7 +247,7 @@ class Runnable(Generic[Input, Output], ABC):
return ( return (
list(map(ensure_config, config)) list(map(ensure_config, config))
if isinstance(config, list) if isinstance(config, list)
else [deepcopy(ensure_config(config)) for _ in range(length)] else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
) )
def _call_with_config( def _call_with_config(
@ -527,7 +528,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try: try:
output = runnable.invoke( output = runnable.invoke(
input, input,
patch_config(config, run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
) )
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
@ -560,7 +561,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try: try:
output = await runnable.ainvoke( output = await runnable.ainvoke(
input, input,
patch_config(config, run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
) )
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
@ -580,8 +581,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self, self,
inputs: List[Input], inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
@ -615,10 +614,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(config, rm.get_child()) patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
max_concurrency=max_concurrency,
) )
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
@ -641,14 +639,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self, self,
inputs: List[Input], inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import ( from langchain.callbacks.manager import AsyncCallbackManager
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = self._get_config_list(config, len(inputs))
@ -679,10 +672,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(config, rm.get_child()) patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
max_concurrency=max_concurrency,
) )
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
@ -782,7 +774,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke( input = step.invoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
@ -810,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke( input = await step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
@ -824,8 +816,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, self,
inputs: List[Input], inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
@ -852,16 +842,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke # invoke
try: try:
for step in self.steps: with get_executor_for_config(configs[0]) as executor:
inputs = step.batch( for step in self.steps:
inputs, inputs = step.batch(
[ inputs,
# each step a child run of the corresponding root run [
patch_config(config, rm.get_child()) # each step a child run of the corresponding root run
for rm, config in zip(run_managers, configs) patch_config(
], config, callbacks=rm.get_child(), executor=executor
max_concurrency=max_concurrency, )
) for rm, config in zip(run_managers, configs)
],
)
# finish the root runs # finish the root runs
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
for rm in run_managers: for rm in run_managers:
@ -876,8 +868,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, self,
inputs: List[Input], inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -914,10 +904,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(config, rm.get_child()) patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
max_concurrency=max_concurrency,
) )
# finish the root runs # finish the root runs
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
@ -956,7 +945,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke( input = step.invoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
@ -968,12 +957,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try: try:
# stream the first of the last steps with non-streaming input # stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream( final_pipeline = steps[streaming_start_index].stream(
input, patch_config(config, run_manager.get_child()) input, patch_config(config, callbacks=run_manager.get_child())
) )
# stream the rest of the last steps with streaming input # stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]: for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform( final_pipeline = step.transform(
final_pipeline, patch_config(config, run_manager.get_child()) final_pipeline,
patch_config(config, callbacks=run_manager.get_child()),
) )
for output in final_pipeline: for output in final_pipeline:
yield output yield output
@ -1022,7 +1012,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke( input = await step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
@ -1034,12 +1024,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try: try:
# stream the first of the last steps with non-streaming input # stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream( final_pipeline = steps[streaming_start_index].astream(
input, patch_config(config, run_manager.get_child()) input, patch_config(config, callbacks=run_manager.get_child())
) )
# stream the rest of the last steps with streaming input # stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]: for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform( final_pipeline = step.atransform(
final_pipeline, patch_config(config, run_manager.get_child()) final_pipeline,
patch_config(config, callbacks=run_manager.get_child()),
) )
async for output in final_pipeline: async for output in final_pipeline:
yield output yield output
@ -1068,7 +1059,7 @@ class RunnableMapChunk(Dict[str, Any]):
""" """
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk: def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = copy.deepcopy(self) chunk = self.copy()
for key in other: for key in other:
if key not in chunk or chunk[key] is None: if key not in chunk or chunk[key] is None:
chunk[key] = other[key] chunk[key] = other[key]
@ -1076,6 +1067,15 @@ class RunnableMapChunk(Dict[str, Any]):
chunk[key] += other[key] chunk[key] += other[key]
return chunk return chunk
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] += self[key]
return chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
""" """
@ -1132,13 +1132,18 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
try: try:
# copy to avoid issues from the caller mutating the steps during invoke() # copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps) steps = dict(self.steps)
with ThreadPoolExecutor() as executor: with get_executor_for_config(config) as executor:
futures = [ futures = [
executor.submit( executor.submit(
step.invoke, step.invoke,
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(deepcopy(config), run_manager.get_child()), patch_config(
config,
deep_copy_locals=True,
callbacks=run_manager.get_child(),
executor=executor,
),
) )
for step in steps.values() for step in steps.values()
] ]
@ -1172,7 +1177,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.ainvoke( step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
) )
for step in steps.values() for step in steps.values()
) )
@ -1197,14 +1202,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
# Each step gets a copy of the input iterator, # Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread. # which is consumed in parallel in a separate thread.
input_copies = list(safetee(input, len(steps), lock=threading.Lock())) input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
with ThreadPoolExecutor() as executor: with get_executor_for_config(config) as executor:
# Create the transform() generator for each step # Create the transform() generator for each step
named_generators = [ named_generators = [
( (
name, name,
step.transform( step.transform(
input_copies.pop(), input_copies.pop(),
patch_config(config, run_manager.get_child()), patch_config(
config, callbacks=run_manager.get_child(), executor=executor
),
), ),
) )
for name, step in steps.items() for name, step in steps.items()
@ -1265,7 +1272,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
( (
name, name,
step.atransform( step.atransform(
input_copies.pop(), patch_config(config, run_manager.get_child()) input_copies.pop(),
patch_config(config, callbacks=run_manager.get_child()),
), ),
) )
for name, step in steps.items() for name, step in steps.items()
@ -1393,25 +1401,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
self, self,
inputs: List[Input], inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
return self.bound.batch( return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
)
async def abatch( async def abatch(
self, self,
inputs: List[Input], inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
return await self.bound.abatch( return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
)
def stream( def stream(
self, self,

View File

@ -1,6 +1,9 @@
from __future__ import annotations from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Dict, List, Optional, TypedDict from typing import Any, Dict, Generator, List, Optional, TypedDict
from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
@ -32,20 +35,44 @@ class RunnableConfig(TypedDict, total=False):
Local variables Local variables
""" """
max_concurrency: Optional[int]
"""
Maximum number of parallel calls to make. If not provided, defaults to ThreadPoolExecutor's default. This is ignored if an executor is provided.
"""
executor: Executor
"""
Externally-managed executor to use for parallel calls. If not provided, a new
ThreadPoolExecutor will be created.
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) empty = RunnableConfig(
tags=[],
metadata={},
callbacks=None,
_locals={},
)
if config is not None: if config is not None:
empty.update(config) empty.update(config)
return empty return empty
def patch_config( def patch_config(
config: RunnableConfig, config: Optional[RunnableConfig],
callbacks: BaseCallbackManager, *,
deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None,
executor: Optional[Executor] = None,
) -> RunnableConfig: ) -> RunnableConfig:
config = config.copy() config = ensure_config(config)
config["callbacks"] = callbacks if deep_copy_locals:
config["_locals"] = deepcopy(config["_locals"])
if callbacks is not None:
config["callbacks"] = callbacks
if executor is not None:
config["executor"] = executor
return config return config
@ -65,3 +92,12 @@ def get_async_callback_manager_for_config(
inheritable_tags=config.get("tags"), inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"), inheritable_metadata=config.get("metadata"),
) )
@contextmanager
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
if config.get("executor"):
yield config["executor"]
else:
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
yield executor

View File

@ -131,26 +131,25 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch( assert fake.batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7] ) == [5, 7]
assert spy.call_args_list == [
mocker.call( assert len(spy.call_args_list) == 2
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) for i, call in enumerate(spy.call_args_list):
), assert call.args[0] == ("hello" if i == 0 else "wooorld")
mocker.call( if i == 0:
"wooorld", assert call.args[1].get("tags") == ["a-tag"]
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), assert call.args[1].get("metadata") == {}
), else:
] assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}
spy.reset_mock() spy.reset_mock()
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert spy.call_args_list == [ assert len(spy.call_args_list) == 2
mocker.call( for i, call in enumerate(spy.call_args_list):
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) assert call.args[0] == ("hello" if i == 0 else "wooorld")
), assert call.args[1].get("tags") == ["a-tag"]
mocker.call( assert call.args[1].get("metadata") == {}
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
]
spy.reset_mock() spy.reset_mock()
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
@ -172,11 +171,21 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call( mocker.call(
"hello", "hello",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
_locals={},
),
), ),
mocker.call( mocker.call(
"wooorld", "wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
_locals={},
),
), ),
] ]