mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 12:39:32 +00:00
community[minor]: Add async methods to CassandraCache and CassandraSemanticCache (#20654)
This commit is contained in:
committed by
GitHub
parent
d6e9bd3011
commit
5c77f45b06
@@ -53,6 +53,7 @@ from sqlalchemy.engine import Row
|
|||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from langchain_community.utilities.cassandra import SetupMode as CassandraSetupMode
|
||||||
from langchain_community.vectorstores.azure_cosmos_db import (
|
from langchain_community.vectorstores.azure_cosmos_db import (
|
||||||
CosmosDBSimilarityType,
|
CosmosDBSimilarityType,
|
||||||
CosmosDBVectorSearchType,
|
CosmosDBVectorSearchType,
|
||||||
@@ -63,7 +64,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from langchain_core._api.deprecation import deprecated
|
from langchain_core._api.deprecation import deprecated, warn_deprecated
|
||||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
|
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
|
||||||
@@ -73,7 +74,9 @@ from langchain_core.outputs import ChatGeneration, Generation
|
|||||||
from langchain_core.utils import get_from_env
|
from langchain_core.utils import get_from_env
|
||||||
|
|
||||||
from langchain_community.utilities.astradb import (
|
from langchain_community.utilities.astradb import (
|
||||||
SetupMode,
|
SetupMode as AstraSetupMode,
|
||||||
|
)
|
||||||
|
from langchain_community.utilities.astradb import (
|
||||||
_AstraDBCollectionEnvironment,
|
_AstraDBCollectionEnvironment,
|
||||||
)
|
)
|
||||||
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
|
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
|
||||||
@@ -1056,6 +1059,7 @@ class CassandraCache(BaseCache):
|
|||||||
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
|
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
|
||||||
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
|
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
|
||||||
skip_provisioning: bool = False,
|
skip_provisioning: bool = False,
|
||||||
|
setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with a ready session and a keyspace name.
|
Initialize with a ready session and a keyspace name.
|
||||||
@@ -1066,6 +1070,10 @@ class CassandraCache(BaseCache):
|
|||||||
ttl_seconds (optional int): time-to-live for cache entries
|
ttl_seconds (optional int): time-to-live for cache entries
|
||||||
(default: None, i.e. forever)
|
(default: None, i.e. forever)
|
||||||
"""
|
"""
|
||||||
|
if skip_provisioning:
|
||||||
|
warn_deprecated(
|
||||||
|
"0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead."
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
from cassio.table import ElasticCassandraTable
|
from cassio.table import ElasticCassandraTable
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
@@ -1079,6 +1087,10 @@ class CassandraCache(BaseCache):
|
|||||||
self.table_name = table_name
|
self.table_name = table_name
|
||||||
self.ttl_seconds = ttl_seconds
|
self.ttl_seconds = ttl_seconds
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if setup_mode == CassandraSetupMode.ASYNC:
|
||||||
|
kwargs["async_setup"] = True
|
||||||
|
|
||||||
self.kv_cache = ElasticCassandraTable(
|
self.kv_cache = ElasticCassandraTable(
|
||||||
session=self.session,
|
session=self.session,
|
||||||
keyspace=self.keyspace,
|
keyspace=self.keyspace,
|
||||||
@@ -1086,27 +1098,31 @@ class CassandraCache(BaseCache):
|
|||||||
keys=["llm_string", "prompt"],
|
keys=["llm_string", "prompt"],
|
||||||
primary_key_type=["TEXT", "TEXT"],
|
primary_key_type=["TEXT", "TEXT"],
|
||||||
ttl_seconds=self.ttl_seconds,
|
ttl_seconds=self.ttl_seconds,
|
||||||
skip_provisioning=skip_provisioning,
|
skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Look up based on prompt and llm_string."""
|
|
||||||
item = self.kv_cache.get(
|
item = self.kv_cache.get(
|
||||||
llm_string=_hash(llm_string),
|
llm_string=_hash(llm_string),
|
||||||
prompt=_hash(prompt),
|
prompt=_hash(prompt),
|
||||||
)
|
)
|
||||||
if item is not None:
|
if item is not None:
|
||||||
generations = _loads_generations(item["body_blob"])
|
return _loads_generations(item["body_blob"])
|
||||||
# this protects against malformed cached items:
|
else:
|
||||||
if generations is not None:
|
return None
|
||||||
return generations
|
|
||||||
else:
|
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
return None
|
item = await self.kv_cache.aget(
|
||||||
|
llm_string=_hash(llm_string),
|
||||||
|
prompt=_hash(prompt),
|
||||||
|
)
|
||||||
|
if item is not None:
|
||||||
|
return _loads_generations(item["body_blob"])
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Update cache based on prompt and llm_string."""
|
|
||||||
blob = _dumps_generations(return_val)
|
blob = _dumps_generations(return_val)
|
||||||
self.kv_cache.put(
|
self.kv_cache.put(
|
||||||
llm_string=_hash(llm_string),
|
llm_string=_hash(llm_string),
|
||||||
@@ -1114,6 +1130,16 @@ class CassandraCache(BaseCache):
|
|||||||
body_blob=blob,
|
body_blob=blob,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def aupdate(
|
||||||
|
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
||||||
|
) -> None:
|
||||||
|
blob = _dumps_generations(return_val)
|
||||||
|
await self.kv_cache.aput(
|
||||||
|
llm_string=_hash(llm_string),
|
||||||
|
prompt=_hash(prompt),
|
||||||
|
body_blob=blob,
|
||||||
|
)
|
||||||
|
|
||||||
def delete_through_llm(
|
def delete_through_llm(
|
||||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -1139,6 +1165,10 @@ class CassandraCache(BaseCache):
|
|||||||
"""Clear cache. This is for all LLMs at once."""
|
"""Clear cache. This is for all LLMs at once."""
|
||||||
self.kv_cache.clear()
|
self.kv_cache.clear()
|
||||||
|
|
||||||
|
async def aclear(self, **kwargs: Any) -> None:
|
||||||
|
"""Clear cache. This is for all LLMs at once."""
|
||||||
|
await self.kv_cache.aclear()
|
||||||
|
|
||||||
|
|
||||||
CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
|
CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
|
||||||
CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
|
CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
|
||||||
@@ -1170,6 +1200,7 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
|
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
|
||||||
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
|
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
|
||||||
skip_provisioning: bool = False,
|
skip_provisioning: bool = False,
|
||||||
|
setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the cache with all relevant parameters.
|
Initialize the cache with all relevant parameters.
|
||||||
@@ -1189,6 +1220,10 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
The default score threshold is tuned to the default metric.
|
The default score threshold is tuned to the default metric.
|
||||||
Tune it carefully yourself if switching to another distance metric.
|
Tune it carefully yourself if switching to another distance metric.
|
||||||
"""
|
"""
|
||||||
|
if skip_provisioning:
|
||||||
|
warn_deprecated(
|
||||||
|
"0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead."
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
from cassio.table import MetadataVectorCassandraTable
|
from cassio.table import MetadataVectorCassandraTable
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
@@ -1214,24 +1249,42 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
return self.embedding.embed_query(text=text)
|
return self.embedding.embed_query(text=text)
|
||||||
|
|
||||||
self._get_embedding = _cache_embedding
|
self._get_embedding = _cache_embedding
|
||||||
self.embedding_dimension = self._get_embedding_dimension()
|
|
||||||
|
@_async_lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
|
||||||
|
async def _acache_embedding(text: str) -> List[float]:
|
||||||
|
return await self.embedding.aembed_query(text=text)
|
||||||
|
|
||||||
|
self._aget_embedding = _acache_embedding
|
||||||
|
|
||||||
|
embedding_dimension: Union[int, Awaitable[int], None] = None
|
||||||
|
if setup_mode == CassandraSetupMode.ASYNC:
|
||||||
|
embedding_dimension = self._aget_embedding_dimension()
|
||||||
|
elif setup_mode == CassandraSetupMode.SYNC:
|
||||||
|
embedding_dimension = self._get_embedding_dimension()
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if setup_mode == CassandraSetupMode.ASYNC:
|
||||||
|
kwargs["async_setup"] = True
|
||||||
|
|
||||||
self.table = MetadataVectorCassandraTable(
|
self.table = MetadataVectorCassandraTable(
|
||||||
session=self.session,
|
session=self.session,
|
||||||
keyspace=self.keyspace,
|
keyspace=self.keyspace,
|
||||||
table=self.table_name,
|
table=self.table_name,
|
||||||
primary_key_type=["TEXT"],
|
primary_key_type=["TEXT"],
|
||||||
vector_dimension=self.embedding_dimension,
|
vector_dimension=embedding_dimension,
|
||||||
ttl_seconds=self.ttl_seconds,
|
ttl_seconds=self.ttl_seconds,
|
||||||
metadata_indexing=("allow", {"_llm_string_hash"}),
|
metadata_indexing=("allow", {"_llm_string_hash"}),
|
||||||
skip_provisioning=skip_provisioning,
|
skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_embedding_dimension(self) -> int:
|
def _get_embedding_dimension(self) -> int:
|
||||||
return len(self._get_embedding(text="This is a sample sentence."))
|
return len(self._get_embedding(text="This is a sample sentence."))
|
||||||
|
|
||||||
|
async def _aget_embedding_dimension(self) -> int:
|
||||||
|
return len(await self._aget_embedding(text="This is a sample sentence."))
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Update cache based on prompt and llm_string."""
|
|
||||||
embedding_vector = self._get_embedding(text=prompt)
|
embedding_vector = self._get_embedding(text=prompt)
|
||||||
llm_string_hash = _hash(llm_string)
|
llm_string_hash = _hash(llm_string)
|
||||||
body = _dumps_generations(return_val)
|
body = _dumps_generations(return_val)
|
||||||
@@ -1240,7 +1293,7 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
"_llm_string_hash": llm_string_hash,
|
"_llm_string_hash": llm_string_hash,
|
||||||
}
|
}
|
||||||
row_id = f"{_hash(prompt)}-{llm_string_hash}"
|
row_id = f"{_hash(prompt)}-{llm_string_hash}"
|
||||||
#
|
|
||||||
self.table.put(
|
self.table.put(
|
||||||
body_blob=body,
|
body_blob=body,
|
||||||
vector=embedding_vector,
|
vector=embedding_vector,
|
||||||
@@ -1248,14 +1301,39 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def aupdate(
|
||||||
|
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
||||||
|
) -> None:
|
||||||
|
embedding_vector = await self._aget_embedding(text=prompt)
|
||||||
|
llm_string_hash = _hash(llm_string)
|
||||||
|
body = _dumps_generations(return_val)
|
||||||
|
metadata = {
|
||||||
|
"_prompt": prompt,
|
||||||
|
"_llm_string_hash": llm_string_hash,
|
||||||
|
}
|
||||||
|
row_id = f"{_hash(prompt)}-{llm_string_hash}"
|
||||||
|
|
||||||
|
await self.table.aput(
|
||||||
|
body_blob=body,
|
||||||
|
vector=embedding_vector,
|
||||||
|
row_id=row_id,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Look up based on prompt and llm_string."""
|
|
||||||
hit_with_id = self.lookup_with_id(prompt, llm_string)
|
hit_with_id = self.lookup_with_id(prompt, llm_string)
|
||||||
if hit_with_id is not None:
|
if hit_with_id is not None:
|
||||||
return hit_with_id[1]
|
return hit_with_id[1]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
|
hit_with_id = await self.alookup_with_id(prompt, llm_string)
|
||||||
|
if hit_with_id is not None:
|
||||||
|
return hit_with_id[1]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def lookup_with_id(
|
def lookup_with_id(
|
||||||
self, prompt: str, llm_string: str
|
self, prompt: str, llm_string: str
|
||||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||||
@@ -1287,6 +1365,37 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def alookup_with_id(
|
||||||
|
self, prompt: str, llm_string: str
|
||||||
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||||
|
"""
|
||||||
|
Look up based on prompt and llm_string.
|
||||||
|
If there are hits, return (document_id, cached_entry)
|
||||||
|
"""
|
||||||
|
prompt_embedding: List[float] = await self._aget_embedding(text=prompt)
|
||||||
|
hits = list(
|
||||||
|
await self.table.ametric_ann_search(
|
||||||
|
vector=prompt_embedding,
|
||||||
|
metadata={"_llm_string_hash": _hash(llm_string)},
|
||||||
|
n=1,
|
||||||
|
metric=self.distance_metric,
|
||||||
|
metric_threshold=self.score_threshold,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if hits:
|
||||||
|
hit = hits[0]
|
||||||
|
generations = _loads_generations(hit["body_blob"])
|
||||||
|
if generations is not None:
|
||||||
|
# this protects against malformed cached items:
|
||||||
|
return (
|
||||||
|
hit["row_id"],
|
||||||
|
generations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def lookup_with_id_through_llm(
|
def lookup_with_id_through_llm(
|
||||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||||
@@ -1296,6 +1405,17 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
)[1]
|
)[1]
|
||||||
return self.lookup_with_id(prompt, llm_string=llm_string)
|
return self.lookup_with_id(prompt, llm_string=llm_string)
|
||||||
|
|
||||||
|
async def alookup_with_id_through_llm(
|
||||||
|
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||||
|
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||||
|
llm_string = (
|
||||||
|
await aget_prompts(
|
||||||
|
{**llm.dict(), **{"stop": stop}},
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
)[1]
|
||||||
|
return await self.alookup_with_id(prompt, llm_string=llm_string)
|
||||||
|
|
||||||
def delete_by_document_id(self, document_id: str) -> None:
|
def delete_by_document_id(self, document_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Given this is a "similarity search" cache, an invalidation pattern
|
Given this is a "similarity search" cache, an invalidation pattern
|
||||||
@@ -1304,10 +1424,22 @@ class CassandraSemanticCache(BaseCache):
|
|||||||
"""
|
"""
|
||||||
self.table.delete(row_id=document_id)
|
self.table.delete(row_id=document_id)
|
||||||
|
|
||||||
|
async def adelete_by_document_id(self, document_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Given this is a "similarity search" cache, an invalidation pattern
|
||||||
|
that makes sense is first a lookup to get an ID, and then deleting
|
||||||
|
with that ID. This is for the second step.
|
||||||
|
"""
|
||||||
|
await self.table.adelete(row_id=document_id)
|
||||||
|
|
||||||
def clear(self, **kwargs: Any) -> None:
|
def clear(self, **kwargs: Any) -> None:
|
||||||
"""Clear the *whole* semantic cache."""
|
"""Clear the *whole* semantic cache."""
|
||||||
self.table.clear()
|
self.table.clear()
|
||||||
|
|
||||||
|
async def aclear(self, **kwargs: Any) -> None:
|
||||||
|
"""Clear the *whole* semantic cache."""
|
||||||
|
await self.table.aclear()
|
||||||
|
|
||||||
|
|
||||||
class FullMd5LLMCache(Base): # type: ignore
|
class FullMd5LLMCache(Base): # type: ignore
|
||||||
"""SQLite table for full LLM Cache (all generations)."""
|
"""SQLite table for full LLM Cache (all generations)."""
|
||||||
@@ -1412,7 +1544,7 @@ class AstraDBCache(BaseCache):
|
|||||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
setup_mode: SetupMode = SetupMode.SYNC,
|
setup_mode: AstraSetupMode = AstraSetupMode.SYNC,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Cache that uses Astra DB as a backend.
|
Cache that uses Astra DB as a backend.
|
||||||
@@ -1612,7 +1744,7 @@ class AstraDBSemanticCache(BaseCache):
|
|||||||
astra_db_client: Optional[AstraDB] = None,
|
astra_db_client: Optional[AstraDB] = None,
|
||||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
setup_mode: SetupMode = SetupMode.SYNC,
|
setup_mode: AstraSetupMode = AstraSetupMode.SYNC,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metric: Optional[str] = None,
|
metric: Optional[str] = None,
|
||||||
@@ -1675,9 +1807,9 @@ class AstraDBSemanticCache(BaseCache):
|
|||||||
self._aget_embedding = _acache_embedding
|
self._aget_embedding = _acache_embedding
|
||||||
|
|
||||||
embedding_dimension: Union[int, Awaitable[int], None] = None
|
embedding_dimension: Union[int, Awaitable[int], None] = None
|
||||||
if setup_mode == SetupMode.ASYNC:
|
if setup_mode == AstraSetupMode.ASYNC:
|
||||||
embedding_dimension = self._aget_embedding_dimension()
|
embedding_dimension = self._aget_embedding_dimension()
|
||||||
elif setup_mode == SetupMode.SYNC:
|
elif setup_mode == AstraSetupMode.SYNC:
|
||||||
embedding_dimension = self._get_embedding_dimension()
|
embedding_dimension = self._get_embedding_dimension()
|
||||||
|
|
||||||
self.astra_env = _AstraDBCollectionEnvironment(
|
self.astra_env = _AstraDBCollectionEnvironment(
|
||||||
|
@@ -93,13 +93,47 @@ class BaseCache(ABC):
|
|||||||
"""Clear cache that can take additional keyword arguments."""
|
"""Clear cache that can take additional keyword arguments."""
|
||||||
|
|
||||||
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Async version of lookup."""
|
"""Look up based on prompt and llm_string.
|
||||||
|
|
||||||
|
A cache implementation is expected to generate a key from the 2-tuple
|
||||||
|
of prompt and llm_string (e.g., by concatenating them with a delimiter).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: a string representation of the prompt.
|
||||||
|
In the case of a Chat model, the prompt is a non-trivial
|
||||||
|
serialization of the prompt into the language model.
|
||||||
|
llm_string: A string representation of the LLM configuration.
|
||||||
|
This is used to capture the invocation parameters of the LLM
|
||||||
|
(e.g., model name, temperature, stop tokens, max tokens, etc.).
|
||||||
|
These invocation parameters are serialized into a string
|
||||||
|
representation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
On a cache miss, return None. On a cache hit, return the cached value.
|
||||||
|
The cached value is a list of Generations (or subclasses).
|
||||||
|
"""
|
||||||
return await run_in_executor(None, self.lookup, prompt, llm_string)
|
return await run_in_executor(None, self.lookup, prompt, llm_string)
|
||||||
|
|
||||||
async def aupdate(
|
async def aupdate(
|
||||||
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Async version of aupdate."""
|
"""Update cache based on prompt and llm_string.
|
||||||
|
|
||||||
|
The prompt and llm_string are used to generate a key for the cache.
|
||||||
|
The key should match that of the look up method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: a string representation of the prompt.
|
||||||
|
In the case of a Chat model, the prompt is a non-trivial
|
||||||
|
serialization of the prompt into the language model.
|
||||||
|
llm_string: A string representation of the LLM configuration.
|
||||||
|
This is used to capture the invocation parameters of the LLM
|
||||||
|
(e.g., model name, temperature, stop tokens, max tokens, etc.).
|
||||||
|
These invocation parameters are serialized into a string
|
||||||
|
representation.
|
||||||
|
return_val: The value to be cached. The value is a list of Generations
|
||||||
|
(or subclasses).
|
||||||
|
"""
|
||||||
return await run_in_executor(None, self.update, prompt, llm_string, return_val)
|
return await run_in_executor(None, self.update, prompt, llm_string, return_val)
|
||||||
|
|
||||||
async def aclear(self, **kwargs: Any) -> None:
|
async def aclear(self, **kwargs: Any) -> None:
|
||||||
|
@@ -1,10 +1,11 @@
|
|||||||
"""Test Cassandra caches. Requires a running vector-capable Cassandra cluster."""
|
"""Test Cassandra caches. Requires a running vector-capable Cassandra cluster."""
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Iterator, Tuple
|
from typing import Any, Iterator, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_community.utilities.cassandra import SetupMode
|
||||||
from langchain_core.outputs import Generation, LLMResult
|
from langchain_core.outputs import Generation, LLMResult
|
||||||
|
|
||||||
from langchain.cache import CassandraCache, CassandraSemanticCache
|
from langchain.cache import CassandraCache, CassandraSemanticCache
|
||||||
@@ -47,16 +48,34 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
|||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(["foo"])
|
output = llm.generate(["foo"])
|
||||||
print(output) # noqa: T201
|
|
||||||
expected_output = LLMResult(
|
expected_output = LLMResult(
|
||||||
generations=[[Generation(text="fizz")]],
|
generations=[[Generation(text="fizz")]],
|
||||||
llm_output={},
|
llm_output={},
|
||||||
)
|
)
|
||||||
print(expected_output) # noqa: T201
|
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
cache.clear()
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cassandra_cache_async(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
|
session, keyspace = cassandra_connection
|
||||||
|
cache = CassandraCache(
|
||||||
|
session=session, keyspace=keyspace, setup_mode=SetupMode.ASYNC
|
||||||
|
)
|
||||||
|
set_llm_cache(cache)
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
output = await llm.agenerate(["foo"])
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
|
await cache.aclear()
|
||||||
|
|
||||||
|
|
||||||
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
session, keyspace = cassandra_connection
|
session, keyspace = cassandra_connection
|
||||||
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
||||||
@@ -79,6 +98,30 @@ def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
|||||||
cache.clear()
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cassandra_cache_ttl_async(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
|
session, keyspace = cassandra_connection
|
||||||
|
cache = CassandraCache(
|
||||||
|
session=session, keyspace=keyspace, ttl_seconds=2, setup_mode=SetupMode.ASYNC
|
||||||
|
)
|
||||||
|
set_llm_cache(cache)
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
output = await llm.agenerate(["foo"])
|
||||||
|
assert output == expected_output
|
||||||
|
await asyncio.sleep(2.5)
|
||||||
|
# entry has expired away.
|
||||||
|
output = await llm.agenerate(["foo"])
|
||||||
|
assert output != expected_output
|
||||||
|
await cache.aclear()
|
||||||
|
|
||||||
|
|
||||||
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
session, keyspace = cassandra_connection
|
session, keyspace = cassandra_connection
|
||||||
sem_cache = CassandraSemanticCache(
|
sem_cache = CassandraSemanticCache(
|
||||||
@@ -103,3 +146,32 @@ def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None
|
|||||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||||
assert output != expected_output
|
assert output != expected_output
|
||||||
sem_cache.clear()
|
sem_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cassandra_semantic_cache_async(
|
||||||
|
cassandra_connection: Tuple[Any, str],
|
||||||
|
) -> None:
|
||||||
|
session, keyspace = cassandra_connection
|
||||||
|
sem_cache = CassandraSemanticCache(
|
||||||
|
session=session,
|
||||||
|
keyspace=keyspace,
|
||||||
|
embedding=FakeEmbeddings(),
|
||||||
|
setup_mode=SetupMode.ASYNC,
|
||||||
|
)
|
||||||
|
set_llm_cache(sem_cache)
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
output = await llm.agenerate(["bar"]) # same embedding as 'foo'
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
|
# clear the cache
|
||||||
|
await sem_cache.aclear()
|
||||||
|
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
|
||||||
|
assert output != expected_output
|
||||||
|
await sem_cache.aclear()
|
||||||
|
Reference in New Issue
Block a user