mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
Upstash redis integration (#10871)
- **Description:** Introduced Upstash provider with following wrappers: UpstashRedisCache, UpstashRedisEntityStore, UpstashRedisChatMessageHistory, UpstashRedisStore - **Issue:** -, - **Dependencies:** upstash-redis python package is needed, - **Tag maintainer:** @baskaryan - **Twitter handle:** @BurakY744 --------- Co-authored-by: Burak Yılmaz <burakyilmaz@Buraks-MacBook-Pro.local> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -26,6 +26,7 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
@@ -53,7 +54,7 @@ except ImportError:
|
||||
from langchain.llms.base import LLM, get_prompts
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.schema import Generation
|
||||
from langchain.schema import ChatGeneration, Generation
|
||||
from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_env
|
||||
@@ -260,6 +261,92 @@ class SQLiteCache(SQLAlchemyCache):
|
||||
super().__init__(engine)
|
||||
|
||||
|
||||
class UpstashRedisCache(BaseCache):
|
||||
"""Cache that uses Upstash Redis as a backend."""
|
||||
|
||||
def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
|
||||
"""
|
||||
Initialize an instance of UpstashRedisCache.
|
||||
|
||||
This method initializes an object with Upstash Redis caching capabilities.
|
||||
It takes a `redis_` parameter, which should be an instance of an Upstash Redis
|
||||
client class, allowing the object to interact with Upstash Redis
|
||||
server for caching purposes.
|
||||
|
||||
Parameters:
|
||||
redis_: An instance of Upstash Redis client class
|
||||
(e.g., Redis) used for caching.
|
||||
This allows the object to communicate with
|
||||
Redis server for caching operations on.
|
||||
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
|
||||
If provided, it sets the time duration for how long cached
|
||||
items will remain valid. If not provided, cached items will not
|
||||
have an automatic expiration.
|
||||
"""
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import upstash_redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
if not isinstance(redis_, Redis):
|
||||
raise ValueError("Please pass in Upstash Redis object.")
|
||||
self.redis = redis_
|
||||
self.ttl = ttl
|
||||
|
||||
def _key(self, prompt: str, llm_string: str) -> str:
|
||||
"""Compute key from prompt and llm_string"""
|
||||
return _hash(prompt + llm_string)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
generations = []
|
||||
# Read from a HASH
|
||||
results = self.redis.hgetall(self._key(prompt, llm_string))
|
||||
if results:
|
||||
for _, text in results.items():
|
||||
generations.append(Generation(text=text))
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
for gen in return_val:
|
||||
if not isinstance(gen, Generation):
|
||||
raise ValueError(
|
||||
"UpstashRedisCache supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
if isinstance(gen, ChatGeneration):
|
||||
warnings.warn(
|
||||
"NOTE: Generation has not been cached. UpstashRedisCache does not"
|
||||
" support caching ChatModel outputs."
|
||||
)
|
||||
return
|
||||
# Write to a HASH
|
||||
key = self._key(prompt, llm_string)
|
||||
|
||||
mapping = {
|
||||
str(idx): generation.text for idx, generation in enumerate(return_val)
|
||||
}
|
||||
self.redis.hset(key=key, values=mapping)
|
||||
|
||||
if self.ttl is not None:
|
||||
self.redis.expire(key, self.ttl)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Clear cache. If `asynchronous` is True, flush asynchronously.
|
||||
This flushes the *whole* db.
|
||||
"""
|
||||
asynchronous = kwargs.get("asynchronous", False)
|
||||
if asynchronous:
|
||||
asynchronous = "ASYNC"
|
||||
else:
|
||||
asynchronous = "SYNC"
|
||||
self.redis.flushdb(flush_type=asynchronous)
|
||||
|
||||
|
||||
class RedisCache(BaseCache):
|
||||
"""Cache that uses Redis as a backend."""
|
||||
|
||||
|
@@ -44,6 +44,7 @@ from langchain.memory.chat_message_histories import (
|
||||
RedisChatMessageHistory,
|
||||
SQLChatMessageHistory,
|
||||
StreamlitChatMessageHistory,
|
||||
UpstashRedisChatMessageHistory,
|
||||
XataChatMessageHistory,
|
||||
ZepChatMessageHistory,
|
||||
)
|
||||
@@ -53,6 +54,7 @@ from langchain.memory.entity import (
|
||||
InMemoryEntityStore,
|
||||
RedisEntityStore,
|
||||
SQLiteEntityStore,
|
||||
UpstashRedisEntityStore,
|
||||
)
|
||||
from langchain.memory.kg import ConversationKGMemory
|
||||
from langchain.memory.motorhead_memory import MotorheadMemory
|
||||
@@ -96,4 +98,6 @@ __all__ = [
|
||||
"XataChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
"ZepMemory",
|
||||
"UpstashRedisEntityStore",
|
||||
"UpstashRedisChatMessageHistory",
|
||||
]
|
||||
|
@@ -20,6 +20,9 @@ from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.streamlit import (
|
||||
StreamlitChatMessageHistory,
|
||||
)
|
||||
from langchain.memory.chat_message_histories.upstash_redis import (
|
||||
UpstashRedisChatMessageHistory,
|
||||
)
|
||||
from langchain.memory.chat_message_histories.xata import XataChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory
|
||||
|
||||
@@ -40,4 +43,5 @@ __all__ = [
|
||||
"StreamlitChatMessageHistory",
|
||||
"XataChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
"UpstashRedisChatMessageHistory",
|
||||
]
|
||||
|
@@ -0,0 +1,67 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpstashRedisChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in an Upstash Redis database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "message_store:",
|
||||
ttl: Optional[int] = None,
|
||||
):
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import upstash redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
|
||||
if url == "" or token == "":
|
||||
raise ValueError(
|
||||
"UPSTASH_REDIS_REST_URL and UPSTASH_REDIS_REST_TOKEN are needed."
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception:
|
||||
logger.error("Upstash Redis instance could not be initiated.")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Construct the record key to use"""
|
||||
return self.key_prefix + self.session_id
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from Upstash Redis"""
|
||||
_items = self.redis_client.lrange(self.key, 0, -1)
|
||||
items = [json.loads(m) for m in _items[::-1]]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in Upstash Redis"""
|
||||
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
|
||||
if self.ttl:
|
||||
self.redis_client.expire(self.key, self.ttl)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from Upstash Redis"""
|
||||
self.redis_client.delete(self.key)
|
@@ -69,6 +69,84 @@ class InMemoryEntityStore(BaseEntityStore):
|
||||
return self.store.clear()
|
||||
|
||||
|
||||
class UpstashRedisEntityStore(BaseEntityStore):
|
||||
"""Upstash Redis backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import upstash_redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception:
|
||||
logger.error("Upstash Redis instance could not be initiated.")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
def scan_and_delete(cursor: int) -> int:
|
||||
cursor, keys_to_delete = self.redis_client.scan(
|
||||
cursor, f"{self.full_key_prefix}:*"
|
||||
)
|
||||
self.redis_client.delete(*keys_to_delete)
|
||||
return cursor
|
||||
|
||||
cursor = scan_and_delete(0)
|
||||
while cursor != 0:
|
||||
scan_and_delete(cursor)
|
||||
|
||||
|
||||
class RedisEntityStore(BaseEntityStore):
|
||||
"""Redis-backed Entity store.
|
||||
|
||||
|
@@ -11,6 +11,7 @@ from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
from langchain.storage.file_system import LocalFileStore
|
||||
from langchain.storage.in_memory import InMemoryStore
|
||||
from langchain.storage.redis import RedisStore
|
||||
from langchain.storage.upstash_redis import UpstashRedisStore
|
||||
|
||||
__all__ = [
|
||||
"EncoderBackedStore",
|
||||
@@ -19,4 +20,5 @@ __all__ = [
|
||||
"RedisStore",
|
||||
"create_lc_store",
|
||||
"create_kv_docstore",
|
||||
"UpstashRedisStore",
|
||||
]
|
||||
|
119
libs/langchain/langchain/storage/upstash_redis.py
Normal file
119
libs/langchain/langchain/storage/upstash_redis.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
from langchain.schema import BaseStore
|
||||
|
||||
|
||||
class UpstashRedisStore(BaseStore[str, str]):
|
||||
"""BaseStore implementation using Upstash Redis as the underlying store."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: Any = None,
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
ttl: Optional[int] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize the UpstashRedisStore with HTTP API.
|
||||
|
||||
Must provide either an Upstash Redis client or a url.
|
||||
|
||||
Args:
|
||||
client: An Upstash Redis instance
|
||||
url: UPSTASH_REDIS_REST_URL
|
||||
token: UPSTASH_REDIS_REST_TOKEN
|
||||
ttl: time to expire keys in seconds if provided,
|
||||
if None keys will never expire
|
||||
namespace: if provided, all keys will be prefixed with this namespace
|
||||
"""
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"UpstashRedisStore requires the upstash_redis library to be installed. "
|
||||
"pip install upstash_redis"
|
||||
) from e
|
||||
|
||||
if client and url:
|
||||
raise ValueError(
|
||||
"Either an Upstash Redis client or a url must be provided, not both."
|
||||
)
|
||||
|
||||
if client:
|
||||
if not isinstance(client, Redis):
|
||||
raise TypeError(
|
||||
f"Expected Upstash Redis client, got {type(client).__name__}."
|
||||
)
|
||||
_client = client
|
||||
else:
|
||||
if not url or not token:
|
||||
raise ValueError(
|
||||
"Either an Upstash Redis client or url and token must be provided."
|
||||
)
|
||||
_client = Redis(url=url, token=token)
|
||||
|
||||
self.client = _client
|
||||
|
||||
if not isinstance(ttl, int) and ttl is not None:
|
||||
raise TypeError(f"Expected int or None, got {type(ttl)} instead.")
|
||||
|
||||
self.ttl = ttl
|
||||
self.namespace = namespace
|
||||
|
||||
def _get_prefixed_key(self, key: str) -> str:
|
||||
"""Get the key with the namespace prefix.
|
||||
|
||||
Args:
|
||||
key (str): The original key.
|
||||
|
||||
Returns:
|
||||
str: The key with the namespace prefix.
|
||||
"""
|
||||
delimiter = "/"
|
||||
if self.namespace:
|
||||
return f"{self.namespace}{delimiter}{key}"
|
||||
return key
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[str]]:
|
||||
"""Get the values associated with the given keys."""
|
||||
|
||||
keys = [self._get_prefixed_key(key) for key in keys]
|
||||
return cast(
|
||||
List[Optional[str]],
|
||||
self.client.mget(*keys),
|
||||
)
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, str]]) -> None:
|
||||
"""Set the given key-value pairs."""
|
||||
for key, value in key_value_pairs:
|
||||
self.client.set(self._get_prefixed_key(key), value, ex=self.ttl)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys."""
|
||||
_keys = [self._get_prefixed_key(key) for key in keys]
|
||||
self.client.delete(*_keys)
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Yield keys in the store."""
|
||||
if prefix:
|
||||
pattern = self._get_prefixed_key(prefix)
|
||||
else:
|
||||
pattern = self._get_prefixed_key("*")
|
||||
|
||||
cursor, keys = self.client.scan(0, match=pattern)
|
||||
for key in keys:
|
||||
if self.namespace:
|
||||
relative_key = key[len(self.namespace) + 1 :]
|
||||
yield relative_key
|
||||
else:
|
||||
yield key
|
||||
|
||||
while cursor != 0:
|
||||
cursor, keys = self.client.scan(cursor, match=pattern)
|
||||
for key in keys:
|
||||
if self.namespace:
|
||||
relative_key = key[len(self.namespace) + 1 :]
|
||||
yield relative_key
|
||||
else:
|
||||
yield key
|
19
libs/langchain/poetry.lock
generated
19
libs/langchain/poetry.lock
generated
@@ -10131,6 +10131,21 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
[package.extras]
|
||||
devenv = ["black", "check-manifest", "flake8", "pyroma", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"]
|
||||
|
||||
[[package]]
|
||||
name = "upstash-redis"
|
||||
version = "0.15.0"
|
||||
description = "Serverless Redis SDK from Upstash"
|
||||
optional = true
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "upstash_redis-0.15.0-py3-none-any.whl", hash = "sha256:4a89913cb2bb2422610bc2a9c8d6b9a9d75d0674c22c5ea8037d35d343ee5846"},
|
||||
{file = "upstash_redis-0.15.0.tar.gz", hash = "sha256:910f6a567142167b742c38efecfabf23f47e24fcbddb00a6b5845cb11064c3af"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = ">=3.8.4,<4.0.0"
|
||||
requests = ">=2.31.0,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "uri-template"
|
||||
version = "1.3.0"
|
||||
@@ -10883,7 +10898,7 @@ cli = ["typer"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["aiosqlite", "amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"]
|
||||
extended-testing = ["aiosqlite", "amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
||||
javascript = ["esprima"]
|
||||
llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
@@ -10893,4 +10908,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "3a5bca34a60eaa9b66a4d1f9ec14de5e6a0e5ca1071a0a874499fe122cc0ee36"
|
||||
content-hash = "6205031e342d6e4640b47b0b5a37fa7d11ea1915e8a3ee05c00e2e49fdec071e"
|
||||
|
@@ -139,6 +139,7 @@ typer = {version= "^0.9.0", optional = true}
|
||||
anthropic = {version = "^0.3.11", optional = true}
|
||||
aiosqlite = {version = "^0.19.0", optional = true}
|
||||
rspace_client = {version = "^2.5.0", optional = true}
|
||||
upstash-redis = {version = "^0.15.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
@@ -367,6 +368,7 @@ extended_testing = [
|
||||
"motor",
|
||||
"timescale-vector",
|
||||
"anthropic",
|
||||
"upstash-redis",
|
||||
"rspace_client",
|
||||
]
|
||||
|
||||
|
91
libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py
vendored
Normal file
91
libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Test Upstash Redis cache functionality."""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import langchain
|
||||
from langchain.cache import UpstashRedisCache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
URL = "<UPSTASH_REDIS_REST_URL>"
|
||||
TOKEN = "<UPSTASH_REDIS_REST_TOKEN>"
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_redis_cache_ttl() -> None:
|
||||
from upstash_redis import Redis
|
||||
|
||||
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
|
||||
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")])
|
||||
key = langchain.llm_cache._key("foo", "bar")
|
||||
assert langchain.llm_cache.redis.pttl(key) > 0
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_redis_cache() -> None:
|
||||
from upstash_redis import Redis
|
||||
|
||||
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
lookup_output = langchain.llm_cache.lookup("foo", llm_string)
|
||||
if lookup_output and len(lookup_output) > 0:
|
||||
assert lookup_output == expected_output.generations[0]
|
||||
|
||||
langchain.llm_cache.clear()
|
||||
output = llm.generate(["foo"])
|
||||
|
||||
assert output != expected_output
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
||||
|
||||
def test_redis_cache_multi() -> None:
|
||||
from upstash_redis import Redis
|
||||
|
||||
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
output = llm.generate(
|
||||
["foo"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
langchain.llm_cache.clear()
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_redis_cache_chat() -> None:
|
||||
from upstash_redis import Redis
|
||||
|
||||
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
with pytest.warns():
|
||||
llm.predict("foo")
|
||||
langchain.llm_cache.redis.flushall()
|
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories.upstash_redis import (
|
||||
UpstashRedisChatMessageHistory,
|
||||
)
|
||||
from langchain.schema.messages import _message_to_dict
|
||||
|
||||
URL = "<UPSTASH_REDIS_REST_URL>"
|
||||
TOKEN = "<UPSTASH_REDIS_REST_TOKEN>"
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup Upstash Redis as a message store
|
||||
message_history = UpstashRedisChatMessageHistory(
|
||||
url=URL, token=TOKEN, ttl=10, session_id="my-test-session"
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Redis, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
@@ -0,0 +1,95 @@
|
||||
"""Implement integration tests for Redis storage."""
|
||||
|
||||
import pytest
|
||||
from upstash_redis import Redis
|
||||
|
||||
from langchain.storage.upstash_redis import UpstashRedisStore
|
||||
|
||||
pytest.importorskip("upstash_redis")
|
||||
|
||||
URL = "<UPSTASH_REDIS_REST_URL>"
|
||||
TOKEN = "<UPSTASH_REDIS_REST_TOKEN>"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_client() -> Redis:
|
||||
"""Yield redis client."""
|
||||
from upstash_redis import Redis
|
||||
|
||||
# This fixture flushes the database!
|
||||
|
||||
client = Redis(url=URL, token=TOKEN)
|
||||
try:
|
||||
client.ping()
|
||||
except Exception:
|
||||
pytest.skip("Ping request failed. Verify that credentials are correct.")
|
||||
|
||||
client.flushdb()
|
||||
return client
|
||||
|
||||
|
||||
def test_mget(redis_client: Redis) -> None:
|
||||
store = UpstashRedisStore(client=redis_client, ttl=None)
|
||||
keys = ["key1", "key2"]
|
||||
redis_client.mset({"key1": "value1", "key2": "value2"})
|
||||
result = store.mget(keys)
|
||||
assert result == ["value1", "value2"]
|
||||
|
||||
|
||||
def test_mset(redis_client: Redis) -> None:
|
||||
store = UpstashRedisStore(client=redis_client, ttl=None)
|
||||
key_value_pairs = [("key1", "value1"), ("key2", "value2")]
|
||||
store.mset(key_value_pairs)
|
||||
result = redis_client.mget("key1", "key2")
|
||||
assert result == ["value1", "value2"]
|
||||
|
||||
|
||||
def test_mdelete(redis_client: Redis) -> None:
|
||||
"""Test that deletion works as expected."""
|
||||
store = UpstashRedisStore(client=redis_client, ttl=None)
|
||||
keys = ["key1", "key2"]
|
||||
redis_client.mset({"key1": "value1", "key2": "value2"})
|
||||
store.mdelete(keys)
|
||||
result = redis_client.mget(*keys)
|
||||
assert result == [None, None]
|
||||
|
||||
|
||||
def test_yield_keys(redis_client: Redis) -> None:
|
||||
store = UpstashRedisStore(client=redis_client, ttl=None)
|
||||
redis_client.mset({"key1": "value2", "key2": "value2"})
|
||||
assert sorted(store.yield_keys()) == ["key1", "key2"]
|
||||
assert sorted(store.yield_keys(prefix="key*")) == ["key1", "key2"]
|
||||
assert sorted(store.yield_keys(prefix="lang*")) == []
|
||||
|
||||
|
||||
def test_namespace(redis_client: Redis) -> None:
|
||||
store = UpstashRedisStore(client=redis_client, ttl=None, namespace="meow")
|
||||
key_value_pairs = [("key1", "value1"), ("key2", "value2")]
|
||||
store.mset(key_value_pairs)
|
||||
|
||||
cursor, all_keys = redis_client.scan(0)
|
||||
while cursor != 0:
|
||||
cursor, keys = redis_client.scan(cursor)
|
||||
if len(keys) != 0:
|
||||
all_keys.extend(keys)
|
||||
|
||||
assert sorted(all_keys) == [
|
||||
"meow/key1",
|
||||
"meow/key2",
|
||||
]
|
||||
|
||||
store.mdelete(["key1"])
|
||||
|
||||
cursor, all_keys = redis_client.scan(0, match="*")
|
||||
while cursor != 0:
|
||||
cursor, keys = redis_client.scan(cursor, match="*")
|
||||
if len(keys) != 0:
|
||||
all_keys.extend(keys)
|
||||
|
||||
assert sorted(all_keys) == [
|
||||
"meow/key2",
|
||||
]
|
||||
|
||||
assert list(store.yield_keys()) == ["key2"]
|
||||
assert list(store.yield_keys(prefix="key*")) == ["key2"]
|
||||
assert list(store.yield_keys(prefix="key1")) == []
|
@@ -0,0 +1,8 @@
|
||||
"""Light weight unit test that attempts to import UpstashRedisStore.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_import_storage() -> None:
|
||||
from langchain.storage.upstash_redis import UpstashRedisStore # noqa
|
Reference in New Issue
Block a user