mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 18:53:10 +00:00
fix redis cache chat model (#8041)
Redis cache currently stores model outputs as strings. Chat generations have Messages which contain more information than just a string. Until Redis cache supports fully storing messages, cache should not interact with chat generations.
This commit is contained in:
parent
973593c5c7
commit
7717c24fc4
@ -5,6 +5,7 @@ import hashlib
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -34,7 +35,7 @@ except ImportError:
|
|||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.load.dump import dumps
|
from langchain.load.dump import dumps
|
||||||
from langchain.load.load import loads
|
from langchain.load.load import loads
|
||||||
from langchain.schema import Generation
|
from langchain.schema import ChatGeneration, Generation
|
||||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
@ -232,6 +233,12 @@ class RedisCache(BaseCache):
|
|||||||
"RedisCache only supports caching of normal LLM generations, "
|
"RedisCache only supports caching of normal LLM generations, "
|
||||||
f"got {type(gen)}"
|
f"got {type(gen)}"
|
||||||
)
|
)
|
||||||
|
if isinstance(gen, ChatGeneration):
|
||||||
|
warnings.warn(
|
||||||
|
"NOTE: Generation has not been cached. RedisCache does not"
|
||||||
|
" support caching ChatModel outputs."
|
||||||
|
)
|
||||||
|
return
|
||||||
# Write to a Redis HASH
|
# Write to a Redis HASH
|
||||||
key = self._key(prompt, llm_string)
|
key = self._key(prompt, llm_string)
|
||||||
self.redis.hset(
|
self.redis.hset(
|
||||||
@ -345,6 +352,12 @@ class RedisSemanticCache(BaseCache):
|
|||||||
"RedisSemanticCache only supports caching of "
|
"RedisSemanticCache only supports caching of "
|
||||||
f"normal LLM generations, got {type(gen)}"
|
f"normal LLM generations, got {type(gen)}"
|
||||||
)
|
)
|
||||||
|
if isinstance(gen, ChatGeneration):
|
||||||
|
warnings.warn(
|
||||||
|
"NOTE: Generation has not been cached. RedisSentimentCache does not"
|
||||||
|
" support caching ChatModel outputs."
|
||||||
|
)
|
||||||
|
return
|
||||||
llm_cache = self._get_llm_cache(llm_string)
|
llm_cache = self._get_llm_cache(llm_string)
|
||||||
# Write to vectorstore
|
# Write to vectorstore
|
||||||
metadata = {
|
metadata = {
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
"""Test Redis cache functionality."""
|
"""Test Redis cache functionality."""
|
||||||
|
import pytest
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.cache import RedisCache, RedisSemanticCache
|
from langchain.cache import RedisCache, RedisSemanticCache
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
REDIS_TEST_URL = "redis://localhost:6379"
|
REDIS_TEST_URL = "redis://localhost:6379"
|
||||||
@ -28,6 +30,17 @@ def test_redis_cache() -> None:
|
|||||||
langchain.llm_cache.redis.flushall()
|
langchain.llm_cache.redis.flushall()
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_cache_chat() -> None:
|
||||||
|
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||||
|
llm = FakeChatModel()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
with pytest.warns():
|
||||||
|
llm.predict("foo")
|
||||||
|
llm.predict("foo")
|
||||||
|
langchain.llm_cache.redis.flushall()
|
||||||
|
|
||||||
|
|
||||||
def test_redis_semantic_cache() -> None:
|
def test_redis_semantic_cache() -> None:
|
||||||
langchain.llm_cache = RedisSemanticCache(
|
langchain.llm_cache = RedisSemanticCache(
|
||||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||||
@ -53,3 +66,14 @@ def test_redis_semantic_cache() -> None:
|
|||||||
# expect different output now without cached result
|
# expect different output now without cached result
|
||||||
assert output != expected_output
|
assert output != expected_output
|
||||||
langchain.llm_cache.clear(llm_string=llm_string)
|
langchain.llm_cache.clear(llm_string=llm_string)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_semantic_cache_chat() -> None:
|
||||||
|
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||||
|
llm = FakeChatModel()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
with pytest.warns():
|
||||||
|
llm.predict("foo")
|
||||||
|
llm.predict("foo")
|
||||||
|
langchain.llm_cache.redis.flushall()
|
||||||
|
Loading…
Reference in New Issue
Block a user