Fix executor

This commit is contained in:
Nuno Campos 2023-12-29 15:50:45 -08:00
parent 9bb1fbcadf
commit 4e4b119614

View File

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import Context, copy_context from contextvars import ContextVar, copy_context
from functools import partial from functools import partial
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -12,6 +12,8 @@ from typing import (
Callable, Callable,
Dict, Dict,
Generator, Generator,
Iterable,
Iterator,
List, List,
Optional, Optional,
TypeVar, TypeVar,
@ -94,6 +96,11 @@ class RunnableConfig(TypedDict, total=False):
""" """
var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig()
)
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present. """Ensure that a config is a dict with all keys present.
@ -110,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
callbacks=None, callbacks=None,
recursion_limit=25, recursion_limit=25,
) )
if var_config := var_child_runnable_config.get():
empty.update(
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
)
if config is not None: if config is not None:
empty.update( empty.update(
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}) cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
@ -391,9 +402,51 @@ def get_async_callback_manager_for_config(
) )
def _set_context(context: Context) -> None: P = ParamSpec("P")
for var, value in context.items(): T = TypeVar("T")
var.set(value)
class ContextThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor that copies the context to the child thread."""
def submit( # type: ignore[override]
self,
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""Submit a function to the executor.
Args:
func (Callable[..., T]): The function to submit.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Future[T]: The future for the function.
"""
return super().submit(
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
)
def map(
self,
fn: Callable[..., T],
*iterables: Iterable[Any],
timeout: float | None = None,
chunksize: int = 1,
) -> Iterator[T]:
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
def _wrapped_fn(*args: Any) -> T:
return contexts.pop().run(fn, *args)
return super().map(
_wrapped_fn,
*iterables,
timeout=timeout,
chunksize=chunksize,
)
@contextmanager @contextmanager
@ -409,18 +462,12 @@ def get_executor_for_config(
Generator[Executor, None, None]: The executor. Generator[Executor, None, None]: The executor.
""" """
config = config or {} config = config or {}
with ThreadPoolExecutor( with ContextThreadPoolExecutor(
max_workers=config.get("max_concurrency"), max_workers=config.get("max_concurrency")
initializer=_set_context,
initargs=(copy_context(),),
) as executor: ) as executor:
yield executor yield executor
P = ParamSpec("P")
T = TypeVar("T")
async def run_in_executor( async def run_in_executor(
executor_or_config: Optional[Union[Executor, RunnableConfig]], executor_or_config: Optional[Union[Executor, RunnableConfig]],
func: Callable[P, T], func: Callable[P, T],