From c67d473397f49892d33ccac3713e18245e7a7ac3 Mon Sep 17 00:00:00 2001 From: Keenan Pepper Date: Fri, 7 Feb 2025 16:51:22 -0800 Subject: [PATCH] core: Make abatch_as_completed respect max_concurrency (#29426) - **Description:** Add tests for respecting max_concurrency and implement it for abatch_as_completed so that test passes - **Issue:** #29425 - **Dependencies:** none - **Twitter handle:** keenanpepper --- libs/core/langchain_core/runnables/base.py | 14 +- .../unit_tests/runnables/test_concurrency.py | 144 ++++++++++++++++++ 2 files changed, 155 insertions(+), 3 deletions(-) create mode 100644 libs/core/tests/unit_tests/runnables/test_concurrency.py diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index bf805c3a4f0..1ad1017241a 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -71,6 +71,7 @@ from langchain_core.runnables.utils import ( accepts_config, accepts_run_manager, asyncio_accepts_context, + gated_coro, gather_with_concurrency, get_function_first_arg_dict_keys, get_function_nonlocals, @@ -952,8 +953,11 @@ class Runnable(Generic[Input, Output], ABC): return configs = get_config_list(config, len(inputs)) + # Get max_concurrency from first config, defaulting to None (unlimited) + max_concurrency = configs[0].get("max_concurrency") if configs else None + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None - async def ainvoke( + async def ainvoke_task( i: int, input: Input, config: RunnableConfig ) -> tuple[int, Union[Output, Exception]]: if return_exceptions: @@ -965,10 +969,14 @@ class Runnable(Generic[Input, Output], ABC): out = e else: out = await self.ainvoke(input, config, **kwargs) - return (i, out) - coros = map(ainvoke, range(len(inputs)), inputs, configs) + coros = [ + gated_coro(semaphore, ainvoke_task(i, input, config)) + if semaphore + else ainvoke_task(i, input, config) + for i, (input, config) in enumerate(zip(inputs, configs)) + ] for coro in asyncio.as_completed(coros): yield await coro diff --git a/libs/core/tests/unit_tests/runnables/test_concurrency.py b/libs/core/tests/unit_tests/runnables/test_concurrency.py new file mode 100644 index 00000000000..24d4fad5d23 --- /dev/null +++ b/libs/core/tests/unit_tests/runnables/test_concurrency.py @@ -0,0 +1,144 @@ +"""Test concurrency behavior of batch and async batch operations.""" + +import asyncio +import time +from typing import Any + +import pytest + +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.runnables.base import Runnable + + +@pytest.mark.asyncio +async def test_abatch_concurrency() -> None: + """Test that abatch respects max_concurrency.""" + running_tasks = 0 + max_running_tasks = 0 + lock = asyncio.Lock() + + async def tracked_function(x: Any) -> str: + nonlocal running_tasks, max_running_tasks + async with lock: + running_tasks += 1 + max_running_tasks = max(max_running_tasks, running_tasks) + + await asyncio.sleep(0.1) # Simulate work + + async with lock: + running_tasks -= 1 + + return f"Completed {x}" + + runnable: Runnable = RunnableLambda(tracked_function) + num_tasks = 10 + max_concurrency = 3 + + config = RunnableConfig(max_concurrency=max_concurrency) + results = await runnable.abatch(list(range(num_tasks)), config=config) + + assert len(results) == num_tasks + assert max_running_tasks <= max_concurrency + + +@pytest.mark.asyncio +async def test_abatch_as_completed_concurrency() -> None: + """Test that abatch_as_completed respects max_concurrency.""" + running_tasks = 0 + max_running_tasks = 0 + lock = asyncio.Lock() + + async def tracked_function(x: Any) -> str: + nonlocal running_tasks, max_running_tasks + async with lock: + running_tasks += 1 + max_running_tasks = max(max_running_tasks, running_tasks) + + await asyncio.sleep(0.1) # Simulate work + + async with lock: + running_tasks -= 1 + + return f"Completed {x}" + + runnable: Runnable = RunnableLambda(tracked_function) + num_tasks = 10 + max_concurrency = 3 + + config = RunnableConfig(max_concurrency=max_concurrency) + results = [] + async for _idx, result in runnable.abatch_as_completed( + list(range(num_tasks)), config=config + ): + results.append(result) + + assert len(results) == num_tasks + assert max_running_tasks <= max_concurrency + + +def test_batch_concurrency() -> None: + """Test that batch respects max_concurrency.""" + running_tasks = 0 + max_running_tasks = 0 + from threading import Lock + + lock = Lock() + + def tracked_function(x: Any) -> str: + nonlocal running_tasks, max_running_tasks + with lock: + running_tasks += 1 + max_running_tasks = max(max_running_tasks, running_tasks) + + time.sleep(0.1) # Simulate work + + with lock: + running_tasks -= 1 + + return f"Completed {x}" + + runnable: Runnable = RunnableLambda(tracked_function) + num_tasks = 10 + max_concurrency = 3 + + config = RunnableConfig(max_concurrency=max_concurrency) + results = runnable.batch(list(range(num_tasks)), config=config) + + assert len(results) == num_tasks + assert max_running_tasks <= max_concurrency + + +def test_batch_as_completed_concurrency() -> None: + """Test that batch_as_completed respects max_concurrency.""" + running_tasks = 0 + max_running_tasks = 0 + from threading import Lock + + lock = Lock() + + def tracked_function(x: Any) -> str: + nonlocal running_tasks, max_running_tasks + with lock: + running_tasks += 1 + max_running_tasks = max(max_running_tasks, running_tasks) + + time.sleep(0.1) # Simulate work + + with lock: + running_tasks -= 1 + + return f"Completed {x}" + + runnable: Runnable = RunnableLambda(tracked_function) + num_tasks = 10 + max_concurrency = 3 + + config = RunnableConfig(max_concurrency=max_concurrency) + results = [] + for _idx, result in runnable.batch_as_completed( + list(range(num_tasks)), config=config + ): + results.append(result) + + assert len(results) == num_tasks + assert max_running_tasks <= max_concurrency