mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-20 03:21:33 +00:00
Correct number of elements in config list in batch()
and abatch()
of BaseLLM
(#12713)
- **Description:** Correct number of elements in config list in `batch()` and `abatch()` of `BaseLLM` in case `max_concurrency` is not None. - **Issue:** #12643 - **Twitter handle:** @akionux --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
88b506b321
commit
c04647bb4e
@ -297,9 +297,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
|
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
|
||||||
return [
|
return [
|
||||||
output
|
output
|
||||||
for batch in batches
|
for i, batch in enumerate(batches)
|
||||||
for output in self.batch(
|
for output in self.batch(
|
||||||
batch, config=config, return_exceptions=return_exceptions, **kwargs
|
batch,
|
||||||
|
config=config[i * max_concurrency : (i + 1) * max_concurrency],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -340,9 +343,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
|
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
|
||||||
return [
|
return [
|
||||||
output
|
output
|
||||||
for batch in batches
|
for i, batch in enumerate(batches)
|
||||||
for output in await self.abatch(
|
for output in await self.abatch(
|
||||||
batch, config=config, return_exceptions=return_exceptions, **kwargs
|
batch,
|
||||||
|
config=config[i * max_concurrency : (i + 1) * max_concurrency],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -6,6 +6,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
||||||
from langchain.globals import get_llm_cache, set_llm_cache
|
from langchain.globals import get_llm_cache, set_llm_cache
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
@ -73,3 +75,22 @@ def test_custom_caching() -> None:
|
|||||||
llm_output=None,
|
llm_output=None,
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch() -> None:
|
||||||
|
llm = FakeLLM()
|
||||||
|
output = llm.batch(["foo", "bar", "foo"])
|
||||||
|
assert output == ["foo"] * 3
|
||||||
|
|
||||||
|
output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2})
|
||||||
|
assert output == ["foo"] * 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abatch() -> None:
|
||||||
|
llm = FakeLLM()
|
||||||
|
output = await llm.abatch(["foo", "bar", "foo"])
|
||||||
|
assert output == ["foo"] * 3
|
||||||
|
|
||||||
|
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
|
||||||
|
assert output == ["foo"] * 3
|
||||||
|
Loading…
Reference in New Issue
Block a user