mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 08:32:32 +00:00
**Description:** Add support to delete documents automatically from the caches & chat message history by adding a new optional parameter, `ttl`. - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Nithish Raghunandanan <nithishr@users.noreply.github.com> Co-authored-by: Erick Friis <erick@langchain.dev>
219 lines
7.0 KiB
Python
219 lines
7.0 KiB
Python
"""Test Couchbase Cache functionality"""
|
|
|
|
import os
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from couchbase.auth import PasswordAuthenticator
|
|
from couchbase.cluster import Cluster
|
|
from couchbase.options import ClusterOptions
|
|
from langchain_core.globals import get_llm_cache, set_llm_cache
|
|
from langchain_core.outputs import Generation
|
|
|
|
from langchain_couchbase.cache import CouchbaseCache, CouchbaseSemanticCache
|
|
from tests.utils import (
|
|
FakeEmbeddings,
|
|
FakeLLM,
|
|
cache_key_hash_function,
|
|
fetch_document_expiry_time,
|
|
get_document_keys,
|
|
)
|
|
|
|
CONNECTION_STRING = os.getenv("COUCHBASE_CONNECTION_STRING", "")
|
|
BUCKET_NAME = os.getenv("COUCHBASE_BUCKET_NAME", "")
|
|
SCOPE_NAME = os.getenv("COUCHBASE_SCOPE_NAME", "")
|
|
CACHE_COLLECTION_NAME = os.getenv("COUCHBASE_CACHE_COLLECTION_NAME", "")
|
|
SEMANTIC_CACHE_COLLECTION_NAME = os.getenv(
|
|
"COUCHBASE_SEMANTIC_CACHE_COLLECTION_NAME", ""
|
|
)
|
|
USERNAME = os.getenv("COUCHBASE_USERNAME", "")
|
|
PASSWORD = os.getenv("COUCHBASE_PASSWORD", "")
|
|
INDEX_NAME = os.getenv("COUCHBASE_SEMANTIC_CACHE_INDEX_NAME", "")
|
|
|
|
|
|
def set_all_env_vars() -> bool:
|
|
"""Check if all environment variables are set"""
|
|
return all(
|
|
[
|
|
CONNECTION_STRING,
|
|
BUCKET_NAME,
|
|
SCOPE_NAME,
|
|
CACHE_COLLECTION_NAME,
|
|
USERNAME,
|
|
PASSWORD,
|
|
INDEX_NAME,
|
|
]
|
|
)
|
|
|
|
|
|
def get_cluster() -> Any:
|
|
"""Get a couchbase cluster object"""
|
|
auth = PasswordAuthenticator(USERNAME, PASSWORD)
|
|
options = ClusterOptions(auth)
|
|
connect_string = CONNECTION_STRING
|
|
cluster = Cluster(connect_string, options)
|
|
|
|
# Wait until the cluster is ready for use.
|
|
cluster.wait_until_ready(timedelta(seconds=5))
|
|
|
|
return cluster
|
|
|
|
|
|
@pytest.fixture()
|
|
def cluster() -> Any:
|
|
"""Get a couchbase cluster object"""
|
|
return get_cluster()
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not set_all_env_vars(), reason="Missing Couchbase environment variables"
|
|
)
|
|
class TestCouchbaseCache:
|
|
def test_cache(self, cluster: Any) -> None:
|
|
"""Test standard LLM cache functionality"""
|
|
set_llm_cache(
|
|
CouchbaseCache(
|
|
cluster=cluster,
|
|
bucket_name=BUCKET_NAME,
|
|
scope_name=SCOPE_NAME,
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
)
|
|
)
|
|
|
|
llm = FakeLLM()
|
|
|
|
params = llm.dict()
|
|
params["stop"] = None
|
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
|
cache_output = get_llm_cache().lookup("foo", llm_string)
|
|
assert cache_output == [Generation(text="fizz")]
|
|
|
|
get_llm_cache().clear()
|
|
output = get_llm_cache().lookup("bar", llm_string)
|
|
assert output != [Generation(text="fizz")]
|
|
|
|
def test_cache_with_ttl(self, cluster: Any) -> None:
|
|
"""Test standard LLM cache functionality with TTL"""
|
|
ttl = timedelta(minutes=10)
|
|
set_llm_cache(
|
|
CouchbaseCache(
|
|
cluster=cluster,
|
|
bucket_name=BUCKET_NAME,
|
|
scope_name=SCOPE_NAME,
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
ttl=ttl,
|
|
)
|
|
)
|
|
|
|
llm = FakeLLM()
|
|
|
|
params = llm.dict()
|
|
params["stop"] = None
|
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
|
cache_output = get_llm_cache().lookup("foo", llm_string)
|
|
assert cache_output == [Generation(text="fizz")]
|
|
|
|
# Check the document's expiry time by fetching it from the database
|
|
document_key = cache_key_hash_function("foo" + llm_string)
|
|
document_expiry_time = fetch_document_expiry_time(
|
|
cluster, BUCKET_NAME, SCOPE_NAME, CACHE_COLLECTION_NAME, document_key
|
|
)
|
|
current_time = datetime.now()
|
|
|
|
time_to_expiry = document_expiry_time - current_time
|
|
|
|
assert time_to_expiry < ttl
|
|
|
|
def test_semantic_cache(self, cluster: Any) -> None:
|
|
"""Test semantic LLM cache functionality"""
|
|
set_llm_cache(
|
|
CouchbaseSemanticCache(
|
|
cluster=cluster,
|
|
embedding=FakeEmbeddings(),
|
|
index_name=INDEX_NAME,
|
|
bucket_name=BUCKET_NAME,
|
|
scope_name=SCOPE_NAME,
|
|
collection_name=SEMANTIC_CACHE_COLLECTION_NAME,
|
|
)
|
|
)
|
|
|
|
llm = FakeLLM()
|
|
|
|
params = llm.dict()
|
|
params["stop"] = None
|
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
get_llm_cache().update(
|
|
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
|
)
|
|
|
|
# foo and bar will have the same embedding produced by FakeEmbeddings
|
|
cache_output = get_llm_cache().lookup("bar", llm_string)
|
|
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
|
|
|
|
# clear the cache
|
|
get_llm_cache().clear()
|
|
output = get_llm_cache().lookup("bar", llm_string)
|
|
assert output != [Generation(text="fizz"), Generation(text="Buzz")]
|
|
|
|
def test_semantic_cache_with_ttl(self, cluster: Any) -> None:
|
|
"""Test semantic LLM cache functionality with TTL"""
|
|
ttl = timedelta(minutes=10)
|
|
|
|
set_llm_cache(
|
|
CouchbaseSemanticCache(
|
|
cluster=cluster,
|
|
embedding=FakeEmbeddings(),
|
|
index_name=INDEX_NAME,
|
|
bucket_name=BUCKET_NAME,
|
|
scope_name=SCOPE_NAME,
|
|
collection_name=SEMANTIC_CACHE_COLLECTION_NAME,
|
|
ttl=ttl,
|
|
)
|
|
)
|
|
|
|
llm = FakeLLM()
|
|
|
|
params = llm.dict()
|
|
params["stop"] = None
|
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
|
|
# Add a document to the cache
|
|
seed_prompt = "foo"
|
|
get_llm_cache().update(
|
|
seed_prompt, llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
|
)
|
|
|
|
# foo and bar will have the same embedding produced by FakeEmbeddings
|
|
cache_output = get_llm_cache().lookup("bar", llm_string)
|
|
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
|
|
|
|
# Check the document's expiry time by fetching it from the database
|
|
fetch_document_query = (
|
|
f"SELECT meta().id, * from `{SEMANTIC_CACHE_COLLECTION_NAME}` doc "
|
|
f"WHERE doc.text = '{seed_prompt}'"
|
|
)
|
|
|
|
document_keys = get_document_keys(
|
|
cluster=cluster,
|
|
bucket_name=BUCKET_NAME,
|
|
scope_name=SCOPE_NAME,
|
|
query=fetch_document_query,
|
|
)
|
|
assert len(document_keys) == 1
|
|
|
|
document_expiry_time = fetch_document_expiry_time(
|
|
cluster=cluster,
|
|
bucket_name=BUCKET_NAME,
|
|
scope_name=SCOPE_NAME,
|
|
collection_name=SEMANTIC_CACHE_COLLECTION_NAME,
|
|
document_key=document_keys[0],
|
|
)
|
|
current_time = datetime.now()
|
|
|
|
time_to_expiry = document_expiry_time - current_time
|
|
|
|
assert time_to_expiry < ttl
|