mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
fix caching (#658)
This commit is contained in:
parent
d0fdc6da11
commit
54d7f1c933
@ -60,7 +60,7 @@ class SQLAlchemyCache(BaseCache):
|
|||||||
"""Initialize by creating all tables."""
|
"""Initialize by creating all tables."""
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.cache_schema = cache_schema
|
self.cache_schema = cache_schema
|
||||||
Base.metadata.create_all(self.engine)
|
self.cache_schema.metadata.create_all(self.engine)
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Look up based on prompt and llm_string."""
|
"""Look up based on prompt and llm_string."""
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
"""Test base LLM functionality."""
|
"""Test base LLM functionality."""
|
||||||
|
from sqlalchemy import Column, Integer, Sequence, String, create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.cache import InMemoryCache
|
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
@ -28,3 +31,41 @@ def test_caching() -> None:
|
|||||||
llm_output=None,
|
llm_output=None,
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_caching() -> None:
|
||||||
|
"""Test custom_caching behavior."""
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
class FulltextLLMCache(Base): # type: ignore
|
||||||
|
"""Postgres table for fulltext-indexed LLM Cache."""
|
||||||
|
|
||||||
|
__tablename__ = "llm_cache_fulltext"
|
||||||
|
id = Column(Integer, Sequence("cache_id"), primary_key=True)
|
||||||
|
prompt = Column(String, nullable=False)
|
||||||
|
llm = Column(String, nullable=False)
|
||||||
|
idx = Column(Integer)
|
||||||
|
response = Column(String)
|
||||||
|
|
||||||
|
engine = create_engine("sqlite://")
|
||||||
|
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm._llm_dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
output = llm.generate(["foo", "bar", "foo"])
|
||||||
|
expected_cache_output = [Generation(text="foo")]
|
||||||
|
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
||||||
|
assert cache_output == expected_cache_output
|
||||||
|
langchain.llm_cache = None
|
||||||
|
expected_generations = [
|
||||||
|
[Generation(text="fizz")],
|
||||||
|
[Generation(text="foo")],
|
||||||
|
[Generation(text="fizz")],
|
||||||
|
]
|
||||||
|
expected_output = LLMResult(
|
||||||
|
expected_generations,
|
||||||
|
llm_output=None,
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
|
Loading…
Reference in New Issue
Block a user