mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
core: Make abatch_as_completed respect max_concurrency (#29426)
- **Description:** Add tests for respecting max_concurrency and implement it for abatch_as_completed so that test passes - **Issue:** #29425 - **Dependencies:** none - **Twitter handle:** keenanpepper
This commit is contained in:
parent
dcfaae85d2
commit
c67d473397
@ -71,6 +71,7 @@ from langchain_core.runnables.utils import (
|
||||
accepts_config,
|
||||
accepts_run_manager,
|
||||
asyncio_accepts_context,
|
||||
gated_coro,
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
get_function_nonlocals,
|
||||
@ -952,8 +953,11 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
# Get max_concurrency from first config, defaulting to None (unlimited)
|
||||
max_concurrency = configs[0].get("max_concurrency") if configs else None
|
||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||
|
||||
async def ainvoke(
|
||||
async def ainvoke_task(
|
||||
i: int, input: Input, config: RunnableConfig
|
||||
) -> tuple[int, Union[Output, Exception]]:
|
||||
if return_exceptions:
|
||||
@ -965,10 +969,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
out = e
|
||||
else:
|
||||
out = await self.ainvoke(input, config, **kwargs)
|
||||
|
||||
return (i, out)
|
||||
|
||||
coros = map(ainvoke, range(len(inputs)), inputs, configs)
|
||||
coros = [
|
||||
gated_coro(semaphore, ainvoke_task(i, input, config))
|
||||
if semaphore
|
||||
else ainvoke_task(i, input, config)
|
||||
for i, (input, config) in enumerate(zip(inputs, configs))
|
||||
]
|
||||
|
||||
for coro in asyncio.as_completed(coros):
|
||||
yield await coro
|
||||
|
144
libs/core/tests/unit_tests/runnables/test_concurrency.py
Normal file
144
libs/core/tests/unit_tests/runnables/test_concurrency.py
Normal file
@ -0,0 +1,144 @@
|
||||
"""Test concurrency behavior of batch and async batch operations."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abatch_concurrency() -> None:
|
||||
"""Test that abatch respects max_concurrency."""
|
||||
running_tasks = 0
|
||||
max_running_tasks = 0
|
||||
lock = asyncio.Lock()
|
||||
|
||||
async def tracked_function(x: Any) -> str:
|
||||
nonlocal running_tasks, max_running_tasks
|
||||
async with lock:
|
||||
running_tasks += 1
|
||||
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||
|
||||
await asyncio.sleep(0.1) # Simulate work
|
||||
|
||||
async with lock:
|
||||
running_tasks -= 1
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||
results = await runnable.abatch(list(range(num_tasks)), config=config)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert max_running_tasks <= max_concurrency
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abatch_as_completed_concurrency() -> None:
|
||||
"""Test that abatch_as_completed respects max_concurrency."""
|
||||
running_tasks = 0
|
||||
max_running_tasks = 0
|
||||
lock = asyncio.Lock()
|
||||
|
||||
async def tracked_function(x: Any) -> str:
|
||||
nonlocal running_tasks, max_running_tasks
|
||||
async with lock:
|
||||
running_tasks += 1
|
||||
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||
|
||||
await asyncio.sleep(0.1) # Simulate work
|
||||
|
||||
async with lock:
|
||||
running_tasks -= 1
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||
results = []
|
||||
async for _idx, result in runnable.abatch_as_completed(
|
||||
list(range(num_tasks)), config=config
|
||||
):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert max_running_tasks <= max_concurrency
|
||||
|
||||
|
||||
def test_batch_concurrency() -> None:
|
||||
"""Test that batch respects max_concurrency."""
|
||||
running_tasks = 0
|
||||
max_running_tasks = 0
|
||||
from threading import Lock
|
||||
|
||||
lock = Lock()
|
||||
|
||||
def tracked_function(x: Any) -> str:
|
||||
nonlocal running_tasks, max_running_tasks
|
||||
with lock:
|
||||
running_tasks += 1
|
||||
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||
|
||||
time.sleep(0.1) # Simulate work
|
||||
|
||||
with lock:
|
||||
running_tasks -= 1
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||
results = runnable.batch(list(range(num_tasks)), config=config)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert max_running_tasks <= max_concurrency
|
||||
|
||||
|
||||
def test_batch_as_completed_concurrency() -> None:
|
||||
"""Test that batch_as_completed respects max_concurrency."""
|
||||
running_tasks = 0
|
||||
max_running_tasks = 0
|
||||
from threading import Lock
|
||||
|
||||
lock = Lock()
|
||||
|
||||
def tracked_function(x: Any) -> str:
|
||||
nonlocal running_tasks, max_running_tasks
|
||||
with lock:
|
||||
running_tasks += 1
|
||||
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||
|
||||
time.sleep(0.1) # Simulate work
|
||||
|
||||
with lock:
|
||||
running_tasks -= 1
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||
results = []
|
||||
for _idx, result in runnable.batch_as_completed(
|
||||
list(range(num_tasks)), config=config
|
||||
):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert max_running_tasks <= max_concurrency
|
Loading…
Reference in New Issue
Block a user