mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
fix redis cache chat model (#8041)
Redis cache currently stores model outputs as strings. Chat generations have Messages which contain more information than just a string. Until Redis cache supports fully storing messages, cache should not interact with chat generations.
This commit is contained in:
parent
973593c5c7
commit
7717c24fc4
@ -5,6 +5,7 @@ import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from typing import (
|
||||
@ -34,7 +35,7 @@ except ImportError:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.schema import Generation
|
||||
from langchain.schema import ChatGeneration, Generation
|
||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
@ -232,6 +233,12 @@ class RedisCache(BaseCache):
|
||||
"RedisCache only supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
if isinstance(gen, ChatGeneration):
|
||||
warnings.warn(
|
||||
"NOTE: Generation has not been cached. RedisCache does not"
|
||||
" support caching ChatModel outputs."
|
||||
)
|
||||
return
|
||||
# Write to a Redis HASH
|
||||
key = self._key(prompt, llm_string)
|
||||
self.redis.hset(
|
||||
@ -345,6 +352,12 @@ class RedisSemanticCache(BaseCache):
|
||||
"RedisSemanticCache only supports caching of "
|
||||
f"normal LLM generations, got {type(gen)}"
|
||||
)
|
||||
if isinstance(gen, ChatGeneration):
|
||||
warnings.warn(
|
||||
"NOTE: Generation has not been cached. RedisSentimentCache does not"
|
||||
" support caching ChatModel outputs."
|
||||
)
|
||||
return
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
# Write to vectorstore
|
||||
metadata = {
|
||||
|
@ -1,10 +1,12 @@
|
||||
"""Test Redis cache functionality."""
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
import langchain
|
||||
from langchain.cache import RedisCache, RedisSemanticCache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
REDIS_TEST_URL = "redis://localhost:6379"
|
||||
@ -28,6 +30,17 @@ def test_redis_cache() -> None:
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
||||
|
||||
def test_redis_cache_chat() -> None:
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
with pytest.warns():
|
||||
llm.predict("foo")
|
||||
llm.predict("foo")
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
||||
|
||||
def test_redis_semantic_cache() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
@ -53,3 +66,14 @@ def test_redis_semantic_cache() -> None:
|
||||
# expect different output now without cached result
|
||||
assert output != expected_output
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_chat() -> None:
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
with pytest.warns():
|
||||
llm.predict("foo")
|
||||
llm.predict("foo")
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
Loading…
Reference in New Issue
Block a user