mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
Fix executor
This commit is contained in:
parent
9bb1fbcadf
commit
4e4b119614
@ -1,9 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import Context, copy_context
|
from contextvars import ContextVar, copy_context
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@ -12,6 +12,8 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@ -94,6 +96,11 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
var_child_runnable_config = ContextVar(
|
||||||
|
"child_runnable_config", default=RunnableConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||||
"""Ensure that a config is a dict with all keys present.
|
"""Ensure that a config is a dict with all keys present.
|
||||||
|
|
||||||
@ -110,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
|||||||
callbacks=None,
|
callbacks=None,
|
||||||
recursion_limit=25,
|
recursion_limit=25,
|
||||||
)
|
)
|
||||||
|
if var_config := var_child_runnable_config.get():
|
||||||
|
empty.update(
|
||||||
|
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
|
||||||
|
)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
empty.update(
|
empty.update(
|
||||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
|
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
|
||||||
@ -391,9 +402,51 @@ def get_async_callback_manager_for_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _set_context(context: Context) -> None:
|
P = ParamSpec("P")
|
||||||
for var, value in context.items():
|
T = TypeVar("T")
|
||||||
var.set(value)
|
|
||||||
|
|
||||||
|
class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
||||||
|
"""ThreadPoolExecutor that copies the context to the child thread."""
|
||||||
|
|
||||||
|
def submit( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
func: Callable[P, T],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Future[T]:
|
||||||
|
"""Submit a function to the executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., T]): The function to submit.
|
||||||
|
*args (Any): The positional arguments to the function.
|
||||||
|
**kwargs (Any): The keyword arguments to the function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Future[T]: The future for the function.
|
||||||
|
"""
|
||||||
|
return super().submit(
|
||||||
|
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
|
||||||
|
)
|
||||||
|
|
||||||
|
def map(
|
||||||
|
self,
|
||||||
|
fn: Callable[..., T],
|
||||||
|
*iterables: Iterable[Any],
|
||||||
|
timeout: float | None = None,
|
||||||
|
chunksize: int = 1,
|
||||||
|
) -> Iterator[T]:
|
||||||
|
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def _wrapped_fn(*args: Any) -> T:
|
||||||
|
return contexts.pop().run(fn, *args)
|
||||||
|
|
||||||
|
return super().map(
|
||||||
|
_wrapped_fn,
|
||||||
|
*iterables,
|
||||||
|
timeout=timeout,
|
||||||
|
chunksize=chunksize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -409,18 +462,12 @@ def get_executor_for_config(
|
|||||||
Generator[Executor, None, None]: The executor.
|
Generator[Executor, None, None]: The executor.
|
||||||
"""
|
"""
|
||||||
config = config or {}
|
config = config or {}
|
||||||
with ThreadPoolExecutor(
|
with ContextThreadPoolExecutor(
|
||||||
max_workers=config.get("max_concurrency"),
|
max_workers=config.get("max_concurrency")
|
||||||
initializer=_set_context,
|
|
||||||
initargs=(copy_context(),),
|
|
||||||
) as executor:
|
) as executor:
|
||||||
yield executor
|
yield executor
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
async def run_in_executor(
|
async def run_in_executor(
|
||||||
executor_or_config: Optional[Union[Executor, RunnableConfig]],
|
executor_or_config: Optional[Union[Executor, RunnableConfig]],
|
||||||
func: Callable[P, T],
|
func: Callable[P, T],
|
||||||
|
Loading…
Reference in New Issue
Block a user