langchain[patch],community[patch]: Move unit tests that depend on community to community (#21685)

This commit is contained in:
Eugene Yurtsev
2024-05-16 17:24:27 -04:00
committed by GitHub
parent 97a4ae50d2
commit 8607735b80
22 changed files with 1248 additions and 262 deletions

View File

@@ -4,20 +4,26 @@ from typing import Dict, Generator, List, Union
import pytest
from _pytest.fixtures import FixtureRequest
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.caches import InMemoryCache
from langchain_core.language_models import FakeListChatModel, FakeListLLM
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.load import dumps
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, Generation
from sqlalchemy import create_engine
from langchain_core.outputs import ChatGeneration
from sqlalchemy import Column, Integer, Sequence, String, create_engine
from sqlalchemy.orm import Session
pytest.importorskip("langchain_community")
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.outputs import Generation, LLMResult
from langchain_community.cache import SQLAlchemyCache # noqa: E402
from tests.unit_tests.llms.fake_llm import FakeLLM
def get_sqlite_cache() -> SQLAlchemyCache:
@@ -210,3 +216,44 @@ def create_llm_string(llm: Union[BaseLLM, BaseChatModel]) -> str:
_dict: Dict = llm.dict()
_dict["stop"] = None
return str(sorted([(k, v) for k, v in _dict.items()]))
def test_sql_alchemy_cache() -> None:
"""Test custom_caching behavior."""
Base = declarative_base()
class FulltextLLMCache(Base): # type: ignore
"""Postgres table for fulltext-indexed LLM Cache."""
__tablename__ = "llm_cache_fulltext"
id = Column(Integer, Sequence("cache_id"), primary_key=True)
prompt = Column(String, nullable=False)
llm = Column(String, nullable=False)
idx = Column(Integer)
response = Column(String)
engine = create_engine("sqlite://")
from langchain_community.cache import SQLAlchemyCache
set_llm_cache(SQLAlchemyCache(engine, FulltextLLMCache))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == expected_cache_output
set_llm_cache(None)
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
generations=expected_generations,
llm_output=None,
)
assert output == expected_output