couchbase: Add ttl support to caches & chat_message_history (#26214)

**Description:** Add support to delete documents automatically from the
caches & chat message history by adding a new optional parameter, `ttl`.


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

---------

Co-authored-by: Nithish Raghunandanan <nithishr@users.noreply.github.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Nithish Raghunandanan
2024-09-21 01:44:29 +02:00
committed by GitHub
parent c6c508ee96
commit 2d21274bf6
10 changed files with 901 additions and 417 deletions

View File

@@ -9,6 +9,7 @@ are duplicated in this utility from modules:
import hashlib
import json
import logging
from datetime import timedelta
from typing import Any, Dict, Optional, Union
from couchbase.cluster import Cluster
@@ -87,6 +88,16 @@ def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
return None
def _validate_ttl(ttl: Optional[timedelta]) -> None:
"""Validate the time to live"""
if not isinstance(ttl, timedelta):
raise ValueError(f"ttl should be of type timedelta but was {type(ttl)}.")
if ttl <= timedelta(seconds=0):
raise ValueError(
f"ttl must be greater than 0 but was {ttl.total_seconds()} seconds."
)
class CouchbaseCache(BaseCache):
"""Couchbase LLM Cache
LLM Cache that uses Couchbase as the backend
@@ -140,6 +151,7 @@ class CouchbaseCache(BaseCache):
bucket_name: str,
scope_name: str,
collection_name: str,
ttl: Optional[timedelta] = None,
**kwargs: Dict[str, Any],
) -> None:
"""Initialize the Couchbase LLM Cache
@@ -149,6 +161,8 @@ class CouchbaseCache(BaseCache):
scope_name (str): name of the scope in bucket to store documents in.
collection_name (str): name of the collection in the scope to store
documents in.
ttl (Optional[timedelta]): TTL or time for the document to live in the cache
After this time, the document will get deleted from the cache.
"""
if not isinstance(cluster, Cluster):
raise ValueError(
@@ -162,6 +176,8 @@ class CouchbaseCache(BaseCache):
self._scope_name = scope_name
self._collection_name = collection_name
self._ttl = None
# Check if the bucket exists
if not self._check_bucket_exists():
raise ValueError(
@@ -185,6 +201,11 @@ class CouchbaseCache(BaseCache):
except Exception as e:
raise e
# Check if the time to live is provided and valid
if ttl is not None:
_validate_ttl(ttl)
self._ttl = ttl
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up from cache based on prompt and llm_string."""
try:
@@ -206,10 +227,16 @@ class CouchbaseCache(BaseCache):
self.LLM: llm_string,
self.RETURN_VAL: _dumps_generations(return_val),
}
document_key = self._generate_key(prompt, llm_string)
try:
self._collection.upsert(
key=self._generate_key(prompt, llm_string), value=doc
)
if self._ttl:
self._collection.upsert(
key=document_key,
value=doc,
expiry=self._ttl,
)
else:
self._collection.upsert(key=document_key, value=doc)
except Exception:
logger.error("Error updating cache")
@@ -242,6 +269,7 @@ class CouchbaseSemanticCache(BaseCache, CouchbaseVectorStore):
collection_name: str,
index_name: str,
score_threshold: Optional[float] = None,
ttl: Optional[timedelta] = None,
) -> None:
"""Initialize the Couchbase LLM Cache
Args:
@@ -253,6 +281,8 @@ class CouchbaseSemanticCache(BaseCache, CouchbaseVectorStore):
documents in.
index_name (str): name of the Search index to use.
score_threshold (float): score threshold to use for filtering results.
ttl (Optional[timedelta]): TTL or time for the document to live in the cache
After this time, the document will get deleted from the cache.
"""
if not isinstance(cluster, Cluster):
raise ValueError(
@@ -265,6 +295,7 @@ class CouchbaseSemanticCache(BaseCache, CouchbaseVectorStore):
self._bucket_name = bucket_name
self._scope_name = scope_name
self._collection_name = collection_name
self._ttl = None
# Check if the bucket exists
if not self._check_bucket_exists():
@@ -291,6 +322,10 @@ class CouchbaseSemanticCache(BaseCache, CouchbaseVectorStore):
self.score_threshold = score_threshold
if ttl is not None:
_validate_ttl(ttl)
self._ttl = ttl
# Initialize the vector store
super().__init__(
cluster=cluster,
@@ -334,6 +369,7 @@ class CouchbaseSemanticCache(BaseCache, CouchbaseVectorStore):
self.RETURN_VAL: _dumps_generations(return_val),
}
],
ttl=self._ttl,
)
except Exception:
logger.error("Error updating cache")

View File

@@ -3,7 +3,8 @@
import logging
import time
import uuid
from typing import Any, Dict, List, Sequence
from datetime import timedelta
from typing import Any, Dict, List, Optional, Sequence
from couchbase.cluster import Cluster
from langchain_core.chat_history import BaseChatMessageHistory
@@ -22,6 +23,16 @@ DEFAULT_INDEX_NAME = "LANGCHAIN_CHAT_HISTORY"
DEFAULT_BATCH_SIZE = 100
def _validate_ttl(ttl: Optional[timedelta]) -> None:
"""Validate the time to live"""
if not isinstance(ttl, timedelta):
raise ValueError(f"ttl should be of type timedelta but was {type(ttl)}.")
if ttl <= timedelta(seconds=0):
raise ValueError(
f"ttl must be greater than 0 but was {ttl.total_seconds()} seconds."
)
class CouchbaseChatMessageHistory(BaseChatMessageHistory):
"""Couchbase Chat Message History
Chat message history that uses Couchbase as the storage
@@ -77,6 +88,7 @@ class CouchbaseChatMessageHistory(BaseChatMessageHistory):
session_id_key: str = DEFAULT_SESSION_ID_KEY,
message_key: str = DEFAULT_MESSAGE_KEY,
create_index: bool = True,
ttl: Optional[timedelta] = None,
) -> None:
"""Initialize the Couchbase Chat Message History
Args:
@@ -92,6 +104,8 @@ class CouchbaseChatMessageHistory(BaseChatMessageHistory):
message_key (str): name of the field to use for the messages
Set to "message" by default.
create_index (bool): create an index if True. Set to True by default.
ttl (timedelta): time to live for the documents in the collection.
When set, the documents are automatically deleted after the ttl expires.
"""
if not isinstance(cluster, Cluster):
raise ValueError(
@@ -104,6 +118,7 @@ class CouchbaseChatMessageHistory(BaseChatMessageHistory):
self._bucket_name = bucket_name
self._scope_name = scope_name
self._collection_name = collection_name
self._ttl = None
# Check if the bucket exists
if not self._check_bucket_exists():
@@ -134,6 +149,10 @@ class CouchbaseChatMessageHistory(BaseChatMessageHistory):
self._session_id = session_id
self._ts_key = DEFAULT_TS_KEY
if ttl is not None:
_validate_ttl(ttl)
self._ttl = ttl
# Create an index if it does not exist if requested
if create_index:
index_fields = (
@@ -156,15 +175,27 @@ class CouchbaseChatMessageHistory(BaseChatMessageHistory):
# get utc timestamp for ordering the messages
timestamp = time.time()
message_content = message_to_dict(message)
try:
self._collection.insert(
document_key,
value={
self._message_key: message_content,
self._session_id_key: self._session_id,
self._ts_key: timestamp,
},
)
if self._ttl:
self._collection.insert(
document_key,
value={
self._message_key: message_content,
self._session_id_key: self._session_id,
self._ts_key: timestamp,
},
expiry=self._ttl,
)
else:
self._collection.insert(
document_key,
value={
self._message_key: message_content,
self._session_id_key: self._session_id,
self._ts_key: timestamp,
},
)
except Exception as e:
logger.error("Error adding message: ", e)
@@ -192,7 +223,10 @@ class CouchbaseChatMessageHistory(BaseChatMessageHistory):
batch = messages_to_insert[i : i + batch_size]
# Convert list of dictionaries to a single dictionary to insert
insert_batch = {list(d.keys())[0]: list(d.values())[0] for d in batch}
self._collection.insert_multi(insert_batch)
if self._ttl:
self._collection.insert_multi(insert_batch, expiry=self._ttl)
else:
self._collection.insert_multi(insert_batch)
except Exception as e:
logger.error("Error adding messages: ", e)

View File

@@ -377,6 +377,9 @@ class CouchbaseVectorStore(VectorStore):
if metadatas is None:
metadatas = [{} for _ in texts]
# Check if TTL is provided
ttl = kwargs.get("ttl", None)
embedded_texts = self._embedding_function.embed_documents(list(texts))
documents_to_insert = [
@@ -396,7 +399,11 @@ class CouchbaseVectorStore(VectorStore):
for i in range(0, len(documents_to_insert), batch_size):
batch = documents_to_insert[i : i + batch_size]
try:
result = self._collection.upsert_multi(batch[0])
# Insert with TTL if provided
if ttl:
result = self._collection.upsert_multi(batch[0], expiry=ttl)
else:
result = self._collection.upsert_multi(batch[0])
if result.all_ok:
doc_ids.extend(batch[0].keys())
except DocumentExistsException as e: