core: Make abatch_as_completed respect max_concurrency (#29426)

- **Description:** Add tests for respecting max_concurrency and
implement it for abatch_as_completed so that test passes
- **Issue:** #29425
- **Dependencies:** none
- **Twitter handle:** keenanpepper
This commit is contained in:
Keenan Pepper 2025-02-07 16:51:22 -08:00 committed by GitHub
parent dcfaae85d2
commit c67d473397
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 155 additions and 3 deletions

View File

@ -71,6 +71,7 @@ from langchain_core.runnables.utils import (
accepts_config,
accepts_run_manager,
asyncio_accepts_context,
gated_coro,
gather_with_concurrency,
get_function_first_arg_dict_keys,
get_function_nonlocals,
@ -952,8 +953,11 @@ class Runnable(Generic[Input, Output], ABC):
return
configs = get_config_list(config, len(inputs))
# Get max_concurrency from first config, defaulting to None (unlimited)
max_concurrency = configs[0].get("max_concurrency") if configs else None
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def ainvoke(
async def ainvoke_task(
i: int, input: Input, config: RunnableConfig
) -> tuple[int, Union[Output, Exception]]:
if return_exceptions:
@ -965,10 +969,14 @@ class Runnable(Generic[Input, Output], ABC):
out = e
else:
out = await self.ainvoke(input, config, **kwargs)
return (i, out)
coros = map(ainvoke, range(len(inputs)), inputs, configs)
coros = [
gated_coro(semaphore, ainvoke_task(i, input, config))
if semaphore
else ainvoke_task(i, input, config)
for i, (input, config) in enumerate(zip(inputs, configs))
]
for coro in asyncio.as_completed(coros):
yield await coro

View File

@ -0,0 +1,144 @@
"""Test concurrency behavior of batch and async batch operations."""
import asyncio
import time
from typing import Any
import pytest
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables.base import Runnable
@pytest.mark.asyncio
async def test_abatch_concurrency() -> None:
"""Test that abatch respects max_concurrency."""
running_tasks = 0
max_running_tasks = 0
lock = asyncio.Lock()
async def tracked_function(x: Any) -> str:
nonlocal running_tasks, max_running_tasks
async with lock:
running_tasks += 1
max_running_tasks = max(max_running_tasks, running_tasks)
await asyncio.sleep(0.1) # Simulate work
async with lock:
running_tasks -= 1
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
config = RunnableConfig(max_concurrency=max_concurrency)
results = await runnable.abatch(list(range(num_tasks)), config=config)
assert len(results) == num_tasks
assert max_running_tasks <= max_concurrency
@pytest.mark.asyncio
async def test_abatch_as_completed_concurrency() -> None:
"""Test that abatch_as_completed respects max_concurrency."""
running_tasks = 0
max_running_tasks = 0
lock = asyncio.Lock()
async def tracked_function(x: Any) -> str:
nonlocal running_tasks, max_running_tasks
async with lock:
running_tasks += 1
max_running_tasks = max(max_running_tasks, running_tasks)
await asyncio.sleep(0.1) # Simulate work
async with lock:
running_tasks -= 1
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
config = RunnableConfig(max_concurrency=max_concurrency)
results = []
async for _idx, result in runnable.abatch_as_completed(
list(range(num_tasks)), config=config
):
results.append(result)
assert len(results) == num_tasks
assert max_running_tasks <= max_concurrency
def test_batch_concurrency() -> None:
"""Test that batch respects max_concurrency."""
running_tasks = 0
max_running_tasks = 0
from threading import Lock
lock = Lock()
def tracked_function(x: Any) -> str:
nonlocal running_tasks, max_running_tasks
with lock:
running_tasks += 1
max_running_tasks = max(max_running_tasks, running_tasks)
time.sleep(0.1) # Simulate work
with lock:
running_tasks -= 1
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
config = RunnableConfig(max_concurrency=max_concurrency)
results = runnable.batch(list(range(num_tasks)), config=config)
assert len(results) == num_tasks
assert max_running_tasks <= max_concurrency
def test_batch_as_completed_concurrency() -> None:
"""Test that batch_as_completed respects max_concurrency."""
running_tasks = 0
max_running_tasks = 0
from threading import Lock
lock = Lock()
def tracked_function(x: Any) -> str:
nonlocal running_tasks, max_running_tasks
with lock:
running_tasks += 1
max_running_tasks = max(max_running_tasks, running_tasks)
time.sleep(0.1) # Simulate work
with lock:
running_tasks -= 1
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
config = RunnableConfig(max_concurrency=max_concurrency)
results = []
for _idx, result in runnable.batch_as_completed(
list(range(num_tasks)), config=config
):
results.append(result)
assert len(results) == num_tasks
assert max_running_tasks <= max_concurrency