diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 5a580697f95..3b9dc50c8bf 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -1,9 +1,9 @@ from __future__ import annotations import asyncio -from concurrent.futures import Executor, ThreadPoolExecutor +from concurrent.futures import Executor, Future, ThreadPoolExecutor from contextlib import contextmanager -from contextvars import Context, copy_context +from contextvars import ContextVar, copy_context from functools import partial from typing import ( TYPE_CHECKING, @@ -12,6 +12,8 @@ from typing import ( Callable, Dict, Generator, + Iterable, + Iterator, List, Optional, TypeVar, @@ -94,6 +96,11 @@ class RunnableConfig(TypedDict, total=False): """ +var_child_runnable_config = ContextVar( + "child_runnable_config", default=RunnableConfig() +) + + def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: """Ensure that a config is a dict with all keys present. @@ -110,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: callbacks=None, recursion_limit=25, ) + if var_config := var_child_runnable_config.get(): + empty.update( + cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None}) + ) if config is not None: empty.update( cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}) @@ -391,9 +402,51 @@ def get_async_callback_manager_for_config( ) -def _set_context(context: Context) -> None: - for var, value in context.items(): - var.set(value) +P = ParamSpec("P") +T = TypeVar("T") + + +class ContextThreadPoolExecutor(ThreadPoolExecutor): + """ThreadPoolExecutor that copies the context to the child thread.""" + + def submit( # type: ignore[override] + self, + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, + ) -> Future[T]: + """Submit a function to the executor. + + Args: + func (Callable[..., T]): The function to submit. + *args (Any): The positional arguments to the function. + **kwargs (Any): The keyword arguments to the function. + + Returns: + Future[T]: The future for the function. + """ + return super().submit( + cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)) + ) + + def map( + self, + fn: Callable[..., T], + *iterables: Iterable[Any], + timeout: float | None = None, + chunksize: int = 1, + ) -> Iterator[T]: + contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type] + + def _wrapped_fn(*args: Any) -> T: + return contexts.pop().run(fn, *args) + + return super().map( + _wrapped_fn, + *iterables, + timeout=timeout, + chunksize=chunksize, + ) @contextmanager @@ -409,18 +462,12 @@ def get_executor_for_config( Generator[Executor, None, None]: The executor. """ config = config or {} - with ThreadPoolExecutor( - max_workers=config.get("max_concurrency"), - initializer=_set_context, - initargs=(copy_context(),), + with ContextThreadPoolExecutor( + max_workers=config.get("max_concurrency") ) as executor: yield executor -P = ParamSpec("P") -T = TypeVar("T") - - async def run_in_executor( executor_or_config: Optional[Union[Executor, RunnableConfig]], func: Callable[P, T],