From 815ec742980fcf289ae6ee34249407e5aaf642a0 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 22 Feb 2024 01:41:47 +0100 Subject: [PATCH] docs: Add docstring to AstraDBStore (#17793) --- .../langchain_community/storage/astradb.py | 153 ++++++++++++------ 1 file changed, 106 insertions(+), 47 deletions(-) diff --git a/libs/community/langchain_community/storage/astradb.py b/libs/community/langchain_community/storage/astradb.py index a486f8851b5..86be2ab9b83 100644 --- a/libs/community/langchain_community/storage/astradb.py +++ b/libs/community/langchain_community/storage/astradb.py @@ -32,28 +32,8 @@ V = TypeVar("V") class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): """Base class for the DataStax AstraDB data store.""" - def __init__( - self, - collection_name: str, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - astra_db_client: Optional[AstraDB] = None, - namespace: Optional[str] = None, - *, - async_astra_db_client: Optional[AsyncAstraDB] = None, - pre_delete_collection: bool = False, - setup_mode: SetupMode = SetupMode.SYNC, - ) -> None: - self.astra_env = _AstraDBCollectionEnvironment( - collection_name=collection_name, - token=token, - api_endpoint=api_endpoint, - astra_db_client=astra_db_client, - async_astra_db_client=async_astra_db_client, - namespace=namespace, - setup_mode=setup_mode, - pre_delete_collection=pre_delete_collection, - ) + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs) self.collection = self.astra_env.collection self.async_collection = self.astra_env.async_collection @@ -66,7 +46,6 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): """Encodes value for Astra DB""" def mget(self, keys: Sequence[str]) -> List[Optional[V]]: - """Get the values associated with the given keys.""" self.astra_env.ensure_db_setup() docs_dict = {} for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}): @@ -74,7 +53,6 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): return [self.decode_value(docs_dict.get(key)) for key in keys] async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: - """Get the values associated with the given keys.""" await self.astra_env.aensure_db_setup() docs_dict = {} async for doc in self.async_collection.paginated_find( @@ -84,13 +62,11 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): return [self.decode_value(docs_dict.get(key)) for key in keys] def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: - """Set the given key-value pairs.""" self.astra_env.ensure_db_setup() for k, v in key_value_pairs: self.collection.upsert({"_id": k, "value": self.encode_value(v)}) async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: - """Set the given key-value pairs.""" await self.astra_env.aensure_db_setup() for k, v in key_value_pairs: await self.async_collection.upsert( @@ -98,17 +74,14 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): ) def mdelete(self, keys: Sequence[str]) -> None: - """Delete the given keys.""" self.astra_env.ensure_db_setup() self.collection.delete_many(filter={"_id": {"$in": list(keys)}}) async def amdelete(self, keys: Sequence[str]) -> None: - """Delete the given keys.""" await self.astra_env.aensure_db_setup() await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}}) def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: - """Yield keys in the store.""" self.astra_env.ensure_db_setup() docs = self.collection.paginated_find() for doc in docs: @@ -117,7 +90,6 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): yield key async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: - """Yield keys in the store.""" await self.astra_env.aensure_db_setup() async for doc in self.async_collection.paginated_find(): key = doc["_id"] @@ -131,16 +103,60 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): alternative_import="langchain_astradb.AstraDBStore", ) class AstraDBStore(AstraDBBaseStore[Any]): - """BaseStore implementation using DataStax AstraDB as the underlying store. + def __init__( + self, + collection_name: str, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + namespace: Optional[str] = None, + *, + async_astra_db_client: Optional[AsyncAstraDB] = None, + pre_delete_collection: bool = False, + setup_mode: SetupMode = SetupMode.SYNC, + ) -> None: + """BaseStore implementation using DataStax AstraDB as the underlying store. - The value type can be any type serializable by json.dumps. - Can be used to store embeddings with the CacheBackedEmbeddings. - Documents in the AstraDB collection will have the format - { - "_id": "", - "value": - } - """ + The value type can be any type serializable by json.dumps. + Can be used to store embeddings with the CacheBackedEmbeddings. + + Documents in the AstraDB collection will have the format + + .. code-block:: json + + { + "_id": "", + "value": + } + + Args: + collection_name: name of the Astra DB collection to create/use. + token: API token for Astra DB usage. + api_endpoint: full URL to the API endpoint, + such as `https://-us-east1.apps.astra.datastax.com`. + astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + async_astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. + namespace: namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or + OFF). + pre_delete_collection: whether to delete the collection + before creating it. If False and the collection already exists, + the collection will be used as is. + """ + # Constructor doc is not inherited so we have to override it. + super().__init__( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + ) def decode_value(self, value: Any) -> Any: return value @@ -155,15 +171,58 @@ class AstraDBStore(AstraDBBaseStore[Any]): alternative_import="langchain_astradb.AstraDBByteStore", ) class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore): - """ByteStore implementation using DataStax AstraDB as the underlying store. + def __init__( + self, + collection_name: str, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + namespace: Optional[str] = None, + *, + async_astra_db_client: Optional[AsyncAstraDB] = None, + pre_delete_collection: bool = False, + setup_mode: SetupMode = SetupMode.SYNC, + ) -> None: + """ByteStore implementation using DataStax AstraDB as the underlying store. - The bytes values are converted to base64 encoded strings - Documents in the AstraDB collection will have the format - { - "_id": "", - "value": "" - } - """ + The bytes values are converted to base64 encoded strings + Documents in the AstraDB collection will have the format + + .. code-block:: json + + { + "_id": "", + "value": "" + } + + Args: + collection_name: name of the Astra DB collection to create/use. + token: API token for Astra DB usage. + api_endpoint: full URL to the API endpoint, + such as `https://-us-east1.apps.astra.datastax.com`. + astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + async_astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. + namespace: namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or + OFF). + pre_delete_collection: whether to delete the collection + before creating it. If False and the collection already exists, + the collection will be used as is. + """ + # Constructor doc is not inherited so we have to override it. + super().__init__( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + ) def decode_value(self, value: Any) -> Optional[bytes]: if value is None: