langchain/libs/partners/couchbase/tests/utils.py
Nithish Raghunandanan 2d21274bf6
couchbase: Add ttl support to caches & chat_message_history (#26214)
**Description:** Add support to delete documents automatically from the
caches & chat message history by adding a new optional parameter, `ttl`.


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

---------

Co-authored-by: Nithish Raghunandanan <nithishr@users.noreply.github.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-09-20 23:44:29 +00:00

137 lines
4.3 KiB
Python

"""Utilities for testing purposes."""
import hashlib
from datetime import datetime
from typing import Any, Dict, List, Mapping, Optional, cast
from couchbase.cluster import Cluster
from couchbase.options import GetOptions
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import LLM
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""
def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = []
self.dimensionality = dimensionality
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""
queries: Optional[Mapping] = None
sequential_responses: Optional[bool] = False
response_index: int = 0
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens."""
return len(text.split())
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.sequential_responses:
return self._get_next_response_in_sequence
if self.queries is not None:
return self.queries[prompt]
if stop is None:
return "foo"
else:
return "bar"
@property
def _identifying_params(self) -> Dict[str, Any]:
return {}
@property
def _get_next_response_in_sequence(self) -> str:
queries = cast(Mapping, self.queries)
response = queries[list(queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response
def cache_key_hash_function(_input: str) -> str:
"""Use a deterministic hashing approach."""
return hashlib.md5(_input.encode()).hexdigest()
def fetch_document_expiry_time(
cluster: Cluster,
bucket_name: str,
scope_name: str,
collection_name: str,
document_key: str,
) -> datetime:
"""Fetch the document's expiry time from the database."""
collection = (
cluster.bucket(bucket_name).scope(scope_name).collection(collection_name)
)
result = collection.get(document_key, GetOptions(with_expiry=True))
return result.expiryTime
def get_document_keys(
cluster: Cluster, bucket_name: str, scope_name: str, query: str
) -> List[str]:
"""Get the document key from the database based on the query using meta().id."""
scope = cluster.bucket(bucket_name).scope(scope_name)
result = scope.query(query).execute()
document_keys = [row["id"] for row in result]
return document_keys