mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
core: Make abatch_as_completed respect max_concurrency (#29426)
- **Description:** Add tests for respecting max_concurrency and implement it for abatch_as_completed so that test passes - **Issue:** #29425 - **Dependencies:** none - **Twitter handle:** keenanpepper
This commit is contained in:
parent
dcfaae85d2
commit
c67d473397
@ -71,6 +71,7 @@ from langchain_core.runnables.utils import (
|
|||||||
accepts_config,
|
accepts_config,
|
||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
asyncio_accepts_context,
|
asyncio_accepts_context,
|
||||||
|
gated_coro,
|
||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
get_function_first_arg_dict_keys,
|
get_function_first_arg_dict_keys,
|
||||||
get_function_nonlocals,
|
get_function_nonlocals,
|
||||||
@ -952,8 +953,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
# Get max_concurrency from first config, defaulting to None (unlimited)
|
||||||
|
max_concurrency = configs[0].get("max_concurrency") if configs else None
|
||||||
|
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke_task(
|
||||||
i: int, input: Input, config: RunnableConfig
|
i: int, input: Input, config: RunnableConfig
|
||||||
) -> tuple[int, Union[Output, Exception]]:
|
) -> tuple[int, Union[Output, Exception]]:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
@ -965,10 +969,14 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
out = e
|
out = e
|
||||||
else:
|
else:
|
||||||
out = await self.ainvoke(input, config, **kwargs)
|
out = await self.ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
return (i, out)
|
return (i, out)
|
||||||
|
|
||||||
coros = map(ainvoke, range(len(inputs)), inputs, configs)
|
coros = [
|
||||||
|
gated_coro(semaphore, ainvoke_task(i, input, config))
|
||||||
|
if semaphore
|
||||||
|
else ainvoke_task(i, input, config)
|
||||||
|
for i, (input, config) in enumerate(zip(inputs, configs))
|
||||||
|
]
|
||||||
|
|
||||||
for coro in asyncio.as_completed(coros):
|
for coro in asyncio.as_completed(coros):
|
||||||
yield await coro
|
yield await coro
|
||||||
|
144
libs/core/tests/unit_tests/runnables/test_concurrency.py
Normal file
144
libs/core/tests/unit_tests/runnables/test_concurrency.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
"""Test concurrency behavior of batch and async batch operations."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||||
|
from langchain_core.runnables.base import Runnable
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abatch_concurrency() -> None:
|
||||||
|
"""Test that abatch respects max_concurrency."""
|
||||||
|
running_tasks = 0
|
||||||
|
max_running_tasks = 0
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def tracked_function(x: Any) -> str:
|
||||||
|
nonlocal running_tasks, max_running_tasks
|
||||||
|
async with lock:
|
||||||
|
running_tasks += 1
|
||||||
|
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1) # Simulate work
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
running_tasks -= 1
|
||||||
|
|
||||||
|
return f"Completed {x}"
|
||||||
|
|
||||||
|
runnable: Runnable = RunnableLambda(tracked_function)
|
||||||
|
num_tasks = 10
|
||||||
|
max_concurrency = 3
|
||||||
|
|
||||||
|
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||||
|
results = await runnable.abatch(list(range(num_tasks)), config=config)
|
||||||
|
|
||||||
|
assert len(results) == num_tasks
|
||||||
|
assert max_running_tasks <= max_concurrency
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abatch_as_completed_concurrency() -> None:
|
||||||
|
"""Test that abatch_as_completed respects max_concurrency."""
|
||||||
|
running_tasks = 0
|
||||||
|
max_running_tasks = 0
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def tracked_function(x: Any) -> str:
|
||||||
|
nonlocal running_tasks, max_running_tasks
|
||||||
|
async with lock:
|
||||||
|
running_tasks += 1
|
||||||
|
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1) # Simulate work
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
running_tasks -= 1
|
||||||
|
|
||||||
|
return f"Completed {x}"
|
||||||
|
|
||||||
|
runnable: Runnable = RunnableLambda(tracked_function)
|
||||||
|
num_tasks = 10
|
||||||
|
max_concurrency = 3
|
||||||
|
|
||||||
|
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||||
|
results = []
|
||||||
|
async for _idx, result in runnable.abatch_as_completed(
|
||||||
|
list(range(num_tasks)), config=config
|
||||||
|
):
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
assert len(results) == num_tasks
|
||||||
|
assert max_running_tasks <= max_concurrency
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_concurrency() -> None:
|
||||||
|
"""Test that batch respects max_concurrency."""
|
||||||
|
running_tasks = 0
|
||||||
|
max_running_tasks = 0
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
lock = Lock()
|
||||||
|
|
||||||
|
def tracked_function(x: Any) -> str:
|
||||||
|
nonlocal running_tasks, max_running_tasks
|
||||||
|
with lock:
|
||||||
|
running_tasks += 1
|
||||||
|
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||||
|
|
||||||
|
time.sleep(0.1) # Simulate work
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
running_tasks -= 1
|
||||||
|
|
||||||
|
return f"Completed {x}"
|
||||||
|
|
||||||
|
runnable: Runnable = RunnableLambda(tracked_function)
|
||||||
|
num_tasks = 10
|
||||||
|
max_concurrency = 3
|
||||||
|
|
||||||
|
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||||
|
results = runnable.batch(list(range(num_tasks)), config=config)
|
||||||
|
|
||||||
|
assert len(results) == num_tasks
|
||||||
|
assert max_running_tasks <= max_concurrency
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_as_completed_concurrency() -> None:
|
||||||
|
"""Test that batch_as_completed respects max_concurrency."""
|
||||||
|
running_tasks = 0
|
||||||
|
max_running_tasks = 0
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
lock = Lock()
|
||||||
|
|
||||||
|
def tracked_function(x: Any) -> str:
|
||||||
|
nonlocal running_tasks, max_running_tasks
|
||||||
|
with lock:
|
||||||
|
running_tasks += 1
|
||||||
|
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||||
|
|
||||||
|
time.sleep(0.1) # Simulate work
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
running_tasks -= 1
|
||||||
|
|
||||||
|
return f"Completed {x}"
|
||||||
|
|
||||||
|
runnable: Runnable = RunnableLambda(tracked_function)
|
||||||
|
num_tasks = 10
|
||||||
|
max_concurrency = 3
|
||||||
|
|
||||||
|
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||||
|
results = []
|
||||||
|
for _idx, result in runnable.batch_as_completed(
|
||||||
|
list(range(num_tasks)), config=config
|
||||||
|
):
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
assert len(results) == num_tasks
|
||||||
|
assert max_running_tasks <= max_concurrency
|
Loading…
Reference in New Issue
Block a user