core: Make abatch_as_completed respect max_concurrency (#29426)

- **Description:** Add tests for respecting max_concurrency and
implement it for abatch_as_completed so that test passes
- **Issue:** #29425
- **Dependencies:** none
- **Twitter handle:** keenanpepper
This commit is contained in:
Keenan Pepper
2025-02-07 16:51:22 -08:00
committed by GitHub
parent dcfaae85d2
commit c67d473397
2 changed files with 155 additions and 3 deletions

View File

@@ -0,0 +1,144 @@
"""Test concurrency behavior of batch and async batch operations."""
import asyncio
import time
from typing import Any
import pytest
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables.base import Runnable
@pytest.mark.asyncio
async def test_abatch_concurrency() -> None:
"""Test that abatch respects max_concurrency."""
running_tasks = 0
max_running_tasks = 0
lock = asyncio.Lock()
async def tracked_function(x: Any) -> str:
nonlocal running_tasks, max_running_tasks
async with lock:
running_tasks += 1
max_running_tasks = max(max_running_tasks, running_tasks)
await asyncio.sleep(0.1) # Simulate work
async with lock:
running_tasks -= 1
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
config = RunnableConfig(max_concurrency=max_concurrency)
results = await runnable.abatch(list(range(num_tasks)), config=config)
assert len(results) == num_tasks
assert max_running_tasks <= max_concurrency
@pytest.mark.asyncio
async def test_abatch_as_completed_concurrency() -> None:
"""Test that abatch_as_completed respects max_concurrency."""
running_tasks = 0
max_running_tasks = 0
lock = asyncio.Lock()
async def tracked_function(x: Any) -> str:
nonlocal running_tasks, max_running_tasks
async with lock:
running_tasks += 1
max_running_tasks = max(max_running_tasks, running_tasks)
await asyncio.sleep(0.1) # Simulate work
async with lock:
running_tasks -= 1
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
config = RunnableConfig(max_concurrency=max_concurrency)
results = []
async for _idx, result in runnable.abatch_as_completed(
list(range(num_tasks)), config=config
):
results.append(result)
assert len(results) == num_tasks
assert max_running_tasks <= max_concurrency
def test_batch_concurrency() -> None:
"""Test that batch respects max_concurrency."""
running_tasks = 0
max_running_tasks = 0
from threading import Lock
lock = Lock()
def tracked_function(x: Any) -> str:
nonlocal running_tasks, max_running_tasks
with lock:
running_tasks += 1
max_running_tasks = max(max_running_tasks, running_tasks)
time.sleep(0.1) # Simulate work
with lock:
running_tasks -= 1
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
config = RunnableConfig(max_concurrency=max_concurrency)
results = runnable.batch(list(range(num_tasks)), config=config)
assert len(results) == num_tasks
assert max_running_tasks <= max_concurrency
def test_batch_as_completed_concurrency() -> None:
"""Test that batch_as_completed respects max_concurrency."""
running_tasks = 0
max_running_tasks = 0
from threading import Lock
lock = Lock()
def tracked_function(x: Any) -> str:
nonlocal running_tasks, max_running_tasks
with lock:
running_tasks += 1
max_running_tasks = max(max_running_tasks, running_tasks)
time.sleep(0.1) # Simulate work
with lock:
running_tasks -= 1
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
config = RunnableConfig(max_concurrency=max_concurrency)
results = []
for _idx, result in runnable.batch_as_completed(
list(range(num_tasks)), config=config
):
results.append(result)
assert len(results) == num_tasks
assert max_running_tasks <= max_concurrency