langchain/libs/partners/couchbase/tests/integration_tests/test_cache.py
Nithish Raghunandanan f1618ec540
couchbase: Add standard and semantic caches (#23607)
Thank you for contributing to LangChain!

**Description:** Add support for caching (standard + semantic) LLM
responses using Couchbase


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Nithish Raghunandanan <nithishr@users.noreply.github.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 20:30:03 +00:00

121 lines
3.8 KiB
Python

"""Test Couchbase Cache functionality"""
import os
from datetime import timedelta
from typing import Any
import pytest
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
from couchbase.options import ClusterOptions
from langchain_core.globals import get_llm_cache, set_llm_cache
from langchain_core.outputs import Generation
from langchain_couchbase.cache import CouchbaseCache, CouchbaseSemanticCache
from tests.utils import FakeEmbeddings, FakeLLM
CONNECTION_STRING = os.getenv("COUCHBASE_CONNECTION_STRING", "")
BUCKET_NAME = os.getenv("COUCHBASE_BUCKET_NAME", "")
SCOPE_NAME = os.getenv("COUCHBASE_SCOPE_NAME", "")
CACHE_COLLECTION_NAME = os.getenv("COUCHBASE_CACHE_COLLECTION_NAME", "")
SEMANTIC_CACHE_COLLECTION_NAME = os.getenv(
"COUCHBASE_SEMANTIC_CACHE_COLLECTION_NAME", ""
)
USERNAME = os.getenv("COUCHBASE_USERNAME", "")
PASSWORD = os.getenv("COUCHBASE_PASSWORD", "")
INDEX_NAME = os.getenv("COUCHBASE_SEMANTIC_CACHE_INDEX_NAME", "")
def set_all_env_vars() -> bool:
"""Check if all environment variables are set"""
return all(
[
CONNECTION_STRING,
BUCKET_NAME,
SCOPE_NAME,
CACHE_COLLECTION_NAME,
USERNAME,
PASSWORD,
INDEX_NAME,
]
)
def get_cluster() -> Any:
"""Get a couchbase cluster object"""
auth = PasswordAuthenticator(USERNAME, PASSWORD)
options = ClusterOptions(auth)
connect_string = CONNECTION_STRING
cluster = Cluster(connect_string, options)
# Wait until the cluster is ready for use.
cluster.wait_until_ready(timedelta(seconds=5))
return cluster
@pytest.fixture()
def cluster() -> Any:
"""Get a couchbase cluster object"""
return get_cluster()
@pytest.mark.skipif(
not set_all_env_vars(), reason="Missing Couchbase environment variables"
)
class TestCouchbaseCache:
def test_cache(self, cluster: Any) -> None:
"""Test standard LLM cache functionality"""
set_llm_cache(
CouchbaseCache(
cluster=cluster,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=CACHE_COLLECTION_NAME,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
cache_output = get_llm_cache().lookup("foo", llm_string)
assert cache_output == [Generation(text="fizz")]
get_llm_cache().clear()
output = get_llm_cache().lookup("bar", llm_string)
assert output != [Generation(text="fizz")]
def test_semantic_cache(self, cluster: Any) -> None:
"""Test semantic LLM cache functionality"""
set_llm_cache(
CouchbaseSemanticCache(
cluster=cluster,
embedding=FakeEmbeddings(),
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=SEMANTIC_CACHE_COLLECTION_NAME,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
# clear the cache
get_llm_cache().clear()
output = get_llm_cache().lookup("bar", llm_string)
assert output != [Generation(text="fizz"), Generation(text="Buzz")]