From c04647bb4ed622019e9faa87699d84093ed11127 Mon Sep 17 00:00:00 2001 From: Akio Nishimura Date: Fri, 3 Nov 2023 09:28:48 +0900 Subject: [PATCH] Correct number of elements in config list in `batch()` and `abatch()` of `BaseLLM` (#12713) - **Description:** Correct number of elements in config list in `batch()` and `abatch()` of `BaseLLM` in case `max_concurrency` is not None. - **Issue:** #12643 - **Twitter handle:** @akionux --------- Co-authored-by: Bagatur --- libs/langchain/langchain/llms/base.py | 14 +++++++++---- .../tests/unit_tests/llms/test_base.py | 21 +++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 2411afbce88..04fe74b9db6 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -297,9 +297,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] return [ output - for batch in batches + for i, batch in enumerate(batches) for output in self.batch( - batch, config=config, return_exceptions=return_exceptions, **kwargs + batch, + config=config[i * max_concurrency : (i + 1) * max_concurrency], + return_exceptions=return_exceptions, + **kwargs, ) ] @@ -340,9 +343,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] return [ output - for batch in batches + for i, batch in enumerate(batches) for output in await self.abatch( - batch, config=config, return_exceptions=return_exceptions, **kwargs + batch, + config=config[i * max_concurrency : (i + 1) * max_concurrency], + return_exceptions=return_exceptions, + **kwargs, ) ] diff --git a/libs/langchain/tests/unit_tests/llms/test_base.py b/libs/langchain/tests/unit_tests/llms/test_base.py index 9cde78c7062..56d21b40c8f 100644 --- a/libs/langchain/tests/unit_tests/llms/test_base.py +++ b/libs/langchain/tests/unit_tests/llms/test_base.py @@ -6,6 +6,8 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +import pytest + from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.globals import get_llm_cache, set_llm_cache from langchain.schema import Generation, LLMResult @@ -73,3 +75,22 @@ def test_custom_caching() -> None: llm_output=None, ) assert output == expected_output + + +def test_batch() -> None: + llm = FakeLLM() + output = llm.batch(["foo", "bar", "foo"]) + assert output == ["foo"] * 3 + + output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2}) + assert output == ["foo"] * 3 + + +@pytest.mark.asyncio +async def test_abatch() -> None: + llm = FakeLLM() + output = await llm.abatch(["foo", "bar", "foo"]) + assert output == ["foo"] * 3 + + output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2}) + assert output == ["foo"] * 3