fix redis cache chat model (#8041)

Redis cache currently stores model outputs as strings. Chat generations
have Messages which contain more information than just a string. Until
Redis cache supports fully storing messages, cache should not interact
with chat generations.
This commit is contained in:
Bagatur 2023-07-20 19:00:05 -07:00 committed by GitHub
parent 973593c5c7
commit 7717c24fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 1 deletions

View File

@ -5,6 +5,7 @@ import hashlib
import inspect import inspect
import json import json
import logging import logging
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import timedelta from datetime import timedelta
from typing import ( from typing import (
@ -34,7 +35,7 @@ except ImportError:
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.load.dump import dumps from langchain.load.dump import dumps
from langchain.load.load import loads from langchain.load.load import loads
from langchain.schema import Generation from langchain.schema import ChatGeneration, Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore from langchain.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
@ -232,6 +233,12 @@ class RedisCache(BaseCache):
"RedisCache only supports caching of normal LLM generations, " "RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}" f"got {type(gen)}"
) )
if isinstance(gen, ChatGeneration):
warnings.warn(
"NOTE: Generation has not been cached. RedisCache does not"
" support caching ChatModel outputs."
)
return
# Write to a Redis HASH # Write to a Redis HASH
key = self._key(prompt, llm_string) key = self._key(prompt, llm_string)
self.redis.hset( self.redis.hset(
@ -345,6 +352,12 @@ class RedisSemanticCache(BaseCache):
"RedisSemanticCache only supports caching of " "RedisSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}" f"normal LLM generations, got {type(gen)}"
) )
if isinstance(gen, ChatGeneration):
warnings.warn(
"NOTE: Generation has not been cached. RedisSentimentCache does not"
" support caching ChatModel outputs."
)
return
llm_cache = self._get_llm_cache(llm_string) llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore # Write to vectorstore
metadata = { metadata = {

View File

@ -1,10 +1,12 @@
"""Test Redis cache functionality.""" """Test Redis cache functionality."""
import pytest
import redis import redis
import langchain import langchain
from langchain.cache import RedisCache, RedisSemanticCache from langchain.cache import RedisCache, RedisSemanticCache
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
REDIS_TEST_URL = "redis://localhost:6379" REDIS_TEST_URL = "redis://localhost:6379"
@ -28,6 +30,17 @@ def test_redis_cache() -> None:
langchain.llm_cache.redis.flushall() langchain.llm_cache.redis.flushall()
def test_redis_cache_chat() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
langchain.llm_cache.redis.flushall()
def test_redis_semantic_cache() -> None: def test_redis_semantic_cache() -> None:
langchain.llm_cache = RedisSemanticCache( langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1 embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
@ -53,3 +66,14 @@ def test_redis_semantic_cache() -> None:
# expect different output now without cached result # expect different output now without cached result
assert output != expected_output assert output != expected_output
langchain.llm_cache.clear(llm_string=llm_string) langchain.llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
langchain.llm_cache.redis.flushall()