From e92e96193fc3fb263a3379b685ab365123e23ebc Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 19 Feb 2024 19:11:49 +0100 Subject: [PATCH] community[minor]: Add async methods to the AstraDB BaseStore (#16872) --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../langchain_community/storage/astradb.py | 59 ++++++++++-- .../integration_tests/storage/test_astradb.py | 95 ++++++++++++++++--- 2 files changed, 136 insertions(+), 18 deletions(-) diff --git a/libs/community/langchain_community/storage/astradb.py b/libs/community/langchain_community/storage/astradb.py index 0cb2ea310aa..959ef374124 100644 --- a/libs/community/langchain_community/storage/astradb.py +++ b/libs/community/langchain_community/storage/astradb.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Generic, Iterator, List, @@ -16,10 +17,13 @@ from typing import ( from langchain_core.stores import BaseStore, ByteStore -from langchain_community.utilities.astradb import _AstraDBEnvironment +from langchain_community.utilities.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) if TYPE_CHECKING: - from astrapy.db import AstraDB + from astrapy.db import AstraDB, AsyncAstraDB V = TypeVar("V") @@ -34,17 +38,23 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, + *, + async_astra_db_client: Optional[AsyncAstraDB] = None, + pre_delete_collection: bool = False, + setup_mode: SetupMode = SetupMode.SYNC, ) -> None: - astra_env = _AstraDBEnvironment( + self.astra_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, ) - self.astra_db = astra_env.astra_db - self.collection = self.astra_db.create_collection( - collection_name=collection_name, - ) + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection @abstractmethod def decode_value(self, value: Any) -> Optional[V]: @@ -56,28 +66,63 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): def mget(self, keys: Sequence[str]) -> List[Optional[V]]: """Get the values associated with the given keys.""" + self.astra_env.ensure_db_setup() docs_dict = {} for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}): docs_dict[doc["_id"]] = doc.get("value") return [self.decode_value(docs_dict.get(key)) for key in keys] + async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: + """Get the values associated with the given keys.""" + await self.astra_env.aensure_db_setup() + docs_dict = {} + async for doc in self.async_collection.paginated_find( + filter={"_id": {"$in": list(keys)}} + ): + docs_dict[doc["_id"]] = doc.get("value") + return [self.decode_value(docs_dict.get(key)) for key in keys] + def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: """Set the given key-value pairs.""" + self.astra_env.ensure_db_setup() for k, v in key_value_pairs: self.collection.upsert({"_id": k, "value": self.encode_value(v)}) + async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + """Set the given key-value pairs.""" + await self.astra_env.aensure_db_setup() + for k, v in key_value_pairs: + await self.async_collection.upsert( + {"_id": k, "value": self.encode_value(v)} + ) + def mdelete(self, keys: Sequence[str]) -> None: """Delete the given keys.""" + self.astra_env.ensure_db_setup() self.collection.delete_many(filter={"_id": {"$in": list(keys)}}) + async def amdelete(self, keys: Sequence[str]) -> None: + """Delete the given keys.""" + await self.astra_env.aensure_db_setup() + await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}}) + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: """Yield keys in the store.""" + self.astra_env.ensure_db_setup() docs = self.collection.paginated_find() for doc in docs: key = doc["_id"] if not prefix or key.startswith(prefix): yield key + async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: + """Yield keys in the store.""" + await self.astra_env.aensure_db_setup() + async for doc in self.async_collection.paginated_find(): + key = doc["_id"] + if not prefix or key.startswith(prefix): + yield key + class AstraDBStore(AstraDBBaseStore[Any]): """BaseStore implementation using DataStax AstraDB as the underlying store. diff --git a/libs/community/tests/integration_tests/storage/test_astradb.py b/libs/community/tests/integration_tests/storage/test_astradb.py index 643b4e93a31..63108ef0c84 100644 --- a/libs/community/tests/integration_tests/storage/test_astradb.py +++ b/libs/community/tests/integration_tests/storage/test_astradb.py @@ -1,9 +1,16 @@ """Implement integration tests for AstraDB storage.""" +from __future__ import annotations + import os +from typing import TYPE_CHECKING import pytest from langchain_community.storage.astradb import AstraDBByteStore, AstraDBStore +from langchain_community.utilities.astradb import SetupMode + +if TYPE_CHECKING: + from astrapy.db import AstraDB, AsyncAstraDB def _has_env_vars() -> bool: @@ -16,7 +23,7 @@ def _has_env_vars() -> bool: @pytest.fixture -def astra_db(): # type: ignore[no-untyped-def] +def astra_db() -> AstraDB: from astrapy.db import AstraDB return AstraDB( @@ -26,24 +33,45 @@ def astra_db(): # type: ignore[no-untyped-def] ) -def init_store(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def] - astra_db.create_collection(collection_name) +@pytest.fixture +def async_astra_db() -> AsyncAstraDB: + from astrapy.db import AsyncAstraDB + + return AsyncAstraDB( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + + +def init_store(astra_db: AstraDB, collection_name: str) -> AstraDBStore: store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db) store.mset([("key1", [0.1, 0.2]), ("key2", "value2")]) return store -def init_bytestore(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def] - astra_db.create_collection(collection_name) +def init_bytestore(astra_db: AstraDB, collection_name: str) -> AstraDBByteStore: store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db) store.mset([("key1", b"value1"), ("key2", b"value2")]) return store +async def init_async_store( + async_astra_db: AsyncAstraDB, collection_name: str +) -> AstraDBStore: + store = AstraDBStore( + collection_name=collection_name, + async_astra_db_client=async_astra_db, + setup_mode=SetupMode.ASYNC, + ) + await store.amset([("key1", [0.1, 0.2]), ("key2", "value2")]) + return store + + @pytest.mark.requires("astrapy") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") class TestAstraDBStore: - def test_mget(self, astra_db) -> None: # type: ignore[no-untyped-def] + def test_mget(self, astra_db: AstraDB) -> None: """Test AstraDBStore mget method.""" collection_name = "lc_test_store_mget" try: @@ -52,7 +80,16 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_mset(self, astra_db) -> None: # type: ignore[no-untyped-def] + async def test_amget(self, async_astra_db: AsyncAstraDB) -> None: + """Test AstraDBStore amget method.""" + collection_name = "lc_test_store_mget" + try: + store = await init_async_store(async_astra_db, collection_name) + assert await store.amget(["key1", "key2"]) == [[0.1, 0.2], "value2"] + finally: + await async_astra_db.delete_collection(collection_name) + + def test_mset(self, astra_db: AstraDB) -> None: """Test that multiple keys can be set with AstraDBStore.""" collection_name = "lc_test_store_mset" try: @@ -64,7 +101,19 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_mdelete(self, astra_db) -> None: # type: ignore[no-untyped-def] + async def test_amset(self, async_astra_db: AsyncAstraDB) -> None: + """Test that multiple keys can be set with AstraDBStore.""" + collection_name = "lc_test_store_mset" + try: + store = await init_async_store(async_astra_db, collection_name) + result = await store.async_collection.find_one({"_id": "key1"}) + assert result["data"]["document"]["value"] == [0.1, 0.2] + result = await store.async_collection.find_one({"_id": "key2"}) + assert result["data"]["document"]["value"] == "value2" + finally: + await async_astra_db.delete_collection(collection_name) + + def test_mdelete(self, astra_db: AstraDB) -> None: """Test that deletion works as expected.""" collection_name = "lc_test_store_mdelete" try: @@ -75,7 +124,18 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_yield_keys(self, astra_db) -> None: # type: ignore[no-untyped-def] + async def test_amdelete(self, async_astra_db: AsyncAstraDB) -> None: + """Test that deletion works as expected.""" + collection_name = "lc_test_store_mdelete" + try: + store = await init_async_store(async_astra_db, collection_name) + await store.amdelete(["key1", "key2"]) + result = await store.amget(["key1", "key2"]) + assert result == [None, None] + finally: + await async_astra_db.delete_collection(collection_name) + + def test_yield_keys(self, astra_db: AstraDB) -> None: collection_name = "lc_test_store_yield_keys" try: store = init_store(astra_db, collection_name) @@ -85,7 +145,20 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_bytestore_mget(self, astra_db) -> None: # type: ignore[no-untyped-def] + async def test_ayield_keys(self, async_astra_db: AsyncAstraDB) -> None: + collection_name = "lc_test_store_yield_keys" + try: + store = await init_async_store(async_astra_db, collection_name) + assert {key async for key in store.ayield_keys()} == {"key1", "key2"} + assert {key async for key in store.ayield_keys(prefix="key")} == { + "key1", + "key2", + } + assert {key async for key in store.ayield_keys(prefix="lang")} == set() + finally: + await async_astra_db.delete_collection(collection_name) + + def test_bytestore_mget(self, astra_db: AstraDB) -> None: """Test AstraDBByteStore mget method.""" collection_name = "lc_test_bytestore_mget" try: @@ -94,7 +167,7 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_bytestore_mset(self, astra_db) -> None: # type: ignore[no-untyped-def] + def test_bytestore_mset(self, astra_db: AstraDB) -> None: """Test that multiple keys can be set with AstraDBByteStore.""" collection_name = "lc_test_bytestore_mset" try: