Files
langchain/tests/integration_tests/cache/test_redis_cache.py
Bagatur 7717c24fc4 fix redis cache chat model (#8041)
Redis cache currently stores model outputs as strings. Chat generations
have Messages which contain more information than just a string. Until
Redis cache supports fully storing messages, cache should not interact
with chat generations.
2023-07-20 19:00:05 -07:00

80 lines
2.6 KiB
Python

"""Test Redis cache functionality."""
import pytest
import redis
import langchain
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
REDIS_TEST_URL = "redis://localhost:6379"
def test_redis_cache() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output)
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output)
assert output == expected_output
langchain.llm_cache.redis.flushall()
def test_redis_cache_chat() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
langchain.llm_cache.redis.flushall()
def test_redis_semantic_cache() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
langchain.llm_cache.clear(llm_string=llm_string)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
# expect different output now without cached result
assert output != expected_output
langchain.llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
langchain.llm_cache.redis.flushall()