community: Factorize AstraDB components constructors (#16779)

* Adds `AstraDBEnvironment` class and use it in `AstraDBLoader`,
`AstraDBCache`, `AstraDBSemanticCache`, `AstraDBBaseStore` and
`AstraDBChatMessageHistory`
* Create an `AsyncAstraDB` if we only have an `AstraDB` and vice-versa
so:
  * we always have an instance of `AstraDB`
* we always have an instance of `AsyncAstraDB` for recent versions of
astrapy
* Create collection if not exists in `AstraDBBaseStore`
* Some typing improvements

Note: `AstraDB` `VectorStore` not using `AstraDBEnvironment` at the
moment. This will be done after the `langchain-astradb` package is out.
This commit is contained in:
Christophe Bornet
2024-02-01 19:51:07 +01:00
committed by GitHub
parent 93366861c7
commit 9d458d089a
5 changed files with 169 additions and 168 deletions

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import base64
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterator,
@@ -13,6 +16,11 @@ from typing import (
from langchain_core.stores import BaseStore, ByteStore
from langchain_community.utilities.astradb import AstraDBEnvironment
if TYPE_CHECKING:
from astrapy.db import AstraDB
V = TypeVar("V")
@@ -22,31 +30,19 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
) -> None:
try:
from astrapy.db import AstraDB, AstraDBCollection
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)
astra_db = astra_db_client or AstraDB(
astra_env = AstraDBEnvironment(
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
namespace=namespace,
)
self.collection = AstraDBCollection(collection_name, astra_db=astra_db)
self.astra_db = astra_env.astra_db
self.collection = self.astra_db.create_collection(
collection_name=collection_name,
)
@abstractmethod
def decode_value(self, value: Any) -> Optional[V]: