From d414d47c7891064be3ceefe7ea537602b410ffa5 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 11:34:18 +0100 Subject: [PATCH] Use a shared executor for all parallel calls --- .../langchain/schema/runnable/__init__.py | 3 +- .../langchain/schema/runnable/base.py | 132 +++++++++--------- .../langchain/schema/runnable/config.py | 48 ++++++- .../schema/runnable/test_runnable.py | 47 ++++--- 4 files changed, 138 insertions(+), 92 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 16f99324b0f..24b235d4d87 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -7,11 +7,12 @@ from langchain.schema.runnable.base import ( RunnableSequence, RunnableWithFallbacks, ) -from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.config import RunnableConfig, patch_config from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable __all__ = [ + "patch_config", "GetLocalVar", "PutLocalVar", "RouterInput", diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index aab395c46c8..fd0666d93d1 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1,11 +1,9 @@ from __future__ import annotations import asyncio -import copy import threading from abc import ABC, abstractmethod -from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait -from copy import deepcopy +from concurrent.futures import FIRST_COMPLETED, wait from functools import partial from itertools import tee from typing import ( @@ -43,6 +41,7 @@ from langchain.schema.runnable.config import ( ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, + get_executor_for_config, patch_config, ) from langchain.schema.runnable.utils import ( @@ -104,8 +103,6 @@ class Runnable(Generic[Input, Output], ABC): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - max_concurrency: Optional[int] = None, **kwargs: Optional[Any], ) -> List[Output]: """ @@ -118,15 +115,19 @@ class Runnable(Generic[Input, Output], ABC): if len(inputs) == 1: return [self.invoke(inputs[0], configs[0], **kwargs)] - with ThreadPoolExecutor(max_workers=max_concurrency) as executor: - return list(executor.map(partial(self.invoke, **kwargs), inputs, configs)) + with get_executor_for_config(configs[0]) as executor: + return list( + executor.map( + partial(self.invoke, **kwargs), + inputs, + (patch_config(c, executor=executor) for c in configs), + ) + ) async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - max_concurrency: Optional[int] = None, **kwargs: Optional[Any], ) -> List[Output]: """ @@ -136,7 +137,7 @@ class Runnable(Generic[Input, Output], ABC): configs = self._get_config_list(config, len(inputs)) coros = map(partial(self.ainvoke, **kwargs), inputs, configs) - return await gather_with_concurrency(max_concurrency, *coros) + return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) def stream( self, @@ -246,7 +247,7 @@ class Runnable(Generic[Input, Output], ABC): return ( list(map(ensure_config, config)) if isinstance(config, list) - else [deepcopy(ensure_config(config)) for _ in range(length)] + else [patch_config(config, deep_copy_locals=True) for _ in range(length)] ) def _call_with_config( @@ -527,7 +528,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): try: output = runnable.invoke( input, - patch_config(config, run_manager.get_child()), + patch_config(config, callbacks=run_manager.get_child()), ) except self.exceptions_to_handle as e: if first_error is None: @@ -560,7 +561,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): try: output = await runnable.ainvoke( input, - patch_config(config, run_manager.get_child()), + patch_config(config, callbacks=run_manager.get_child()), ) except self.exceptions_to_handle as e: if first_error is None: @@ -580,8 +581,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - max_concurrency: Optional[int] = None, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import CallbackManager @@ -615,10 +614,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): inputs, [ # each step a child run of the corresponding root run - patch_config(config, rm.get_child()) + patch_config(config, callbacks=rm.get_child()) for rm, config in zip(run_managers, configs) ], - max_concurrency=max_concurrency, ) except self.exceptions_to_handle as e: if first_error is None: @@ -641,14 +639,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - max_concurrency: Optional[int] = None, **kwargs: Optional[Any], ) -> List[Output]: - from langchain.callbacks.manager import ( - AsyncCallbackManager, - AsyncCallbackManagerForChainRun, - ) + from langchain.callbacks.manager import AsyncCallbackManager # setup callbacks configs = self._get_config_list(config, len(inputs)) @@ -679,10 +672,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): inputs, [ # each step a child run of the corresponding root run - patch_config(config, rm.get_child()) + patch_config(config, callbacks=rm.get_child()) for rm, config in zip(run_managers, configs) ], - max_concurrency=max_concurrency, ) except self.exceptions_to_handle as e: if first_error is None: @@ -782,7 +774,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): input = step.invoke( input, # mark each step as a child run - patch_config(config, run_manager.get_child()), + patch_config(config, callbacks=run_manager.get_child()), ) # finish the root run except (KeyboardInterrupt, Exception) as e: @@ -810,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): input = await step.ainvoke( input, # mark each step as a child run - patch_config(config, run_manager.get_child()), + patch_config(config, callbacks=run_manager.get_child()), ) # finish the root run except (KeyboardInterrupt, Exception) as e: @@ -824,8 +816,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - max_concurrency: Optional[int] = None, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import CallbackManager @@ -852,16 +842,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke try: - for step in self.steps: - inputs = step.batch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - max_concurrency=max_concurrency, - ) + with get_executor_for_config(configs[0]) as executor: + for step in self.steps: + inputs = step.batch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config( + config, callbacks=rm.get_child(), executor=executor + ) + for rm, config in zip(run_managers, configs) + ], + ) # finish the root runs except (KeyboardInterrupt, Exception) as e: for rm in run_managers: @@ -876,8 +868,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - max_concurrency: Optional[int] = None, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import ( @@ -914,10 +904,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): inputs, [ # each step a child run of the corresponding root run - patch_config(config, rm.get_child()) + patch_config(config, callbacks=rm.get_child()) for rm, config in zip(run_managers, configs) ], - max_concurrency=max_concurrency, ) # finish the root runs except (KeyboardInterrupt, Exception) as e: @@ -956,7 +945,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): input = step.invoke( input, # mark each step as a child run - patch_config(config, run_manager.get_child()), + patch_config(config, callbacks=run_manager.get_child()), ) except (KeyboardInterrupt, Exception) as e: run_manager.on_chain_error(e) @@ -968,12 +957,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): try: # stream the first of the last steps with non-streaming input final_pipeline = steps[streaming_start_index].stream( - input, patch_config(config, run_manager.get_child()) + input, patch_config(config, callbacks=run_manager.get_child()) ) # stream the rest of the last steps with streaming input for step in steps[streaming_start_index + 1 :]: final_pipeline = step.transform( - final_pipeline, patch_config(config, run_manager.get_child()) + final_pipeline, + patch_config(config, callbacks=run_manager.get_child()), ) for output in final_pipeline: yield output @@ -1022,7 +1012,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): input = await step.ainvoke( input, # mark each step as a child run - patch_config(config, run_manager.get_child()), + patch_config(config, callbacks=run_manager.get_child()), ) except (KeyboardInterrupt, Exception) as e: await run_manager.on_chain_error(e) @@ -1034,12 +1024,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): try: # stream the first of the last steps with non-streaming input final_pipeline = steps[streaming_start_index].astream( - input, patch_config(config, run_manager.get_child()) + input, patch_config(config, callbacks=run_manager.get_child()) ) # stream the rest of the last steps with streaming input for step in steps[streaming_start_index + 1 :]: final_pipeline = step.atransform( - final_pipeline, patch_config(config, run_manager.get_child()) + final_pipeline, + patch_config(config, callbacks=run_manager.get_child()), ) async for output in final_pipeline: yield output @@ -1068,7 +1059,7 @@ class RunnableMapChunk(Dict[str, Any]): """ def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk: - chunk = copy.deepcopy(self) + chunk = self.copy() for key in other: if key not in chunk or chunk[key] is None: chunk[key] = other[key] @@ -1076,6 +1067,15 @@ class RunnableMapChunk(Dict[str, Any]): chunk[key] += other[key] return chunk + def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk: + chunk = RunnableMapChunk(other) + for key in self: + if key not in chunk or chunk[key] is None: + chunk[key] = self[key] + elif self[key] is not None: + chunk[key] += self[key] + return chunk + class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): """ @@ -1132,13 +1132,18 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): try: # copy to avoid issues from the caller mutating the steps during invoke() steps = dict(self.steps) - with ThreadPoolExecutor() as executor: + with get_executor_for_config(config) as executor: futures = [ executor.submit( step.invoke, input, # mark each step as a child run - patch_config(deepcopy(config), run_manager.get_child()), + patch_config( + config, + deep_copy_locals=True, + callbacks=run_manager.get_child(), + executor=executor, + ), ) for step in steps.values() ] @@ -1172,7 +1177,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): step.ainvoke( input, # mark each step as a child run - patch_config(config, run_manager.get_child()), + patch_config(config, callbacks=run_manager.get_child()), ) for step in steps.values() ) @@ -1197,14 +1202,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): # Each step gets a copy of the input iterator, # which is consumed in parallel in a separate thread. input_copies = list(safetee(input, len(steps), lock=threading.Lock())) - with ThreadPoolExecutor() as executor: + with get_executor_for_config(config) as executor: # Create the transform() generator for each step named_generators = [ ( name, step.transform( input_copies.pop(), - patch_config(config, run_manager.get_child()), + patch_config( + config, callbacks=run_manager.get_child(), executor=executor + ), ), ) for name, step in steps.items() @@ -1265,7 +1272,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): ( name, step.atransform( - input_copies.pop(), patch_config(config, run_manager.get_child()) + input_copies.pop(), + patch_config(config, callbacks=run_manager.get_child()), ), ) for name, step in steps.items() @@ -1393,25 +1401,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - max_concurrency: Optional[int] = None, **kwargs: Optional[Any], ) -> List[Output]: - return self.bound.batch( - inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs} - ) + return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs}) async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - max_concurrency: Optional[int] = None, **kwargs: Optional[Any], ) -> List[Output]: - return await self.bound.abatch( - inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs} - ) + return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs}) def stream( self, diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 00408b7ee6c..9eff2ffaa94 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -1,6 +1,9 @@ from __future__ import annotations +from concurrent.futures import Executor, ThreadPoolExecutor +from contextlib import contextmanager +from copy import deepcopy -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Dict, Generator, List, Optional, TypedDict from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager @@ -32,20 +35,44 @@ class RunnableConfig(TypedDict, total=False): Local variables """ + max_concurrency: Optional[int] + """ + Maximum number of parallel calls to make. If not provided, defaults to ThreadPoolExecutor's default. This is ignored if an executor is provided. + """ + + executor: Executor + """ + Externally-managed executor to use for parallel calls. If not provided, a new + ThreadPoolExecutor will be created. + """ + def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: - empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) + empty = RunnableConfig( + tags=[], + metadata={}, + callbacks=None, + _locals={}, + ) if config is not None: empty.update(config) return empty def patch_config( - config: RunnableConfig, - callbacks: BaseCallbackManager, + config: Optional[RunnableConfig], + *, + deep_copy_locals: bool = False, + callbacks: Optional[BaseCallbackManager] = None, + executor: Optional[Executor] = None, ) -> RunnableConfig: - config = config.copy() - config["callbacks"] = callbacks + config = ensure_config(config) + if deep_copy_locals: + config["_locals"] = deepcopy(config["_locals"]) + if callbacks is not None: + config["callbacks"] = callbacks + if executor is not None: + config["executor"] = executor return config @@ -65,3 +92,12 @@ def get_async_callback_manager_for_config( inheritable_tags=config.get("tags"), inheritable_metadata=config.get("metadata"), ) + + +@contextmanager +def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]: + if config.get("executor"): + yield config["executor"] + else: + with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor: + yield executor diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 80d63c69123..98cab172a56 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -131,26 +131,25 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: assert fake.batch( ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] ) == [5, 7] - assert spy.call_args_list == [ - mocker.call( - "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) - ), - mocker.call( - "wooorld", - dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), - ), - ] + + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + if i == 0: + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {} + else: + assert call.args[1].get("tags") == [] + assert call.args[1].get("metadata") == {"key": "value"} + spy.reset_mock() assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] - assert spy.call_args_list == [ - mocker.call( - "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) - ), - mocker.call( - "wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) - ), - ] + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {} spy.reset_mock() assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 @@ -172,11 +171,21 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: assert spy.call_args_list == [ mocker.call( "hello", - dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + dict( + metadata={"key": "value"}, + tags=[], + callbacks=None, + _locals={}, + ), ), mocker.call( "wooorld", - dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + dict( + metadata={"key": "value"}, + tags=[], + callbacks=None, + _locals={}, + ), ), ]