community[patch]: Fix AstraDBCache docstrings (#17802)

This commit is contained in:
Christophe Bornet 2024-02-20 17:39:30 +01:00 committed by GitHub
parent 865cabff05
commit b13e52b6ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1366,16 +1366,6 @@ ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache"
class AstraDBCache(BaseCache): class AstraDBCache(BaseCache):
"""
Cache that uses Astra DB as a backend.
It uses a single collection as a kv store
The lookup keys, combined in the _id of the documents, are:
- prompt, a string
- llm_string, a deterministic str representation of the model parameters.
(needed to prevent same-prompt-different-model collisions)
"""
@staticmethod @staticmethod
def _make_id(prompt: str, llm_string: str) -> str: def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}" return f"{_hash(prompt)}#{_hash(llm_string)}"
@ -1393,25 +1383,30 @@ class AstraDBCache(BaseCache):
setup_mode: SetupMode = SetupMode.SYNC, setup_mode: SetupMode = SetupMode.SYNC,
): ):
""" """
Create an AstraDB cache using a collection for storage. Cache that uses Astra DB as a backend.
Args (only keyword-arguments accepted): It uses a single collection as a kv store
collection_name (str): name of the Astra DB collection to create/use. The lookup keys, combined in the _id of the documents, are:
token (Optional[str]): API token for Astra DB usage. - prompt, a string
api_endpoint (Optional[str]): full URL to the API endpoint, - llm_string, a deterministic str representation of the model parameters.
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". (needed to prevent same-prompt-different-model collisions)
astra_db_client (Optional[AstraDB]):
*alternative to token+api_endpoint*, Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance. you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client (Optional[AsyncAstraDB]): async_astra_db_client: *alternative to token+api_endpoint*,
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace". collection is created. Defaults to the database's "default namespace".
pre_delete_collection (bool): whether to delete and re-create the setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
collection. Defaults to False. OFF).
async_setup (bool): whether to create the collection asynchronously. pre_delete_collection: whether to delete the collection
Enable only if there is a running asyncio event loop. Defaults to False. before creating it. If False and the collection already exists,
the collection will be used as is.
""" """
self.astra_env = _AstraDBCollectionEnvironment( self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name, collection_name=collection_name,
@ -1427,7 +1422,6 @@ class AstraDBCache(BaseCache):
self.async_collection = self.astra_env.async_collection self.async_collection = self.astra_env.async_collection
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
self.astra_env.ensure_db_setup() self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
item = self.collection.find_one( item = self.collection.find_one(
@ -1441,7 +1435,6 @@ class AstraDBCache(BaseCache):
return _loads_generations(item["body_blob"]) if item is not None else None return _loads_generations(item["body_blob"]) if item is not None else None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
await self.astra_env.aensure_db_setup() await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
item = ( item = (
@ -1457,7 +1450,6 @@ class AstraDBCache(BaseCache):
return _loads_generations(item["body_blob"]) if item is not None else None return _loads_generations(item["body_blob"]) if item is not None else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.astra_env.ensure_db_setup() self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val) blob = _dumps_generations(return_val)
@ -1471,7 +1463,6 @@ class AstraDBCache(BaseCache):
async def aupdate( async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None: ) -> None:
"""Update cache based on prompt and llm_string."""
await self.astra_env.aensure_db_setup() await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val) blob = _dumps_generations(return_val)
@ -1523,12 +1514,10 @@ class AstraDBCache(BaseCache):
await self.async_collection.delete_one(doc_id) await self.async_collection.delete_one(doc_id)
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
self.astra_env.ensure_db_setup() self.astra_env.ensure_db_setup()
self.collection.clear() self.collection.clear()
async def aclear(self, **kwargs: Any) -> None: async def aclear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
await self.astra_env.aensure_db_setup() await self.astra_env.aensure_db_setup()
await self.async_collection.clear() await self.async_collection.clear()
@ -1575,18 +1564,6 @@ def _async_lru_cache(maxsize: int = 128, typed: bool = False) -> Callable:
class AstraDBSemanticCache(BaseCache): class AstraDBSemanticCache(BaseCache):
"""
Cache that uses Astra DB as a vector-store backend for semantic
(i.e. similarity-based) lookup.
It uses a single (vector) collection and can store
cached values from several LLMs, so the LLM's 'llm_string' is stored
in the document metadata.
You can choose the preferred similarity (or use the API default) --
remember the threshold might require metric-dependent tuning.
"""
def __init__( def __init__(
self, self,
*, *,
@ -1603,33 +1580,38 @@ class AstraDBSemanticCache(BaseCache):
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
): ):
""" """
Initialize the cache with all relevant parameters. Cache that uses Astra DB as a vector-store backend for semantic
Args: (i.e. similarity-based) lookup.
collection_name (str): name of the Astra DB collection to create/use. It uses a single (vector) collection and can store
token (Optional[str]): API token for Astra DB usage. cached values from several LLMs, so the LLM's 'llm_string' is stored
api_endpoint (Optional[str]): full URL to the API endpoint, in the document metadata.
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[AstraDB]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client (Optional[AsyncAstraDB]):
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode (SetupMode): mode used to create the collection in the DB
(SYNC, ASYNC or OFF). Defaults to SYNC.
pre_delete_collection (bool): whether to delete and re-create the
collection. Defaults to False.
embedding (Embedding): Embedding provider for semantic
encoding and search.
metric: the function to use for evaluating similarity of text embeddings.
Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product')
similarity_threshold (float, optional): the minimum similarity
for accepting a (semantic-search) match.
You can choose the preferred similarity (or use the API default).
The default score threshold is tuned to the default metric. The default score threshold is tuned to the default metric.
Tune it carefully yourself if switching to another distance metric. Tune it carefully yourself if switching to another distance metric.
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
embedding: Embedding provider for semantic encoding and search.
metric: the function to use for evaluating similarity of text embeddings.
Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product')
similarity_threshold: the minimum similarity for accepting a
(semantic-search) match.
""" """
self.embedding = embedding self.embedding = embedding
self.metric = metric self.metric = metric
@ -1685,7 +1667,6 @@ class AstraDBSemanticCache(BaseCache):
return f"{_hash(prompt)}#{_hash(llm_string)}" return f"{_hash(prompt)}#{_hash(llm_string)}"
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.astra_env.ensure_db_setup() self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string) llm_string_hash = _hash(llm_string)
@ -1704,7 +1685,6 @@ class AstraDBSemanticCache(BaseCache):
async def aupdate( async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None: ) -> None:
"""Update cache based on prompt and llm_string."""
await self.astra_env.aensure_db_setup() await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string) llm_string_hash = _hash(llm_string)
@ -1721,7 +1701,6 @@ class AstraDBSemanticCache(BaseCache):
) )
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = self.lookup_with_id(prompt, llm_string) hit_with_id = self.lookup_with_id(prompt, llm_string)
if hit_with_id is not None: if hit_with_id is not None:
return hit_with_id[1] return hit_with_id[1]
@ -1729,7 +1708,6 @@ class AstraDBSemanticCache(BaseCache):
return None return None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = await self.alookup_with_id(prompt, llm_string) hit_with_id = await self.alookup_with_id(prompt, llm_string)
if hit_with_id is not None: if hit_with_id is not None:
return hit_with_id[1] return hit_with_id[1]
@ -1835,11 +1813,9 @@ class AstraDBSemanticCache(BaseCache):
await self.async_collection.delete_one(document_id) await self.async_collection.delete_one(document_id)
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
self.astra_env.ensure_db_setup() self.astra_env.ensure_db_setup()
self.collection.clear() self.collection.clear()
async def aclear(self, **kwargs: Any) -> None: async def aclear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
await self.astra_env.aensure_db_setup() await self.astra_env.aensure_db_setup()
await self.async_collection.clear() await self.async_collection.clear()