mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Fix executor
This commit is contained in:
parent
9bb1fbcadf
commit
4e4b119614
@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from contextvars import Context, copy_context
|
||||
from contextvars import ContextVar, copy_context
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -12,6 +12,8 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
@ -94,6 +96,11 @@ class RunnableConfig(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
|
||||
var_child_runnable_config = ContextVar(
|
||||
"child_runnable_config", default=RunnableConfig()
|
||||
)
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
"""Ensure that a config is a dict with all keys present.
|
||||
|
||||
@ -110,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
)
|
||||
if var_config := var_child_runnable_config.get():
|
||||
empty.update(
|
||||
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
|
||||
)
|
||||
if config is not None:
|
||||
empty.update(
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
|
||||
@ -391,9 +402,51 @@ def get_async_callback_manager_for_config(
|
||||
)
|
||||
|
||||
|
||||
def _set_context(context: Context) -> None:
|
||||
for var, value in context.items():
|
||||
var.set(value)
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
||||
"""ThreadPoolExecutor that copies the context to the child thread."""
|
||||
|
||||
def submit( # type: ignore[override]
|
||||
self,
|
||||
func: Callable[P, T],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Future[T]:
|
||||
"""Submit a function to the executor.
|
||||
|
||||
Args:
|
||||
func (Callable[..., T]): The function to submit.
|
||||
*args (Any): The positional arguments to the function.
|
||||
**kwargs (Any): The keyword arguments to the function.
|
||||
|
||||
Returns:
|
||||
Future[T]: The future for the function.
|
||||
"""
|
||||
return super().submit(
|
||||
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
|
||||
)
|
||||
|
||||
def map(
|
||||
self,
|
||||
fn: Callable[..., T],
|
||||
*iterables: Iterable[Any],
|
||||
timeout: float | None = None,
|
||||
chunksize: int = 1,
|
||||
) -> Iterator[T]:
|
||||
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||||
|
||||
def _wrapped_fn(*args: Any) -> T:
|
||||
return contexts.pop().run(fn, *args)
|
||||
|
||||
return super().map(
|
||||
_wrapped_fn,
|
||||
*iterables,
|
||||
timeout=timeout,
|
||||
chunksize=chunksize,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -409,18 +462,12 @@ def get_executor_for_config(
|
||||
Generator[Executor, None, None]: The executor.
|
||||
"""
|
||||
config = config or {}
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=config.get("max_concurrency"),
|
||||
initializer=_set_context,
|
||||
initargs=(copy_context(),),
|
||||
with ContextThreadPoolExecutor(
|
||||
max_workers=config.get("max_concurrency")
|
||||
) as executor:
|
||||
yield executor
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_in_executor(
|
||||
executor_or_config: Optional[Union[Executor, RunnableConfig]],
|
||||
func: Callable[P, T],
|
||||
|
Loading…
Reference in New Issue
Block a user