BUGFIX: llm backwards compat imports (#13698)

This commit is contained in:
Bagatur 2023-11-21 20:12:35 -08:00 committed by GitHub
parent ace9e64d62
commit a21e84faf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 18 deletions

View File

@ -0,0 +1,19 @@
from tests.unit_tests.fake.llm import FakeListLLM
def test_batch() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
output = llm.batch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3
async def test_abatch() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
output = await llm.abatch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3

View File

@ -3,6 +3,7 @@ from langchain_core.language_models.chat_models import (
SimpleChatModel,
_agenerate_from_stream,
_generate_from_stream,
_get_verbosity,
)
__all__ = [
@ -10,4 +11,5 @@ __all__ = [
"SimpleChatModel",
"_generate_from_stream",
"_agenerate_from_stream",
"_get_verbosity",
]

View File

@ -1,6 +1,9 @@
# Backwards compatibility.
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.llms import (
LLM,
BaseLLM,
_get_verbosity,
create_base_retry_decorator,
get_prompts,
update_cache,
@ -10,6 +13,8 @@ __all__ = [
"create_base_retry_decorator",
"get_prompts",
"update_cache",
"BaseLanguageModel",
"_get_verbosity",
"BaseLLM",
"LLM",
]

View File

@ -0,0 +1,13 @@
from langchain.chat_models.base import __all__
EXPECTED_ALL = [
"BaseChatModel",
"SimpleChatModel",
"_agenerate_from_stream",
"_generate_from_stream",
"_get_verbosity",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -10,8 +10,23 @@ from langchain_core.outputs import Generation, LLMResult
from langchain.cache import InMemoryCache, SQLAlchemyCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.llms.base import __all__
from tests.unit_tests.llms.fake_llm import FakeLLM
EXPECTED_ALL = [
"BaseLLM",
"LLM",
"_get_verbosity",
"create_base_retry_decorator",
"get_prompts",
"update_cache",
"BaseLanguageModel",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
def test_caching() -> None:
"""Test caching behavior."""
@ -74,21 +89,3 @@ def test_custom_caching() -> None:
llm_output=None,
)
assert output == expected_output
def test_batch() -> None:
llm = FakeLLM()
output = llm.batch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3
async def test_abatch() -> None:
llm = FakeLLM()
output = await llm.abatch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3