diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py index 57fbf0aca78..91509c07578 100644 --- a/libs/community/langchain_community/cache.py +++ b/libs/community/langchain_community/cache.py @@ -53,6 +53,7 @@ from sqlalchemy.engine import Row from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session +from langchain_community.utilities.cassandra import SetupMode as CassandraSetupMode from langchain_community.vectorstores.azure_cosmos_db import ( CosmosDBSimilarityType, CosmosDBVectorSearchType, @@ -63,7 +64,7 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base -from langchain_core._api.deprecation import deprecated +from langchain_core._api.deprecation import deprecated, warn_deprecated from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts @@ -73,7 +74,9 @@ from langchain_core.outputs import ChatGeneration, Generation from langchain_core.utils import get_from_env from langchain_community.utilities.astradb import ( - SetupMode, + SetupMode as AstraSetupMode, +) +from langchain_community.utilities.astradb import ( _AstraDBCollectionEnvironment, ) from langchain_community.vectorstores import AzureCosmosDBVectorSearch @@ -1056,6 +1059,7 @@ class CassandraCache(BaseCache): table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME, ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS, skip_provisioning: bool = False, + setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC, ): """ Initialize with a ready session and a keyspace name. @@ -1066,6 +1070,10 @@ class CassandraCache(BaseCache): ttl_seconds (optional int): time-to-live for cache entries (default: None, i.e. forever) """ + if skip_provisioning: + warn_deprecated( + "0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead." + ) try: from cassio.table import ElasticCassandraTable except (ImportError, ModuleNotFoundError): @@ -1079,6 +1087,10 @@ class CassandraCache(BaseCache): self.table_name = table_name self.ttl_seconds = ttl_seconds + kwargs = {} + if setup_mode == CassandraSetupMode.ASYNC: + kwargs["async_setup"] = True + self.kv_cache = ElasticCassandraTable( session=self.session, keyspace=self.keyspace, @@ -1086,27 +1098,31 @@ class CassandraCache(BaseCache): keys=["llm_string", "prompt"], primary_key_type=["TEXT", "TEXT"], ttl_seconds=self.ttl_seconds, - skip_provisioning=skip_provisioning, + skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF, + **kwargs, ) def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - """Look up based on prompt and llm_string.""" item = self.kv_cache.get( llm_string=_hash(llm_string), prompt=_hash(prompt), ) if item is not None: - generations = _loads_generations(item["body_blob"]) - # this protects against malformed cached items: - if generations is not None: - return generations - else: - return None + return _loads_generations(item["body_blob"]) + else: + return None + + async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + item = await self.kv_cache.aget( + llm_string=_hash(llm_string), + prompt=_hash(prompt), + ) + if item is not None: + return _loads_generations(item["body_blob"]) else: return None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: - """Update cache based on prompt and llm_string.""" blob = _dumps_generations(return_val) self.kv_cache.put( llm_string=_hash(llm_string), @@ -1114,6 +1130,16 @@ class CassandraCache(BaseCache): body_blob=blob, ) + async def aupdate( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> None: + blob = _dumps_generations(return_val) + await self.kv_cache.aput( + llm_string=_hash(llm_string), + prompt=_hash(prompt), + body_blob=blob, + ) + def delete_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> None: @@ -1139,6 +1165,10 @@ class CassandraCache(BaseCache): """Clear cache. This is for all LLMs at once.""" self.kv_cache.clear() + async def aclear(self, **kwargs: Any) -> None: + """Clear cache. This is for all LLMs at once.""" + await self.kv_cache.aclear() + CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot" CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85 @@ -1170,6 +1200,7 @@ class CassandraSemanticCache(BaseCache): score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD, ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS, skip_provisioning: bool = False, + setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC, ): """ Initialize the cache with all relevant parameters. @@ -1189,6 +1220,10 @@ class CassandraSemanticCache(BaseCache): The default score threshold is tuned to the default metric. Tune it carefully yourself if switching to another distance metric. """ + if skip_provisioning: + warn_deprecated( + "0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead." + ) try: from cassio.table import MetadataVectorCassandraTable except (ImportError, ModuleNotFoundError): @@ -1214,24 +1249,42 @@ class CassandraSemanticCache(BaseCache): return self.embedding.embed_query(text=text) self._get_embedding = _cache_embedding - self.embedding_dimension = self._get_embedding_dimension() + + @_async_lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) + async def _acache_embedding(text: str) -> List[float]: + return await self.embedding.aembed_query(text=text) + + self._aget_embedding = _acache_embedding + + embedding_dimension: Union[int, Awaitable[int], None] = None + if setup_mode == CassandraSetupMode.ASYNC: + embedding_dimension = self._aget_embedding_dimension() + elif setup_mode == CassandraSetupMode.SYNC: + embedding_dimension = self._get_embedding_dimension() + + kwargs = {} + if setup_mode == CassandraSetupMode.ASYNC: + kwargs["async_setup"] = True self.table = MetadataVectorCassandraTable( session=self.session, keyspace=self.keyspace, table=self.table_name, primary_key_type=["TEXT"], - vector_dimension=self.embedding_dimension, + vector_dimension=embedding_dimension, ttl_seconds=self.ttl_seconds, metadata_indexing=("allow", {"_llm_string_hash"}), - skip_provisioning=skip_provisioning, + skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF, + **kwargs, ) def _get_embedding_dimension(self) -> int: return len(self._get_embedding(text="This is a sample sentence.")) + async def _aget_embedding_dimension(self) -> int: + return len(await self._aget_embedding(text="This is a sample sentence.")) + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: - """Update cache based on prompt and llm_string.""" embedding_vector = self._get_embedding(text=prompt) llm_string_hash = _hash(llm_string) body = _dumps_generations(return_val) @@ -1240,7 +1293,7 @@ class CassandraSemanticCache(BaseCache): "_llm_string_hash": llm_string_hash, } row_id = f"{_hash(prompt)}-{llm_string_hash}" - # + self.table.put( body_blob=body, vector=embedding_vector, @@ -1248,14 +1301,39 @@ class CassandraSemanticCache(BaseCache): metadata=metadata, ) + async def aupdate( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> None: + embedding_vector = await self._aget_embedding(text=prompt) + llm_string_hash = _hash(llm_string) + body = _dumps_generations(return_val) + metadata = { + "_prompt": prompt, + "_llm_string_hash": llm_string_hash, + } + row_id = f"{_hash(prompt)}-{llm_string_hash}" + + await self.table.aput( + body_blob=body, + vector=embedding_vector, + row_id=row_id, + metadata=metadata, + ) + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - """Look up based on prompt and llm_string.""" hit_with_id = self.lookup_with_id(prompt, llm_string) if hit_with_id is not None: return hit_with_id[1] else: return None + async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + hit_with_id = await self.alookup_with_id(prompt, llm_string) + if hit_with_id is not None: + return hit_with_id[1] + else: + return None + def lookup_with_id( self, prompt: str, llm_string: str ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: @@ -1287,6 +1365,37 @@ class CassandraSemanticCache(BaseCache): else: return None + async def alookup_with_id( + self, prompt: str, llm_string: str + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + """ + Look up based on prompt and llm_string. + If there are hits, return (document_id, cached_entry) + """ + prompt_embedding: List[float] = await self._aget_embedding(text=prompt) + hits = list( + await self.table.ametric_ann_search( + vector=prompt_embedding, + metadata={"_llm_string_hash": _hash(llm_string)}, + n=1, + metric=self.distance_metric, + metric_threshold=self.score_threshold, + ) + ) + if hits: + hit = hits[0] + generations = _loads_generations(hit["body_blob"]) + if generations is not None: + # this protects against malformed cached items: + return ( + hit["row_id"], + generations, + ) + else: + return None + else: + return None + def lookup_with_id_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: @@ -1296,6 +1405,17 @@ class CassandraSemanticCache(BaseCache): )[1] return self.lookup_with_id(prompt, llm_string=llm_string) + async def alookup_with_id_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + llm_string = ( + await aget_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + ) + )[1] + return await self.alookup_with_id(prompt, llm_string=llm_string) + def delete_by_document_id(self, document_id: str) -> None: """ Given this is a "similarity search" cache, an invalidation pattern @@ -1304,10 +1424,22 @@ class CassandraSemanticCache(BaseCache): """ self.table.delete(row_id=document_id) + async def adelete_by_document_id(self, document_id: str) -> None: + """ + Given this is a "similarity search" cache, an invalidation pattern + that makes sense is first a lookup to get an ID, and then deleting + with that ID. This is for the second step. + """ + await self.table.adelete(row_id=document_id) + def clear(self, **kwargs: Any) -> None: """Clear the *whole* semantic cache.""" self.table.clear() + async def aclear(self, **kwargs: Any) -> None: + """Clear the *whole* semantic cache.""" + await self.table.aclear() + class FullMd5LLMCache(Base): # type: ignore """SQLite table for full LLM Cache (all generations).""" @@ -1412,7 +1544,7 @@ class AstraDBCache(BaseCache): async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, pre_delete_collection: bool = False, - setup_mode: SetupMode = SetupMode.SYNC, + setup_mode: AstraSetupMode = AstraSetupMode.SYNC, ): """ Cache that uses Astra DB as a backend. @@ -1612,7 +1744,7 @@ class AstraDBSemanticCache(BaseCache): astra_db_client: Optional[AstraDB] = None, async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, - setup_mode: SetupMode = SetupMode.SYNC, + setup_mode: AstraSetupMode = AstraSetupMode.SYNC, pre_delete_collection: bool = False, embedding: Embeddings, metric: Optional[str] = None, @@ -1675,9 +1807,9 @@ class AstraDBSemanticCache(BaseCache): self._aget_embedding = _acache_embedding embedding_dimension: Union[int, Awaitable[int], None] = None - if setup_mode == SetupMode.ASYNC: + if setup_mode == AstraSetupMode.ASYNC: embedding_dimension = self._aget_embedding_dimension() - elif setup_mode == SetupMode.SYNC: + elif setup_mode == AstraSetupMode.SYNC: embedding_dimension = self._get_embedding_dimension() self.astra_env = _AstraDBCollectionEnvironment( diff --git a/libs/core/langchain_core/caches.py b/libs/core/langchain_core/caches.py index 4e9e2993a04..2e37e028cd2 100644 --- a/libs/core/langchain_core/caches.py +++ b/libs/core/langchain_core/caches.py @@ -93,13 +93,47 @@ class BaseCache(ABC): """Clear cache that can take additional keyword arguments.""" async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - """Async version of lookup.""" + """Look up based on prompt and llm_string. + + A cache implementation is expected to generate a key from the 2-tuple + of prompt and llm_string (e.g., by concatenating them with a delimiter). + + Args: + prompt: a string representation of the prompt. + In the case of a Chat model, the prompt is a non-trivial + serialization of the prompt into the language model. + llm_string: A string representation of the LLM configuration. + This is used to capture the invocation parameters of the LLM + (e.g., model name, temperature, stop tokens, max tokens, etc.). + These invocation parameters are serialized into a string + representation. + + Returns: + On a cache miss, return None. On a cache hit, return the cached value. + The cached value is a list of Generations (or subclasses). + """ return await run_in_executor(None, self.lookup, prompt, llm_string) async def aupdate( self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE ) -> None: - """Async version of aupdate.""" + """Update cache based on prompt and llm_string. + + The prompt and llm_string are used to generate a key for the cache. + The key should match that of the look up method. + + Args: + prompt: a string representation of the prompt. + In the case of a Chat model, the prompt is a non-trivial + serialization of the prompt into the language model. + llm_string: A string representation of the LLM configuration. + This is used to capture the invocation parameters of the LLM + (e.g., model name, temperature, stop tokens, max tokens, etc.). + These invocation parameters are serialized into a string + representation. + return_val: The value to be cached. The value is a list of Generations + (or subclasses). + """ return await run_in_executor(None, self.update, prompt, llm_string, return_val) async def aclear(self, **kwargs: Any) -> None: diff --git a/libs/langchain/tests/integration_tests/cache/test_cassandra.py b/libs/langchain/tests/integration_tests/cache/test_cassandra.py index 19a8efaf4ac..a61eb764eb7 100644 --- a/libs/langchain/tests/integration_tests/cache/test_cassandra.py +++ b/libs/langchain/tests/integration_tests/cache/test_cassandra.py @@ -1,10 +1,11 @@ """Test Cassandra caches. Requires a running vector-capable Cassandra cluster.""" - +import asyncio import os import time from typing import Any, Iterator, Tuple import pytest +from langchain_community.utilities.cassandra import SetupMode from langchain_core.outputs import Generation, LLMResult from langchain.cache import CassandraCache, CassandraSemanticCache @@ -47,16 +48,34 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None: llm_string = str(sorted([(k, v) for k, v in params.items()])) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) output = llm.generate(["foo"]) - print(output) # noqa: T201 expected_output = LLMResult( generations=[[Generation(text="fizz")]], llm_output={}, ) - print(expected_output) # noqa: T201 assert output == expected_output cache.clear() +async def test_cassandra_cache_async(cassandra_connection: Tuple[Any, str]) -> None: + session, keyspace = cassandra_connection + cache = CassandraCache( + session=session, keyspace=keyspace, setup_mode=SetupMode.ASYNC + ) + set_llm_cache(cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")]) + output = await llm.agenerate(["foo"]) + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + assert output == expected_output + await cache.aclear() + + def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None: session, keyspace = cassandra_connection cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2) @@ -79,6 +98,30 @@ def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None: cache.clear() +async def test_cassandra_cache_ttl_async(cassandra_connection: Tuple[Any, str]) -> None: + session, keyspace = cassandra_connection + cache = CassandraCache( + session=session, keyspace=keyspace, ttl_seconds=2, setup_mode=SetupMode.ASYNC + ) + set_llm_cache(cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")]) + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + output = await llm.agenerate(["foo"]) + assert output == expected_output + await asyncio.sleep(2.5) + # entry has expired away. + output = await llm.agenerate(["foo"]) + assert output != expected_output + await cache.aclear() + + def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None: session, keyspace = cassandra_connection sem_cache = CassandraSemanticCache( @@ -103,3 +146,32 @@ def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None output = llm.generate(["bar"]) # 'fizz' is erased away now assert output != expected_output sem_cache.clear() + + +async def test_cassandra_semantic_cache_async( + cassandra_connection: Tuple[Any, str], +) -> None: + session, keyspace = cassandra_connection + sem_cache = CassandraSemanticCache( + session=session, + keyspace=keyspace, + embedding=FakeEmbeddings(), + setup_mode=SetupMode.ASYNC, + ) + set_llm_cache(sem_cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")]) + output = await llm.agenerate(["bar"]) # same embedding as 'foo' + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + assert output == expected_output + # clear the cache + await sem_cache.aclear() + output = await llm.agenerate(["bar"]) # 'fizz' is erased away now + assert output != expected_output + await sem_cache.aclear()