mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
Correct number of elements in config list in batch()
and abatch()
of BaseLLM
(#12713)
- **Description:** Correct number of elements in config list in `batch()` and `abatch()` of `BaseLLM` in case `max_concurrency` is not None. - **Issue:** #12643 - **Twitter handle:** @akionux --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
88b506b321
commit
c04647bb4e
@ -297,9 +297,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
|
||||
return [
|
||||
output
|
||||
for batch in batches
|
||||
for i, batch in enumerate(batches)
|
||||
for output in self.batch(
|
||||
batch, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
batch,
|
||||
config=config[i * max_concurrency : (i + 1) * max_concurrency],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
|
||||
@ -340,9 +343,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
|
||||
return [
|
||||
output
|
||||
for batch in batches
|
||||
for i, batch in enumerate(batches)
|
||||
for output in await self.abatch(
|
||||
batch, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
batch,
|
||||
config=config[i * max_concurrency : (i + 1) * max_concurrency],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -6,6 +6,8 @@ try:
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
@ -73,3 +75,22 @@ def test_custom_caching() -> None:
|
||||
llm_output=None,
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
llm = FakeLLM()
|
||||
output = llm.batch(["foo", "bar", "foo"])
|
||||
assert output == ["foo"] * 3
|
||||
|
||||
output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2})
|
||||
assert output == ["foo"] * 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abatch() -> None:
|
||||
llm = FakeLLM()
|
||||
output = await llm.abatch(["foo", "bar", "foo"])
|
||||
assert output == ["foo"] * 3
|
||||
|
||||
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
|
||||
assert output == ["foo"] * 3
|
||||
|
Loading…
Reference in New Issue
Block a user