community: Factorize AstraDB components constructors (#16779)

* Adds `AstraDBEnvironment` class and use it in `AstraDBLoader`,
`AstraDBCache`, `AstraDBSemanticCache`, `AstraDBBaseStore` and
`AstraDBChatMessageHistory`
* Create an `AsyncAstraDB` if we only have an `AstraDB` and vice-versa
so:
  * we always have an instance of `AstraDB`
* we always have an instance of `AsyncAstraDB` for recent versions of
astrapy
* Create collection if not exists in `AstraDBBaseStore`
* Some typing improvements

Note: `AstraDB` `VectorStore` not using `AstraDBEnvironment` at the
moment. This will be done after the `langchain-astradb` package is out.
This commit is contained in:
Christophe Bornet 2024-02-01 19:51:07 +01:00 committed by GitHub
parent 93366861c7
commit 9d458d089a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 169 additions and 168 deletions

View File

@ -60,12 +60,14 @@ from langchain_core.load.load import loads
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils import get_from_env from langchain_core.utils import get_from_env
from langchain_community.utilities.astradb import AstraDBEnvironment
from langchain_community.vectorstores.redis import Redis as RedisVectorstore from langchain_community.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
if TYPE_CHECKING: if TYPE_CHECKING:
import momento import momento
from astrapy.db import AstraDB
from cassandra.cluster import Session as CassandraSession from cassandra.cluster import Session as CassandraSession
@ -1262,7 +1264,7 @@ class AstraDBCache(BaseCache):
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
token: Optional[str] = None, token: Optional[str] = None,
api_endpoint: Optional[str] = None, api_endpoint: Optional[str] = None,
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
): ):
""" """
@ -1278,39 +1280,17 @@ class AstraDBCache(BaseCache):
namespace (Optional[str]): namespace (aka keyspace) where the namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace". collection is created. Defaults to the database's "default namespace".
""" """
try: astra_env = AstraDBEnvironment(
from astrapy.db import ( token=token,
AstraDB as LibAstraDB, api_endpoint=api_endpoint,
) astra_db_client=astra_db_client,
except (ImportError, ModuleNotFoundError): namespace=namespace,
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)
self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
) )
self.astra_db = astra_env.astra_db
self.collection = self.astra_db.create_collection( self.collection = self.astra_db.create_collection(
collection_name=self.collection_name, collection_name=collection_name,
) )
self.collection_name = collection_name
@staticmethod @staticmethod
def _make_id(prompt: str, llm_string: str) -> str: def _make_id(prompt: str, llm_string: str) -> str:
@ -1364,7 +1344,7 @@ class AstraDBCache(BaseCache):
def delete(self, prompt: str, llm_string: str) -> None: def delete(self, prompt: str, llm_string: str) -> None:
"""Evict from cache if there's an entry.""" """Evict from cache if there's an entry."""
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
return self.collection.delete_one(doc_id) self.collection.delete_one(doc_id)
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once.""" """Clear cache. This is for all LLMs at once."""
@ -1395,7 +1375,7 @@ class AstraDBSemanticCache(BaseCache):
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
token: Optional[str] = None, token: Optional[str] = None,
api_endpoint: Optional[str] = None, api_endpoint: Optional[str] = None,
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
embedding: Embeddings, embedding: Embeddings,
metric: Optional[str] = None, metric: Optional[str] = None,
@ -1423,22 +1403,13 @@ class AstraDBSemanticCache(BaseCache):
The default score threshold is tuned to the default metric. The default score threshold is tuned to the default metric.
Tune it carefully yourself if switching to another distance metric. Tune it carefully yourself if switching to another distance metric.
""" """
try: astra_env = AstraDBEnvironment(
from astrapy.db import ( token=token,
AstraDB as LibAstraDB, api_endpoint=api_endpoint,
) astra_db_client=astra_db_client,
except (ImportError, ModuleNotFoundError): namespace=namespace,
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
) )
self.astra_db = astra_env.astra_db
self.embedding = embedding self.embedding = embedding
self.metric = metric self.metric = metric
@ -1457,18 +1428,7 @@ class AstraDBSemanticCache(BaseCache):
self.embedding_dimension = self._get_embedding_dimension() self.embedding_dimension = self._get_embedding_dimension()
self.collection_name = collection_name self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
self.collection = self.astra_db.create_collection( self.collection = self.astra_db.create_collection(
collection_name=self.collection_name, collection_name=self.collection_name,
dimension=self.embedding_dimension, dimension=self.embedding_dimension,

View File

@ -3,11 +3,12 @@ from __future__ import annotations
import json import json
import time import time
import typing from typing import TYPE_CHECKING, List, Optional
from typing import List, Optional
if typing.TYPE_CHECKING: from langchain_community.utilities.astradb import AstraDBEnvironment
from astrapy.db import AstraDB as LibAstraDB
if TYPE_CHECKING:
from astrapy.db import AstraDB
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import ( from langchain_core.messages import (
@ -42,40 +43,22 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
collection_name: str = DEFAULT_COLLECTION_NAME, collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None, token: Optional[str] = None,
api_endpoint: Optional[str] = None, api_endpoint: Optional[str] = None,
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB' astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
) -> None: ) -> None:
"""Create an Astra DB chat message history.""" """Create an Astra DB chat message history."""
try: astra_env = AstraDBEnvironment(
from astrapy.db import AstraDB as LibAstraDB token=token,
except (ImportError, ModuleNotFoundError): api_endpoint=api_endpoint,
raise ImportError( astra_db_client=astra_db_client,
"Could not import a recent astrapy python package. " namespace=namespace,
"Please install it with `pip install --upgrade astrapy`."
) )
self.astra_db = astra_env.astra_db
# Conflicting-arg checks: self.collection = self.astra_db.create_collection(collection_name)
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)
self.session_id = session_id self.session_id = session_id
self.collection_name = collection_name self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
self.collection = self.astra_db.create_collection(self.collection_name)
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore

View File

@ -16,8 +16,10 @@ from typing import (
) )
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.runnables import run_in_executor
from langchain_community.document_loaders.base import BaseLoader from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.astradb import AstraDBEnvironment
if TYPE_CHECKING: if TYPE_CHECKING:
from astrapy.db import AstraDB, AsyncAstraDB from astrapy.db import AstraDB, AsyncAstraDB
@ -42,21 +44,15 @@ class AstraDBLoader(BaseLoader):
nb_prefetched: int = 1000, nb_prefetched: int = 1000,
extraction_function: Callable[[Dict], str] = json.dumps, extraction_function: Callable[[Dict], str] = json.dumps,
) -> None: ) -> None:
try: astra_env = AstraDBEnvironment(
from astrapy.db import AstraDB token=token,
except (ImportError, ModuleNotFoundError): api_endpoint=api_endpoint,
raise ImportError( astra_db_client=astra_db_client,
"Could not import a recent astrapy python package. " async_astra_db_client=async_astra_db_client,
"Please install it with `pip install --upgrade astrapy`." namespace=namespace,
)
# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDB if passing 'token' and 'api_endpoint'."
) )
self.astra_env = astra_env
self.collection = astra_env.astra_db.collection(collection_name)
self.collection_name = collection_name self.collection_name = collection_name
self.filter = filter_criteria self.filter = filter_criteria
self.projection = projection self.projection = projection
@ -64,47 +60,11 @@ class AstraDBLoader(BaseLoader):
self.nb_prefetched = nb_prefetched self.nb_prefetched = nb_prefetched
self.extraction_function = extraction_function self.extraction_function = extraction_function
astra_db = astra_db_client
async_astra_db = async_astra_db_client
if token and api_endpoint:
astra_db = AstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)
try:
from astrapy.db import AsyncAstraDB
async_astra_db = AsyncAstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)
except (ImportError, ModuleNotFoundError):
pass
if not astra_db and not async_astra_db:
raise ValueError(
"Must provide 'astra_db_client' or 'async_astra_db_client' or 'token' "
"and 'api_endpoint'"
)
self.collection = astra_db.collection(collection_name) if astra_db else None
if async_astra_db:
from astrapy.db import AsyncAstraDBCollection
self.async_collection = AsyncAstraDBCollection(
astra_db=async_astra_db, collection_name=collection_name
)
else:
self.async_collection = None
def load(self) -> List[Document]: def load(self) -> List[Document]:
"""Eagerly load the content.""" """Eagerly load the content."""
return list(self.lazy_load()) return list(self.lazy_load())
def lazy_load(self) -> Iterator[Document]: def lazy_load(self) -> Iterator[Document]:
if not self.collection:
raise ValueError("Missing AstraDB client")
queue = Queue(self.nb_prefetched) queue = Queue(self.nb_prefetched)
t = threading.Thread(target=self.fetch_results, args=(queue,)) t = threading.Thread(target=self.fetch_results, args=(queue,))
t.start() t.start()
@ -120,9 +80,27 @@ class AstraDBLoader(BaseLoader):
return [doc async for doc in self.alazy_load()] return [doc async for doc in self.alazy_load()]
async def alazy_load(self) -> AsyncIterator[Document]: async def alazy_load(self) -> AsyncIterator[Document]:
if not self.async_collection: if not self.astra_env.async_astra_db:
raise ValueError("Missing AsyncAstraDB client") iterator = run_in_executor(
async for doc in self.async_collection.paginated_find( None,
self.collection.paginated_find,
filter=self.filter,
options=self.find_options,
projection=self.projection,
sort=None,
prefetched=True,
)
done = object()
while True:
item = await run_in_executor(None, lambda it: next(it, done), iterator)
if item is done:
break
yield item
return
async_collection = await self.astra_env.async_astra_db.collection(
self.collection_name
)
async for doc in async_collection.paginated_find(
filter=self.filter, filter=self.filter,
options=self.find_options, options=self.find_options,
projection=self.projection, projection=self.projection,
@ -132,8 +110,8 @@ class AstraDBLoader(BaseLoader):
yield Document( yield Document(
page_content=self.extraction_function(doc), page_content=self.extraction_function(doc),
metadata={ metadata={
"namespace": self.async_collection.astra_db.namespace, "namespace": async_collection.astra_db.namespace,
"api_endpoint": self.async_collection.astra_db.base_url, "api_endpoint": async_collection.astra_db.base_url,
"collection": self.collection_name, "collection": self.collection_name,
}, },
) )

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import base64 import base64
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Generic, Generic,
Iterator, Iterator,
@ -13,6 +16,11 @@ from typing import (
from langchain_core.stores import BaseStore, ByteStore from langchain_core.stores import BaseStore, ByteStore
from langchain_community.utilities.astradb import AstraDBEnvironment
if TYPE_CHECKING:
from astrapy.db import AstraDB
V = TypeVar("V") V = TypeVar("V")
@ -22,31 +30,19 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
collection_name: str, collection_name: str,
token: Optional[str] = None, token: Optional[str] = None,
api_endpoint: Optional[str] = None, api_endpoint: Optional[str] = None,
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
) -> None: ) -> None:
try: astra_env = AstraDBEnvironment(
from astrapy.db import AstraDB, AstraDBCollection
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)
astra_db = astra_db_client or AstraDB(
token=token, token=token,
api_endpoint=api_endpoint, api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
namespace=namespace, namespace=namespace,
) )
self.collection = AstraDBCollection(collection_name, astra_db=astra_db) self.astra_db = astra_env.astra_db
self.collection = self.astra_db.create_collection(
collection_name=collection_name,
)
@abstractmethod @abstractmethod
def decode_value(self, value: Any) -> Optional[V]: def decode_value(self, value: Any) -> Optional[V]:

View File

@ -0,0 +1,84 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from astrapy.db import (
AstraDB,
AsyncAstraDB,
)
class AstraDBEnvironment:
def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
) -> None:
self.token = token
self.api_endpoint = api_endpoint
astra_db = astra_db_client
self.async_astra_db = async_astra_db_client
self.namespace = namespace
from astrapy import db
try:
from astrapy.db import AstraDB
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
supports_async = hasattr(db, "AsyncAstraDB")
# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
)
if token and api_endpoint:
astra_db = AstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
if supports_async:
self.async_astra_db = db.AsyncAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
if astra_db:
self.astra_db = astra_db
else:
if self.async_astra_db:
self.astra_db = AstraDB(
token=self.async_astra_db.token,
api_endpoint=self.async_astra_db.base_url,
api_path=self.async_astra_db.api_path,
api_version=self.async_astra_db.api_version,
namespace=self.async_astra_db.namespace,
)
else:
raise ValueError(
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
"'token' and 'api_endpoint'"
)
if not self.async_astra_db and self.astra_db and supports_async:
self.async_astra_db = db.AsyncAstraDB(
token=self.astra_db.token,
api_endpoint=self.astra_db.base_url,
api_path=self.astra_db.api_path,
api_version=self.astra_db.api_version,
namespace=self.astra_db.namespace,
)