Correct number of elements in config list in batch() and abatch() of BaseLLM (#12713)

- **Description:** Correct number of elements in config list in
`batch()` and `abatch()` of `BaseLLM` in case `max_concurrency` is not
None.
- **Issue:** #12643
- **Twitter handle:** @akionux

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Akio Nishimura 2023-11-03 09:28:48 +09:00 committed by GitHub
parent 88b506b321
commit c04647bb4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 4 deletions

View File

@ -297,9 +297,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
return [
output
for batch in batches
for i, batch in enumerate(batches)
for output in self.batch(
batch, config=config, return_exceptions=return_exceptions, **kwargs
batch,
config=config[i * max_concurrency : (i + 1) * max_concurrency],
return_exceptions=return_exceptions,
**kwargs,
)
]
@ -340,9 +343,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
return [
output
for batch in batches
for i, batch in enumerate(batches)
for output in await self.abatch(
batch, config=config, return_exceptions=return_exceptions, **kwargs
batch,
config=config[i * max_concurrency : (i + 1) * max_concurrency],
return_exceptions=return_exceptions,
**kwargs,
)
]

View File

@ -6,6 +6,8 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
import pytest
from langchain.cache import InMemoryCache, SQLAlchemyCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.schema import Generation, LLMResult
@ -73,3 +75,22 @@ def test_custom_caching() -> None:
llm_output=None,
)
assert output == expected_output
def test_batch() -> None:
llm = FakeLLM()
output = llm.batch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3
@pytest.mark.asyncio
async def test_abatch() -> None:
llm = FakeLLM()
output = await llm.abatch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3