community[patch]: Fix AstraDBCache docstrings (#17802)

This commit is contained in:
Christophe Bornet 2024-02-20 17:39:30 +01:00 committed by GitHub
parent 865cabff05
commit b13e52b6ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1366,16 +1366,6 @@ ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache"
class AstraDBCache(BaseCache):
"""
Cache that uses Astra DB as a backend.
It uses a single collection as a kv store
The lookup keys, combined in the _id of the documents, are:
- prompt, a string
- llm_string, a deterministic str representation of the model parameters.
(needed to prevent same-prompt-different-model collisions)
"""
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}"
@ -1393,25 +1383,30 @@ class AstraDBCache(BaseCache):
setup_mode: SetupMode = SetupMode.SYNC,
):
"""
Create an AstraDB cache using a collection for storage.
Cache that uses Astra DB as a backend.
Args (only keyword-arguments accepted):
collection_name (str): name of the Astra DB collection to create/use.
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[AstraDB]):
*alternative to token+api_endpoint*,
It uses a single collection as a kv store
The lookup keys, combined in the _id of the documents, are:
- prompt, a string
- llm_string, a deterministic str representation of the model parameters.
(needed to prevent same-prompt-different-model collisions)
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client (Optional[AsyncAstraDB]):
*alternative to token+api_endpoint*,
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
pre_delete_collection (bool): whether to delete and re-create the
collection. Defaults to False.
async_setup (bool): whether to create the collection asynchronously.
Enable only if there is a running asyncio event loop. Defaults to False.
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
@ -1427,7 +1422,6 @@ class AstraDBCache(BaseCache):
self.async_collection = self.astra_env.async_collection
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
item = self.collection.find_one(
@ -1441,7 +1435,6 @@ class AstraDBCache(BaseCache):
return _loads_generations(item["body_blob"]) if item is not None else None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
item = (
@ -1457,7 +1450,6 @@ class AstraDBCache(BaseCache):
return _loads_generations(item["body_blob"]) if item is not None else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val)
@ -1471,7 +1463,6 @@ class AstraDBCache(BaseCache):
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val)
@ -1523,12 +1514,10 @@ class AstraDBCache(BaseCache):
await self.async_collection.delete_one(doc_id)
def clear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
self.astra_env.ensure_db_setup()
self.collection.clear()
async def aclear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
await self.astra_env.aensure_db_setup()
await self.async_collection.clear()
@ -1575,18 +1564,6 @@ def _async_lru_cache(maxsize: int = 128, typed: bool = False) -> Callable:
class AstraDBSemanticCache(BaseCache):
"""
Cache that uses Astra DB as a vector-store backend for semantic
(i.e. similarity-based) lookup.
It uses a single (vector) collection and can store
cached values from several LLMs, so the LLM's 'llm_string' is stored
in the document metadata.
You can choose the preferred similarity (or use the API default) --
remember the threshold might require metric-dependent tuning.
"""
def __init__(
self,
*,
@ -1603,33 +1580,38 @@ class AstraDBSemanticCache(BaseCache):
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
):
"""
Initialize the cache with all relevant parameters.
Args:
Cache that uses Astra DB as a vector-store backend for semantic
(i.e. similarity-based) lookup.
collection_name (str): name of the Astra DB collection to create/use.
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[AstraDB]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client (Optional[AsyncAstraDB]):
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode (SetupMode): mode used to create the collection in the DB
(SYNC, ASYNC or OFF). Defaults to SYNC.
pre_delete_collection (bool): whether to delete and re-create the
collection. Defaults to False.
embedding (Embedding): Embedding provider for semantic
encoding and search.
metric: the function to use for evaluating similarity of text embeddings.
Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product')
similarity_threshold (float, optional): the minimum similarity
for accepting a (semantic-search) match.
It uses a single (vector) collection and can store
cached values from several LLMs, so the LLM's 'llm_string' is stored
in the document metadata.
You can choose the preferred similarity (or use the API default).
The default score threshold is tuned to the default metric.
Tune it carefully yourself if switching to another distance metric.
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
embedding: Embedding provider for semantic encoding and search.
metric: the function to use for evaluating similarity of text embeddings.
Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product')
similarity_threshold: the minimum similarity for accepting a
(semantic-search) match.
"""
self.embedding = embedding
self.metric = metric
@ -1685,7 +1667,6 @@ class AstraDBSemanticCache(BaseCache):
return f"{_hash(prompt)}#{_hash(llm_string)}"
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string)
@ -1704,7 +1685,6 @@ class AstraDBSemanticCache(BaseCache):
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string)
@ -1721,7 +1701,6 @@ class AstraDBSemanticCache(BaseCache):
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = self.lookup_with_id(prompt, llm_string)
if hit_with_id is not None:
return hit_with_id[1]
@ -1729,7 +1708,6 @@ class AstraDBSemanticCache(BaseCache):
return None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = await self.alookup_with_id(prompt, llm_string)
if hit_with_id is not None:
return hit_with_id[1]
@ -1835,11 +1813,9 @@ class AstraDBSemanticCache(BaseCache):
await self.async_collection.delete_one(document_id)
def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
self.astra_env.ensure_db_setup()
self.collection.clear()
async def aclear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
await self.astra_env.aensure_db_setup()
await self.async_collection.clear()