mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 06:00:41 +00:00
fix(cache): use dumps for RedisCache (#10408)
# Description Attempts to fix RedisCache for ChatGenerations using `loads` and `dumps` used in SQLAlchemy cache by @hwchase17 . this is better than pickle dump, because this won't execute any arbitrary code during de-serialisation. # Issues #7722 & #8666 # Dependencies None, but removes the warning introduced in #8041 by @baskaryan Handle: @jaikanthjay46
This commit is contained in:
parent
5944c1851b
commit
9f85f7c543
@ -25,7 +25,6 @@ import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
@ -54,7 +53,7 @@ except ImportError:
|
||||
from langchain.llms.base import LLM, get_prompts
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.schema import ChatGeneration, Generation
|
||||
from langchain.schema import Generation
|
||||
from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_env
|
||||
@ -306,7 +305,18 @@ class RedisCache(BaseCache):
|
||||
results = self.redis.hgetall(self._key(prompt, llm_string))
|
||||
if results:
|
||||
for _, text in results.items():
|
||||
generations.append(Generation(text=text))
|
||||
try:
|
||||
generations.append(loads(text))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Retrieving a cache value that could not be deserialized "
|
||||
"properly. This is likely due to the cache being in an "
|
||||
"older format. Please recreate your cache to avoid this "
|
||||
"error."
|
||||
)
|
||||
# In a previous life we stored the raw text directly
|
||||
# in the table, so assume it's in that format.
|
||||
generations.append(Generation(text=text))
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
@ -317,12 +327,6 @@ class RedisCache(BaseCache):
|
||||
"RedisCache only supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
if isinstance(gen, ChatGeneration):
|
||||
warnings.warn(
|
||||
"NOTE: Generation has not been cached. RedisCache does not"
|
||||
" support caching ChatModel outputs."
|
||||
)
|
||||
return
|
||||
# Write to a Redis HASH
|
||||
key = self._key(prompt, llm_string)
|
||||
|
||||
@ -330,7 +334,7 @@ class RedisCache(BaseCache):
|
||||
pipe.hset(
|
||||
key,
|
||||
mapping={
|
||||
str(idx): generation.text
|
||||
str(idx): dumps(generation)
|
||||
for idx, generation in enumerate(return_val)
|
||||
},
|
||||
)
|
||||
@ -441,9 +445,20 @@ class RedisSemanticCache(BaseCache):
|
||||
)
|
||||
if results:
|
||||
for document in results:
|
||||
generations.extend(
|
||||
_load_generations_from_json(document.metadata["return_val"])
|
||||
)
|
||||
try:
|
||||
generations.extend(loads(document.metadata["return_val"]))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Retrieving a cache value that could not be deserialized "
|
||||
"properly. This is likely due to the cache being in an "
|
||||
"older format. Please recreate your cache to avoid this "
|
||||
"error."
|
||||
)
|
||||
# In a previous life we stored the raw text directly
|
||||
# in the table, so assume it's in that format.
|
||||
generations.extend(
|
||||
_load_generations_from_json(document.metadata["return_val"])
|
||||
)
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
@ -454,18 +469,12 @@ class RedisSemanticCache(BaseCache):
|
||||
"RedisSemanticCache only supports caching of "
|
||||
f"normal LLM generations, got {type(gen)}"
|
||||
)
|
||||
if isinstance(gen, ChatGeneration):
|
||||
warnings.warn(
|
||||
"NOTE: Generation has not been cached. RedisSentimentCache does not"
|
||||
" support caching ChatModel outputs."
|
||||
)
|
||||
return
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
_dump_generations_to_json([g for g in return_val])
|
||||
|
||||
metadata = {
|
||||
"llm_string": llm_string,
|
||||
"prompt": prompt,
|
||||
"return_val": _dump_generations_to_json([g for g in return_val]),
|
||||
"return_val": dumps([g for g in return_val]),
|
||||
}
|
||||
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||
|
||||
|
@ -6,8 +6,11 @@ import pytest
|
||||
|
||||
import langchain
|
||||
from langchain.cache import RedisCache, RedisSemanticCache
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain.schema.output import ChatGeneration
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
@ -56,9 +59,17 @@ def test_redis_cache_chat() -> None:
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
with pytest.warns():
|
||||
llm.predict("foo")
|
||||
llm.predict("foo")
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||
langchain.llm_cache.update(
|
||||
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||
)
|
||||
output = llm.generate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
||||
|
||||
@ -120,9 +131,16 @@ def test_redis_semantic_cache_chat() -> None:
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
with pytest.warns():
|
||||
llm.predict("foo")
|
||||
llm.predict("foo")
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||
langchain.llm_cache.update(
|
||||
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||
)
|
||||
output = llm.generate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user