diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 2411afbce88..04fe74b9db6 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -297,9 +297,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] return [ output - for batch in batches + for i, batch in enumerate(batches) for output in self.batch( - batch, config=config, return_exceptions=return_exceptions, **kwargs + batch, + config=config[i * max_concurrency : (i + 1) * max_concurrency], + return_exceptions=return_exceptions, + **kwargs, ) ] @@ -340,9 +343,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] return [ output - for batch in batches + for i, batch in enumerate(batches) for output in await self.abatch( - batch, config=config, return_exceptions=return_exceptions, **kwargs + batch, + config=config[i * max_concurrency : (i + 1) * max_concurrency], + return_exceptions=return_exceptions, + **kwargs, ) ] diff --git a/libs/langchain/tests/unit_tests/llms/test_base.py b/libs/langchain/tests/unit_tests/llms/test_base.py index 9cde78c7062..56d21b40c8f 100644 --- a/libs/langchain/tests/unit_tests/llms/test_base.py +++ b/libs/langchain/tests/unit_tests/llms/test_base.py @@ -6,6 +6,8 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +import pytest + from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.globals import get_llm_cache, set_llm_cache from langchain.schema import Generation, LLMResult @@ -73,3 +75,22 @@ def test_custom_caching() -> None: llm_output=None, ) assert output == expected_output + + +def test_batch() -> None: + llm = FakeLLM() + output = llm.batch(["foo", "bar", "foo"]) + assert output == ["foo"] * 3 + + output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2}) + assert output == ["foo"] * 3 + + +@pytest.mark.asyncio +async def test_abatch() -> None: + llm = FakeLLM() + output = await llm.abatch(["foo", "bar", "foo"]) + assert output == ["foo"] * 3 + + output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2}) + assert output == ["foo"] * 3