mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Use a shared executor for all parallel calls
This commit is contained in:
parent
a40c12bb88
commit
d414d47c78
@ -7,11 +7,12 @@ from langchain.schema.runnable.base import (
|
|||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"patch_config",
|
||||||
"GetLocalVar",
|
"GetLocalVar",
|
||||||
"PutLocalVar",
|
"PutLocalVar",
|
||||||
"RouterInput",
|
"RouterInput",
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
from concurrent.futures import FIRST_COMPLETED, wait
|
||||||
from copy import deepcopy
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import tee
|
from itertools import tee
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -43,6 +41,7 @@ from langchain.schema.runnable.config import (
|
|||||||
ensure_config,
|
ensure_config,
|
||||||
get_async_callback_manager_for_config,
|
get_async_callback_manager_for_config,
|
||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
|
get_executor_for_config,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.utils import (
|
from langchain.schema.runnable.utils import (
|
||||||
@ -104,8 +103,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
"""
|
"""
|
||||||
@ -118,15 +115,19 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
return [self.invoke(inputs[0], configs[0], **kwargs)]
|
return [self.invoke(inputs[0], configs[0], **kwargs)]
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
|
return list(
|
||||||
|
executor.map(
|
||||||
|
partial(self.invoke, **kwargs),
|
||||||
|
inputs,
|
||||||
|
(patch_config(c, executor=executor) for c in configs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
"""
|
"""
|
||||||
@ -136,7 +137,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
configs = self._get_config_list(config, len(inputs))
|
configs = self._get_config_list(config, len(inputs))
|
||||||
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
|
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
|
||||||
|
|
||||||
return await gather_with_concurrency(max_concurrency, *coros)
|
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
@ -246,7 +247,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return (
|
return (
|
||||||
list(map(ensure_config, config))
|
list(map(ensure_config, config))
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [deepcopy(ensure_config(config)) for _ in range(length)]
|
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call_with_config(
|
def _call_with_config(
|
||||||
@ -527,7 +528,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
output = runnable.invoke(
|
output = runnable.invoke(
|
||||||
input,
|
input,
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
@ -560,7 +561,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
output = await runnable.ainvoke(
|
output = await runnable.ainvoke(
|
||||||
input,
|
input,
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
@ -580,8 +581,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
@ -615,10 +614,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
inputs,
|
inputs,
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
patch_config(config, rm.get_child())
|
patch_config(config, callbacks=rm.get_child())
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
max_concurrency=max_concurrency,
|
|
||||||
)
|
)
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
@ -641,14 +639,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
AsyncCallbackManager,
|
|
||||||
AsyncCallbackManagerForChainRun,
|
|
||||||
)
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = self._get_config_list(config, len(inputs))
|
||||||
@ -679,10 +672,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
inputs,
|
inputs,
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
patch_config(config, rm.get_child())
|
patch_config(config, callbacks=rm.get_child())
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
max_concurrency=max_concurrency,
|
|
||||||
)
|
)
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
@ -782,7 +774,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
input = step.invoke(
|
input = step.invoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
@ -810,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
input = await step.ainvoke(
|
input = await step.ainvoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
@ -824,8 +816,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
@ -852,15 +842,17 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
# invoke
|
# invoke
|
||||||
try:
|
try:
|
||||||
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
for step in self.steps:
|
for step in self.steps:
|
||||||
inputs = step.batch(
|
inputs = step.batch(
|
||||||
inputs,
|
inputs,
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
patch_config(config, rm.get_child())
|
patch_config(
|
||||||
|
config, callbacks=rm.get_child(), executor=executor
|
||||||
|
)
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
max_concurrency=max_concurrency,
|
|
||||||
)
|
)
|
||||||
# finish the root runs
|
# finish the root runs
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
@ -876,8 +868,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -914,10 +904,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
inputs,
|
inputs,
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
patch_config(config, rm.get_child())
|
patch_config(config, callbacks=rm.get_child())
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
max_concurrency=max_concurrency,
|
|
||||||
)
|
)
|
||||||
# finish the root runs
|
# finish the root runs
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
@ -956,7 +945,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
input = step.invoke(
|
input = step.invoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
@ -968,12 +957,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
# stream the first of the last steps with non-streaming input
|
# stream the first of the last steps with non-streaming input
|
||||||
final_pipeline = steps[streaming_start_index].stream(
|
final_pipeline = steps[streaming_start_index].stream(
|
||||||
input, patch_config(config, run_manager.get_child())
|
input, patch_config(config, callbacks=run_manager.get_child())
|
||||||
)
|
)
|
||||||
# stream the rest of the last steps with streaming input
|
# stream the rest of the last steps with streaming input
|
||||||
for step in steps[streaming_start_index + 1 :]:
|
for step in steps[streaming_start_index + 1 :]:
|
||||||
final_pipeline = step.transform(
|
final_pipeline = step.transform(
|
||||||
final_pipeline, patch_config(config, run_manager.get_child())
|
final_pipeline,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
for output in final_pipeline:
|
for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
@ -1022,7 +1012,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
input = await step.ainvoke(
|
input = await step.ainvoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
@ -1034,12 +1024,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
# stream the first of the last steps with non-streaming input
|
# stream the first of the last steps with non-streaming input
|
||||||
final_pipeline = steps[streaming_start_index].astream(
|
final_pipeline = steps[streaming_start_index].astream(
|
||||||
input, patch_config(config, run_manager.get_child())
|
input, patch_config(config, callbacks=run_manager.get_child())
|
||||||
)
|
)
|
||||||
# stream the rest of the last steps with streaming input
|
# stream the rest of the last steps with streaming input
|
||||||
for step in steps[streaming_start_index + 1 :]:
|
for step in steps[streaming_start_index + 1 :]:
|
||||||
final_pipeline = step.atransform(
|
final_pipeline = step.atransform(
|
||||||
final_pipeline, patch_config(config, run_manager.get_child())
|
final_pipeline,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
async for output in final_pipeline:
|
async for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
@ -1068,7 +1059,7 @@ class RunnableMapChunk(Dict[str, Any]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
||||||
chunk = copy.deepcopy(self)
|
chunk = self.copy()
|
||||||
for key in other:
|
for key in other:
|
||||||
if key not in chunk or chunk[key] is None:
|
if key not in chunk or chunk[key] is None:
|
||||||
chunk[key] = other[key]
|
chunk[key] = other[key]
|
||||||
@ -1076,6 +1067,15 @@ class RunnableMapChunk(Dict[str, Any]):
|
|||||||
chunk[key] += other[key]
|
chunk[key] += other[key]
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
||||||
|
chunk = RunnableMapChunk(other)
|
||||||
|
for key in self:
|
||||||
|
if key not in chunk or chunk[key] is None:
|
||||||
|
chunk[key] = self[key]
|
||||||
|
elif self[key] is not None:
|
||||||
|
chunk[key] += self[key]
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
@ -1132,13 +1132,18 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
try:
|
try:
|
||||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||||
steps = dict(self.steps)
|
steps = dict(self.steps)
|
||||||
with ThreadPoolExecutor() as executor:
|
with get_executor_for_config(config) as executor:
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
step.invoke,
|
step.invoke,
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(deepcopy(config), run_manager.get_child()),
|
patch_config(
|
||||||
|
config,
|
||||||
|
deep_copy_locals=True,
|
||||||
|
callbacks=run_manager.get_child(),
|
||||||
|
executor=executor,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for step in steps.values()
|
for step in steps.values()
|
||||||
]
|
]
|
||||||
@ -1172,7 +1177,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
step.ainvoke(
|
step.ainvoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
)
|
)
|
||||||
for step in steps.values()
|
for step in steps.values()
|
||||||
)
|
)
|
||||||
@ -1197,14 +1202,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
# Each step gets a copy of the input iterator,
|
# Each step gets a copy of the input iterator,
|
||||||
# which is consumed in parallel in a separate thread.
|
# which is consumed in parallel in a separate thread.
|
||||||
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
||||||
with ThreadPoolExecutor() as executor:
|
with get_executor_for_config(config) as executor:
|
||||||
# Create the transform() generator for each step
|
# Create the transform() generator for each step
|
||||||
named_generators = [
|
named_generators = [
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
step.transform(
|
step.transform(
|
||||||
input_copies.pop(),
|
input_copies.pop(),
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(), executor=executor
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for name, step in steps.items()
|
for name, step in steps.items()
|
||||||
@ -1265,7 +1272,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
step.atransform(
|
step.atransform(
|
||||||
input_copies.pop(), patch_config(config, run_manager.get_child())
|
input_copies.pop(),
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for name, step in steps.items()
|
for name, step in steps.items()
|
||||||
@ -1393,25 +1401,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return self.bound.batch(
|
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
|
||||||
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return await self.bound.abatch(
|
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
|
||||||
inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
|
|
||||||
)
|
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, TypedDict
|
from typing import Any, Dict, Generator, List, Optional, TypedDict
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
@ -32,20 +35,44 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
Local variables
|
Local variables
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
max_concurrency: Optional[int]
|
||||||
|
"""
|
||||||
|
Maximum number of parallel calls to make. If not provided, defaults to ThreadPoolExecutor's default. This is ignored if an executor is provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
executor: Executor
|
||||||
|
"""
|
||||||
|
Externally-managed executor to use for parallel calls. If not provided, a new
|
||||||
|
ThreadPoolExecutor will be created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
empty = RunnableConfig(
|
||||||
|
tags=[],
|
||||||
|
metadata={},
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
empty.update(config)
|
empty.update(config)
|
||||||
return empty
|
return empty
|
||||||
|
|
||||||
|
|
||||||
def patch_config(
|
def patch_config(
|
||||||
config: RunnableConfig,
|
config: Optional[RunnableConfig],
|
||||||
callbacks: BaseCallbackManager,
|
*,
|
||||||
|
deep_copy_locals: bool = False,
|
||||||
|
callbacks: Optional[BaseCallbackManager] = None,
|
||||||
|
executor: Optional[Executor] = None,
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = config.copy()
|
config = ensure_config(config)
|
||||||
|
if deep_copy_locals:
|
||||||
|
config["_locals"] = deepcopy(config["_locals"])
|
||||||
|
if callbacks is not None:
|
||||||
config["callbacks"] = callbacks
|
config["callbacks"] = callbacks
|
||||||
|
if executor is not None:
|
||||||
|
config["executor"] = executor
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@ -65,3 +92,12 @@ def get_async_callback_manager_for_config(
|
|||||||
inheritable_tags=config.get("tags"),
|
inheritable_tags=config.get("tags"),
|
||||||
inheritable_metadata=config.get("metadata"),
|
inheritable_metadata=config.get("metadata"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
|
||||||
|
if config.get("executor"):
|
||||||
|
yield config["executor"]
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
|
||||||
|
yield executor
|
||||||
|
@ -131,26 +131,25 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
assert fake.batch(
|
assert fake.batch(
|
||||||
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||||
) == [5, 7]
|
) == [5, 7]
|
||||||
assert spy.call_args_list == [
|
|
||||||
mocker.call(
|
assert len(spy.call_args_list) == 2
|
||||||
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
for i, call in enumerate(spy.call_args_list):
|
||||||
),
|
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||||
mocker.call(
|
if i == 0:
|
||||||
"wooorld",
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
assert call.args[1].get("metadata") == {}
|
||||||
),
|
else:
|
||||||
]
|
assert call.args[1].get("tags") == []
|
||||||
|
assert call.args[1].get("metadata") == {"key": "value"}
|
||||||
|
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||||
assert spy.call_args_list == [
|
assert len(spy.call_args_list) == 2
|
||||||
mocker.call(
|
for i, call in enumerate(spy.call_args_list):
|
||||||
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||||
),
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
mocker.call(
|
assert call.args[1].get("metadata") == {}
|
||||||
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
|
||||||
),
|
|
||||||
]
|
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
|
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
|
||||||
@ -172,11 +171,21 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
assert spy.call_args_list == [
|
assert spy.call_args_list == [
|
||||||
mocker.call(
|
mocker.call(
|
||||||
"hello",
|
"hello",
|
||||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=[],
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
),
|
||||||
),
|
),
|
||||||
mocker.call(
|
mocker.call(
|
||||||
"wooorld",
|
"wooorld",
|
||||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=[],
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user