mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
Use a shared executor for all parallel calls
This commit is contained in:
parent
a40c12bb88
commit
d414d47c78
@ -7,11 +7,12 @@ from langchain.schema.runnable.base import (
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||
|
||||
__all__ = [
|
||||
"patch_config",
|
||||
"GetLocalVar",
|
||||
"PutLocalVar",
|
||||
"RouterInput",
|
||||
|
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
from copy import deepcopy
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
from functools import partial
|
||||
from itertools import tee
|
||||
from typing import (
|
||||
@ -43,6 +41,7 @@ from langchain.schema.runnable.config import (
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
get_executor_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
@ -104,8 +103,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
"""
|
||||
@ -118,15 +115,19 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
if len(inputs) == 1:
|
||||
return [self.invoke(inputs[0], configs[0], **kwargs)]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return list(
|
||||
executor.map(
|
||||
partial(self.invoke, **kwargs),
|
||||
inputs,
|
||||
(patch_config(c, executor=executor) for c in configs),
|
||||
)
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
"""
|
||||
@ -136,7 +137,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
|
||||
|
||||
return await gather_with_concurrency(max_concurrency, *coros)
|
||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@ -246,7 +247,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return (
|
||||
list(map(ensure_config, config))
|
||||
if isinstance(config, list)
|
||||
else [deepcopy(ensure_config(config)) for _ in range(length)]
|
||||
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
|
||||
)
|
||||
|
||||
def _call_with_config(
|
||||
@ -527,7 +528,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
@ -560,7 +561,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
@ -580,8 +581,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
@ -615,10 +614,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, rm.get_child())
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
@ -641,14 +639,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
@ -679,10 +672,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, rm.get_child())
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
@ -782,7 +774,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@ -810,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@ -824,8 +816,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
@ -852,15 +842,17 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
# invoke
|
||||
try:
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
for step in self.steps:
|
||||
inputs = step.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, rm.get_child())
|
||||
patch_config(
|
||||
config, callbacks=rm.get_child(), executor=executor
|
||||
)
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@ -876,8 +868,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import (
|
||||
@ -914,10 +904,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, rm.get_child())
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@ -956,7 +945,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
@ -968,12 +957,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].stream(
|
||||
input, patch_config(config, run_manager.get_child())
|
||||
input, patch_config(config, callbacks=run_manager.get_child())
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.transform(
|
||||
final_pipeline, patch_config(config, run_manager.get_child())
|
||||
final_pipeline,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
for output in final_pipeline:
|
||||
yield output
|
||||
@ -1022,7 +1012,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
@ -1034,12 +1024,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].astream(
|
||||
input, patch_config(config, run_manager.get_child())
|
||||
input, patch_config(config, callbacks=run_manager.get_child())
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.atransform(
|
||||
final_pipeline, patch_config(config, run_manager.get_child())
|
||||
final_pipeline,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
async for output in final_pipeline:
|
||||
yield output
|
||||
@ -1068,7 +1059,7 @@ class RunnableMapChunk(Dict[str, Any]):
|
||||
"""
|
||||
|
||||
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
||||
chunk = copy.deepcopy(self)
|
||||
chunk = self.copy()
|
||||
for key in other:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = other[key]
|
||||
@ -1076,6 +1067,15 @@ class RunnableMapChunk(Dict[str, Any]):
|
||||
chunk[key] += other[key]
|
||||
return chunk
|
||||
|
||||
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
||||
chunk = RunnableMapChunk(other)
|
||||
for key in self:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = self[key]
|
||||
elif self[key] is not None:
|
||||
chunk[key] += self[key]
|
||||
return chunk
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
"""
|
||||
@ -1132,13 +1132,18 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = dict(self.steps)
|
||||
with ThreadPoolExecutor() as executor:
|
||||
with get_executor_for_config(config) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
step.invoke,
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(deepcopy(config), run_manager.get_child()),
|
||||
patch_config(
|
||||
config,
|
||||
deep_copy_locals=True,
|
||||
callbacks=run_manager.get_child(),
|
||||
executor=executor,
|
||||
),
|
||||
)
|
||||
for step in steps.values()
|
||||
]
|
||||
@ -1172,7 +1177,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
)
|
||||
@ -1197,14 +1202,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
# Each step gets a copy of the input iterator,
|
||||
# which is consumed in parallel in a separate thread.
|
||||
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
||||
with ThreadPoolExecutor() as executor:
|
||||
with get_executor_for_config(config) as executor:
|
||||
# Create the transform() generator for each step
|
||||
named_generators = [
|
||||
(
|
||||
name,
|
||||
step.transform(
|
||||
input_copies.pop(),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(), executor=executor
|
||||
),
|
||||
),
|
||||
)
|
||||
for name, step in steps.items()
|
||||
@ -1265,7 +1272,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
(
|
||||
name,
|
||||
step.atransform(
|
||||
input_copies.pop(), patch_config(config, run_manager.get_child())
|
||||
input_copies.pop(),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
),
|
||||
)
|
||||
for name, step in steps.items()
|
||||
@ -1393,25 +1401,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
return self.bound.batch(
|
||||
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
|
||||
)
|
||||
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
return await self.bound.abatch(
|
||||
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
|
||||
)
|
||||
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
|
||||
|
||||
def stream(
|
||||
self,
|
||||
|
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
from typing import Any, Dict, Generator, List, Optional, TypedDict
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
@ -32,20 +35,44 @@ class RunnableConfig(TypedDict, total=False):
|
||||
Local variables
|
||||
"""
|
||||
|
||||
max_concurrency: Optional[int]
|
||||
"""
|
||||
Maximum number of parallel calls to make. If not provided, defaults to ThreadPoolExecutor's default. This is ignored if an executor is provided.
|
||||
"""
|
||||
|
||||
executor: Executor
|
||||
"""
|
||||
Externally-managed executor to use for parallel calls. If not provided, a new
|
||||
ThreadPoolExecutor will be created.
|
||||
"""
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||
empty = RunnableConfig(
|
||||
tags=[],
|
||||
metadata={},
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
)
|
||||
if config is not None:
|
||||
empty.update(config)
|
||||
return empty
|
||||
|
||||
|
||||
def patch_config(
|
||||
config: RunnableConfig,
|
||||
callbacks: BaseCallbackManager,
|
||||
config: Optional[RunnableConfig],
|
||||
*,
|
||||
deep_copy_locals: bool = False,
|
||||
callbacks: Optional[BaseCallbackManager] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
) -> RunnableConfig:
|
||||
config = config.copy()
|
||||
config = ensure_config(config)
|
||||
if deep_copy_locals:
|
||||
config["_locals"] = deepcopy(config["_locals"])
|
||||
if callbacks is not None:
|
||||
config["callbacks"] = callbacks
|
||||
if executor is not None:
|
||||
config["executor"] = executor
|
||||
return config
|
||||
|
||||
|
||||
@ -65,3 +92,12 @@ def get_async_callback_manager_for_config(
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
|
||||
if config.get("executor"):
|
||||
yield config["executor"]
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
|
||||
yield executor
|
||||
|
@ -131,26 +131,25 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
assert fake.batch(
|
||||
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||
) == [5, 7]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call(
|
||||
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||
),
|
||||
mocker.call(
|
||||
"wooorld",
|
||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||
),
|
||||
]
|
||||
|
||||
assert len(spy.call_args_list) == 2
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||
if i == 0:
|
||||
assert call.args[1].get("tags") == ["a-tag"]
|
||||
assert call.args[1].get("metadata") == {}
|
||||
else:
|
||||
assert call.args[1].get("tags") == []
|
||||
assert call.args[1].get("metadata") == {"key": "value"}
|
||||
|
||||
spy.reset_mock()
|
||||
|
||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call(
|
||||
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||
),
|
||||
mocker.call(
|
||||
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||
),
|
||||
]
|
||||
assert len(spy.call_args_list) == 2
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||
assert call.args[1].get("tags") == ["a-tag"]
|
||||
assert call.args[1].get("metadata") == {}
|
||||
spy.reset_mock()
|
||||
|
||||
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
|
||||
@ -172,11 +171,21 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
assert spy.call_args_list == [
|
||||
mocker.call(
|
||||
"hello",
|
||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||
dict(
|
||||
metadata={"key": "value"},
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
"wooorld",
|
||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||
dict(
|
||||
metadata={"key": "value"},
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user