mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
fix(cache): use dumps for RedisCache (#10408)
# Description Attempts to fix RedisCache for ChatGenerations using `loads` and `dumps` used in SQLAlchemy cache by @hwchase17 . this is better than pickle dump, because this won't execute any arbitrary code during de-serialisation. # Issues #7722 & #8666 # Dependencies None, but removes the warning introduced in #8041 by @baskaryan Handle: @jaikanthjay46
This commit is contained in:
parent
5944c1851b
commit
9f85f7c543
@ -25,7 +25,6 @@ import hashlib
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -54,7 +53,7 @@ except ImportError:
|
|||||||
from langchain.llms.base import LLM, get_prompts
|
from langchain.llms.base import LLM, get_prompts
|
||||||
from langchain.load.dump import dumps
|
from langchain.load.dump import dumps
|
||||||
from langchain.load.load import loads
|
from langchain.load.load import loads
|
||||||
from langchain.schema import ChatGeneration, Generation
|
from langchain.schema import Generation
|
||||||
from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache
|
from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache
|
||||||
from langchain.schema.embeddings import Embeddings
|
from langchain.schema.embeddings import Embeddings
|
||||||
from langchain.utils import get_from_env
|
from langchain.utils import get_from_env
|
||||||
@ -306,7 +305,18 @@ class RedisCache(BaseCache):
|
|||||||
results = self.redis.hgetall(self._key(prompt, llm_string))
|
results = self.redis.hgetall(self._key(prompt, llm_string))
|
||||||
if results:
|
if results:
|
||||||
for _, text in results.items():
|
for _, text in results.items():
|
||||||
generations.append(Generation(text=text))
|
try:
|
||||||
|
generations.append(loads(text))
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Retrieving a cache value that could not be deserialized "
|
||||||
|
"properly. This is likely due to the cache being in an "
|
||||||
|
"older format. Please recreate your cache to avoid this "
|
||||||
|
"error."
|
||||||
|
)
|
||||||
|
# In a previous life we stored the raw text directly
|
||||||
|
# in the table, so assume it's in that format.
|
||||||
|
generations.append(Generation(text=text))
|
||||||
return generations if generations else None
|
return generations if generations else None
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
@ -317,12 +327,6 @@ class RedisCache(BaseCache):
|
|||||||
"RedisCache only supports caching of normal LLM generations, "
|
"RedisCache only supports caching of normal LLM generations, "
|
||||||
f"got {type(gen)}"
|
f"got {type(gen)}"
|
||||||
)
|
)
|
||||||
if isinstance(gen, ChatGeneration):
|
|
||||||
warnings.warn(
|
|
||||||
"NOTE: Generation has not been cached. RedisCache does not"
|
|
||||||
" support caching ChatModel outputs."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# Write to a Redis HASH
|
# Write to a Redis HASH
|
||||||
key = self._key(prompt, llm_string)
|
key = self._key(prompt, llm_string)
|
||||||
|
|
||||||
@ -330,7 +334,7 @@ class RedisCache(BaseCache):
|
|||||||
pipe.hset(
|
pipe.hset(
|
||||||
key,
|
key,
|
||||||
mapping={
|
mapping={
|
||||||
str(idx): generation.text
|
str(idx): dumps(generation)
|
||||||
for idx, generation in enumerate(return_val)
|
for idx, generation in enumerate(return_val)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -441,9 +445,20 @@ class RedisSemanticCache(BaseCache):
|
|||||||
)
|
)
|
||||||
if results:
|
if results:
|
||||||
for document in results:
|
for document in results:
|
||||||
generations.extend(
|
try:
|
||||||
_load_generations_from_json(document.metadata["return_val"])
|
generations.extend(loads(document.metadata["return_val"]))
|
||||||
)
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Retrieving a cache value that could not be deserialized "
|
||||||
|
"properly. This is likely due to the cache being in an "
|
||||||
|
"older format. Please recreate your cache to avoid this "
|
||||||
|
"error."
|
||||||
|
)
|
||||||
|
# In a previous life we stored the raw text directly
|
||||||
|
# in the table, so assume it's in that format.
|
||||||
|
generations.extend(
|
||||||
|
_load_generations_from_json(document.metadata["return_val"])
|
||||||
|
)
|
||||||
return generations if generations else None
|
return generations if generations else None
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
@ -454,18 +469,12 @@ class RedisSemanticCache(BaseCache):
|
|||||||
"RedisSemanticCache only supports caching of "
|
"RedisSemanticCache only supports caching of "
|
||||||
f"normal LLM generations, got {type(gen)}"
|
f"normal LLM generations, got {type(gen)}"
|
||||||
)
|
)
|
||||||
if isinstance(gen, ChatGeneration):
|
|
||||||
warnings.warn(
|
|
||||||
"NOTE: Generation has not been cached. RedisSentimentCache does not"
|
|
||||||
" support caching ChatModel outputs."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
llm_cache = self._get_llm_cache(llm_string)
|
llm_cache = self._get_llm_cache(llm_string)
|
||||||
_dump_generations_to_json([g for g in return_val])
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"llm_string": llm_string,
|
"llm_string": llm_string,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"return_val": _dump_generations_to_json([g for g in return_val]),
|
"return_val": dumps([g for g in return_val]),
|
||||||
}
|
}
|
||||||
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||||
|
|
||||||
|
@ -6,8 +6,11 @@ import pytest
|
|||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.cache import RedisCache, RedisSemanticCache
|
from langchain.cache import RedisCache, RedisSemanticCache
|
||||||
|
from langchain.load.dump import dumps
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.schema.embeddings import Embeddings
|
from langchain.schema.embeddings import Embeddings
|
||||||
|
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain.schema.output import ChatGeneration
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
ConsistentFakeEmbeddings,
|
ConsistentFakeEmbeddings,
|
||||||
FakeEmbeddings,
|
FakeEmbeddings,
|
||||||
@ -56,9 +59,17 @@ def test_redis_cache_chat() -> None:
|
|||||||
llm = FakeChatModel()
|
llm = FakeChatModel()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
with pytest.warns():
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
llm.predict("foo")
|
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||||
llm.predict("foo")
|
langchain.llm_cache.update(
|
||||||
|
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||||
|
)
|
||||||
|
output = llm.generate([prompt])
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
langchain.llm_cache.redis.flushall()
|
langchain.llm_cache.redis.flushall()
|
||||||
|
|
||||||
|
|
||||||
@ -120,9 +131,16 @@ def test_redis_semantic_cache_chat() -> None:
|
|||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
with pytest.warns():
|
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||||
llm.predict("foo")
|
langchain.llm_cache.update(
|
||||||
llm.predict("foo")
|
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||||
|
)
|
||||||
|
output = llm.generate([prompt])
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
langchain.llm_cache.clear(llm_string=llm_string)
|
langchain.llm_cache.clear(llm_string=llm_string)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user