Fix executor

This commit is contained in:
Nuno Campos 2023-12-29 15:50:45 -08:00
parent 9bb1fbcadf
commit 4e4b119614

View File

@ -1,9 +1,9 @@
from __future__ import annotations
import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import Context, copy_context
from contextvars import ContextVar, copy_context
from functools import partial
from typing import (
TYPE_CHECKING,
@ -12,6 +12,8 @@ from typing import (
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
TypeVar,
@ -94,6 +96,11 @@ class RunnableConfig(TypedDict, total=False):
"""
var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig()
)
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present.
@ -110,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
callbacks=None,
recursion_limit=25,
)
if var_config := var_child_runnable_config.get():
empty.update(
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
)
if config is not None:
empty.update(
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
@ -391,9 +402,51 @@ def get_async_callback_manager_for_config(
)
def _set_context(context: Context) -> None:
for var, value in context.items():
var.set(value)
P = ParamSpec("P")
T = TypeVar("T")
class ContextThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor that copies the context to the child thread."""
def submit( # type: ignore[override]
self,
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""Submit a function to the executor.
Args:
func (Callable[..., T]): The function to submit.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Future[T]: The future for the function.
"""
return super().submit(
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
)
def map(
self,
fn: Callable[..., T],
*iterables: Iterable[Any],
timeout: float | None = None,
chunksize: int = 1,
) -> Iterator[T]:
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
def _wrapped_fn(*args: Any) -> T:
return contexts.pop().run(fn, *args)
return super().map(
_wrapped_fn,
*iterables,
timeout=timeout,
chunksize=chunksize,
)
@contextmanager
@ -409,18 +462,12 @@ def get_executor_for_config(
Generator[Executor, None, None]: The executor.
"""
config = config or {}
with ThreadPoolExecutor(
max_workers=config.get("max_concurrency"),
initializer=_set_context,
initargs=(copy_context(),),
with ContextThreadPoolExecutor(
max_workers=config.get("max_concurrency")
) as executor:
yield executor
P = ParamSpec("P")
T = TypeVar("T")
async def run_in_executor(
executor_or_config: Optional[Union[Executor, RunnableConfig]],
func: Callable[P, T],