mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
core: Make abatch_as_completed respect max_concurrency (#29426)
- **Description:** Add tests for respecting max_concurrency and implement it for abatch_as_completed so that test passes - **Issue:** #29425 - **Dependencies:** none - **Twitter handle:** keenanpepper
This commit is contained in:
144
libs/core/tests/unit_tests/runnables/test_concurrency.py
Normal file
144
libs/core/tests/unit_tests/runnables/test_concurrency.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Test concurrency behavior of batch and async batch operations."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abatch_concurrency() -> None:
|
||||
"""Test that abatch respects max_concurrency."""
|
||||
running_tasks = 0
|
||||
max_running_tasks = 0
|
||||
lock = asyncio.Lock()
|
||||
|
||||
async def tracked_function(x: Any) -> str:
|
||||
nonlocal running_tasks, max_running_tasks
|
||||
async with lock:
|
||||
running_tasks += 1
|
||||
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||
|
||||
await asyncio.sleep(0.1) # Simulate work
|
||||
|
||||
async with lock:
|
||||
running_tasks -= 1
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||
results = await runnable.abatch(list(range(num_tasks)), config=config)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert max_running_tasks <= max_concurrency
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abatch_as_completed_concurrency() -> None:
|
||||
"""Test that abatch_as_completed respects max_concurrency."""
|
||||
running_tasks = 0
|
||||
max_running_tasks = 0
|
||||
lock = asyncio.Lock()
|
||||
|
||||
async def tracked_function(x: Any) -> str:
|
||||
nonlocal running_tasks, max_running_tasks
|
||||
async with lock:
|
||||
running_tasks += 1
|
||||
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||
|
||||
await asyncio.sleep(0.1) # Simulate work
|
||||
|
||||
async with lock:
|
||||
running_tasks -= 1
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||
results = []
|
||||
async for _idx, result in runnable.abatch_as_completed(
|
||||
list(range(num_tasks)), config=config
|
||||
):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert max_running_tasks <= max_concurrency
|
||||
|
||||
|
||||
def test_batch_concurrency() -> None:
|
||||
"""Test that batch respects max_concurrency."""
|
||||
running_tasks = 0
|
||||
max_running_tasks = 0
|
||||
from threading import Lock
|
||||
|
||||
lock = Lock()
|
||||
|
||||
def tracked_function(x: Any) -> str:
|
||||
nonlocal running_tasks, max_running_tasks
|
||||
with lock:
|
||||
running_tasks += 1
|
||||
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||
|
||||
time.sleep(0.1) # Simulate work
|
||||
|
||||
with lock:
|
||||
running_tasks -= 1
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||
results = runnable.batch(list(range(num_tasks)), config=config)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert max_running_tasks <= max_concurrency
|
||||
|
||||
|
||||
def test_batch_as_completed_concurrency() -> None:
|
||||
"""Test that batch_as_completed respects max_concurrency."""
|
||||
running_tasks = 0
|
||||
max_running_tasks = 0
|
||||
from threading import Lock
|
||||
|
||||
lock = Lock()
|
||||
|
||||
def tracked_function(x: Any) -> str:
|
||||
nonlocal running_tasks, max_running_tasks
|
||||
with lock:
|
||||
running_tasks += 1
|
||||
max_running_tasks = max(max_running_tasks, running_tasks)
|
||||
|
||||
time.sleep(0.1) # Simulate work
|
||||
|
||||
with lock:
|
||||
running_tasks -= 1
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
config = RunnableConfig(max_concurrency=max_concurrency)
|
||||
results = []
|
||||
for _idx, result in runnable.batch_as_completed(
|
||||
list(range(num_tasks)), config=config
|
||||
):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert max_running_tasks <= max_concurrency
|
Reference in New Issue
Block a user