mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
fix caching (#658)
This commit is contained in:
parent
d0fdc6da11
commit
54d7f1c933
@ -60,7 +60,7 @@ class SQLAlchemyCache(BaseCache):
|
||||
"""Initialize by creating all tables."""
|
||||
self.engine = engine
|
||||
self.cache_schema = cache_schema
|
||||
Base.metadata.create_all(self.engine)
|
||||
self.cache_schema.metadata.create_all(self.engine)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
|
@ -1,6 +1,9 @@
|
||||
"""Test base LLM functionality."""
|
||||
from sqlalchemy import Column, Integer, Sequence, String, create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
import langchain
|
||||
from langchain.cache import InMemoryCache
|
||||
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@ -28,3 +31,41 @@ def test_caching() -> None:
|
||||
llm_output=None,
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_custom_caching() -> None:
|
||||
"""Test custom_caching behavior."""
|
||||
Base = declarative_base()
|
||||
|
||||
class FulltextLLMCache(Base): # type: ignore
|
||||
"""Postgres table for fulltext-indexed LLM Cache."""
|
||||
|
||||
__tablename__ = "llm_cache_fulltext"
|
||||
id = Column(Integer, Sequence("cache_id"), primary_key=True)
|
||||
prompt = Column(String, nullable=False)
|
||||
llm = Column(String, nullable=False)
|
||||
idx = Column(Integer)
|
||||
response = Column(String)
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
|
||||
llm = FakeLLM()
|
||||
params = llm._llm_dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo", "bar", "foo"])
|
||||
expected_cache_output = [Generation(text="foo")]
|
||||
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
||||
assert cache_output == expected_cache_output
|
||||
langchain.llm_cache = None
|
||||
expected_generations = [
|
||||
[Generation(text="fizz")],
|
||||
[Generation(text="foo")],
|
||||
[Generation(text="fizz")],
|
||||
]
|
||||
expected_output = LLMResult(
|
||||
expected_generations,
|
||||
llm_output=None,
|
||||
)
|
||||
assert output == expected_output
|
||||
|
Loading…
Reference in New Issue
Block a user