mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 11:12:47 +00:00
community: Factorize AstraDB components constructors (#16779)
* Adds `AstraDBEnvironment` class and use it in `AstraDBLoader`, `AstraDBCache`, `AstraDBSemanticCache`, `AstraDBBaseStore` and `AstraDBChatMessageHistory` * Create an `AsyncAstraDB` if we only have an `AstraDB` and vice-versa so: * we always have an instance of `AstraDB` * we always have an instance of `AsyncAstraDB` for recent versions of astrapy * Create collection if not exists in `AstraDBBaseStore` * Some typing improvements Note: `AstraDB` `VectorStore` not using `AstraDBEnvironment` at the moment. This will be done after the `langchain-astradb` package is out.
This commit is contained in:
parent
93366861c7
commit
9d458d089a
@ -60,12 +60,14 @@ from langchain_core.load.load import loads
|
|||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.utils import get_from_env
|
from langchain_core.utils import get_from_env
|
||||||
|
|
||||||
|
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||||
from langchain_community.vectorstores.redis import Redis as RedisVectorstore
|
from langchain_community.vectorstores.redis import Redis as RedisVectorstore
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import momento
|
import momento
|
||||||
|
from astrapy.db import AstraDB
|
||||||
from cassandra.cluster import Session as CassandraSession
|
from cassandra.cluster import Session as CassandraSession
|
||||||
|
|
||||||
|
|
||||||
@ -1262,7 +1264,7 @@ class AstraDBCache(BaseCache):
|
|||||||
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
api_endpoint: Optional[str] = None,
|
api_endpoint: Optional[str] = None,
|
||||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
astra_db_client: Optional[AstraDB] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1278,39 +1280,17 @@ class AstraDBCache(BaseCache):
|
|||||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||||
collection is created. Defaults to the database's "default namespace".
|
collection is created. Defaults to the database's "default namespace".
|
||||||
"""
|
"""
|
||||||
try:
|
astra_env = AstraDBEnvironment(
|
||||||
from astrapy.db import (
|
token=token,
|
||||||
AstraDB as LibAstraDB,
|
api_endpoint=api_endpoint,
|
||||||
)
|
astra_db_client=astra_db_client,
|
||||||
except (ImportError, ModuleNotFoundError):
|
namespace=namespace,
|
||||||
raise ImportError(
|
|
||||||
"Could not import a recent astrapy python package. "
|
|
||||||
"Please install it with `pip install --upgrade astrapy`."
|
|
||||||
)
|
|
||||||
# Conflicting-arg checks:
|
|
||||||
if astra_db_client is not None:
|
|
||||||
if token is not None or api_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
|
||||||
"'token' and 'api_endpoint'."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.collection_name = collection_name
|
|
||||||
self.token = token
|
|
||||||
self.api_endpoint = api_endpoint
|
|
||||||
self.namespace = namespace
|
|
||||||
|
|
||||||
if astra_db_client is not None:
|
|
||||||
self.astra_db = astra_db_client
|
|
||||||
else:
|
|
||||||
self.astra_db = LibAstraDB(
|
|
||||||
token=self.token,
|
|
||||||
api_endpoint=self.api_endpoint,
|
|
||||||
namespace=self.namespace,
|
|
||||||
)
|
|
||||||
self.collection = self.astra_db.create_collection(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
)
|
)
|
||||||
|
self.astra_db = astra_env.astra_db
|
||||||
|
self.collection = self.astra_db.create_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
)
|
||||||
|
self.collection_name = collection_name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_id(prompt: str, llm_string: str) -> str:
|
def _make_id(prompt: str, llm_string: str) -> str:
|
||||||
@ -1364,7 +1344,7 @@ class AstraDBCache(BaseCache):
|
|||||||
def delete(self, prompt: str, llm_string: str) -> None:
|
def delete(self, prompt: str, llm_string: str) -> None:
|
||||||
"""Evict from cache if there's an entry."""
|
"""Evict from cache if there's an entry."""
|
||||||
doc_id = self._make_id(prompt, llm_string)
|
doc_id = self._make_id(prompt, llm_string)
|
||||||
return self.collection.delete_one(doc_id)
|
self.collection.delete_one(doc_id)
|
||||||
|
|
||||||
def clear(self, **kwargs: Any) -> None:
|
def clear(self, **kwargs: Any) -> None:
|
||||||
"""Clear cache. This is for all LLMs at once."""
|
"""Clear cache. This is for all LLMs at once."""
|
||||||
@ -1395,7 +1375,7 @@ class AstraDBSemanticCache(BaseCache):
|
|||||||
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
api_endpoint: Optional[str] = None,
|
api_endpoint: Optional[str] = None,
|
||||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
astra_db_client: Optional[AstraDB] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metric: Optional[str] = None,
|
metric: Optional[str] = None,
|
||||||
@ -1423,22 +1403,13 @@ class AstraDBSemanticCache(BaseCache):
|
|||||||
The default score threshold is tuned to the default metric.
|
The default score threshold is tuned to the default metric.
|
||||||
Tune it carefully yourself if switching to another distance metric.
|
Tune it carefully yourself if switching to another distance metric.
|
||||||
"""
|
"""
|
||||||
try:
|
astra_env = AstraDBEnvironment(
|
||||||
from astrapy.db import (
|
token=token,
|
||||||
AstraDB as LibAstraDB,
|
api_endpoint=api_endpoint,
|
||||||
)
|
astra_db_client=astra_db_client,
|
||||||
except (ImportError, ModuleNotFoundError):
|
namespace=namespace,
|
||||||
raise ImportError(
|
)
|
||||||
"Could not import a recent astrapy python package. "
|
self.astra_db = astra_env.astra_db
|
||||||
"Please install it with `pip install --upgrade astrapy`."
|
|
||||||
)
|
|
||||||
# Conflicting-arg checks:
|
|
||||||
if astra_db_client is not None:
|
|
||||||
if token is not None or api_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
|
||||||
"'token' and 'api_endpoint'."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
@ -1457,18 +1428,7 @@ class AstraDBSemanticCache(BaseCache):
|
|||||||
self.embedding_dimension = self._get_embedding_dimension()
|
self.embedding_dimension = self._get_embedding_dimension()
|
||||||
|
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.token = token
|
|
||||||
self.api_endpoint = api_endpoint
|
|
||||||
self.namespace = namespace
|
|
||||||
|
|
||||||
if astra_db_client is not None:
|
|
||||||
self.astra_db = astra_db_client
|
|
||||||
else:
|
|
||||||
self.astra_db = LibAstraDB(
|
|
||||||
token=self.token,
|
|
||||||
api_endpoint=self.api_endpoint,
|
|
||||||
namespace=self.namespace,
|
|
||||||
)
|
|
||||||
self.collection = self.astra_db.create_collection(
|
self.collection = self.astra_db.create_collection(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
dimension=self.embedding_dimension,
|
dimension=self.embedding_dimension,
|
||||||
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import typing
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||||
from astrapy.db import AstraDB as LibAstraDB
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from astrapy.db import AstraDB
|
||||||
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -42,40 +43,22 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
collection_name: str = DEFAULT_COLLECTION_NAME,
|
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
api_endpoint: Optional[str] = None,
|
api_endpoint: Optional[str] = None,
|
||||||
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB'
|
astra_db_client: Optional[AstraDB] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an Astra DB chat message history."""
|
"""Create an Astra DB chat message history."""
|
||||||
try:
|
astra_env = AstraDBEnvironment(
|
||||||
from astrapy.db import AstraDB as LibAstraDB
|
token=token,
|
||||||
except (ImportError, ModuleNotFoundError):
|
api_endpoint=api_endpoint,
|
||||||
raise ImportError(
|
astra_db_client=astra_db_client,
|
||||||
"Could not import a recent astrapy python package. "
|
namespace=namespace,
|
||||||
"Please install it with `pip install --upgrade astrapy`."
|
)
|
||||||
)
|
self.astra_db = astra_env.astra_db
|
||||||
|
|
||||||
# Conflicting-arg checks:
|
self.collection = self.astra_db.create_collection(collection_name)
|
||||||
if astra_db_client is not None:
|
|
||||||
if token is not None or api_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
|
||||||
"'token' and 'api_endpoint'."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.token = token
|
|
||||||
self.api_endpoint = api_endpoint
|
|
||||||
self.namespace = namespace
|
|
||||||
if astra_db_client is not None:
|
|
||||||
self.astra_db = astra_db_client
|
|
||||||
else:
|
|
||||||
self.astra_db = LibAstraDB(
|
|
||||||
token=self.token,
|
|
||||||
api_endpoint=self.api_endpoint,
|
|
||||||
namespace=self.namespace,
|
|
||||||
)
|
|
||||||
self.collection = self.astra_db.create_collection(self.collection_name)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||||
|
@ -16,8 +16,10 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.runnables import run_in_executor
|
||||||
|
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
|
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from astrapy.db import AstraDB, AsyncAstraDB
|
from astrapy.db import AstraDB, AsyncAstraDB
|
||||||
@ -42,21 +44,15 @@ class AstraDBLoader(BaseLoader):
|
|||||||
nb_prefetched: int = 1000,
|
nb_prefetched: int = 1000,
|
||||||
extraction_function: Callable[[Dict], str] = json.dumps,
|
extraction_function: Callable[[Dict], str] = json.dumps,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
astra_env = AstraDBEnvironment(
|
||||||
from astrapy.db import AstraDB
|
token=token,
|
||||||
except (ImportError, ModuleNotFoundError):
|
api_endpoint=api_endpoint,
|
||||||
raise ImportError(
|
astra_db_client=astra_db_client,
|
||||||
"Could not import a recent astrapy python package. "
|
async_astra_db_client=async_astra_db_client,
|
||||||
"Please install it with `pip install --upgrade astrapy`."
|
namespace=namespace,
|
||||||
)
|
)
|
||||||
|
self.astra_env = astra_env
|
||||||
# Conflicting-arg checks:
|
self.collection = astra_env.astra_db.collection(collection_name)
|
||||||
if astra_db_client is not None or async_astra_db_client is not None:
|
|
||||||
if token is not None or api_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
|
||||||
"AstraDB if passing 'token' and 'api_endpoint'."
|
|
||||||
)
|
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.filter = filter_criteria
|
self.filter = filter_criteria
|
||||||
self.projection = projection
|
self.projection = projection
|
||||||
@ -64,47 +60,11 @@ class AstraDBLoader(BaseLoader):
|
|||||||
self.nb_prefetched = nb_prefetched
|
self.nb_prefetched = nb_prefetched
|
||||||
self.extraction_function = extraction_function
|
self.extraction_function = extraction_function
|
||||||
|
|
||||||
astra_db = astra_db_client
|
|
||||||
async_astra_db = async_astra_db_client
|
|
||||||
|
|
||||||
if token and api_endpoint:
|
|
||||||
astra_db = AstraDB(
|
|
||||||
token=token,
|
|
||||||
api_endpoint=api_endpoint,
|
|
||||||
namespace=namespace,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from astrapy.db import AsyncAstraDB
|
|
||||||
|
|
||||||
async_astra_db = AsyncAstraDB(
|
|
||||||
token=token,
|
|
||||||
api_endpoint=api_endpoint,
|
|
||||||
namespace=namespace,
|
|
||||||
)
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
pass
|
|
||||||
if not astra_db and not async_astra_db:
|
|
||||||
raise ValueError(
|
|
||||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or 'token' "
|
|
||||||
"and 'api_endpoint'"
|
|
||||||
)
|
|
||||||
self.collection = astra_db.collection(collection_name) if astra_db else None
|
|
||||||
if async_astra_db:
|
|
||||||
from astrapy.db import AsyncAstraDBCollection
|
|
||||||
|
|
||||||
self.async_collection = AsyncAstraDBCollection(
|
|
||||||
astra_db=async_astra_db, collection_name=collection_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.async_collection = None
|
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
"""Eagerly load the content."""
|
"""Eagerly load the content."""
|
||||||
return list(self.lazy_load())
|
return list(self.lazy_load())
|
||||||
|
|
||||||
def lazy_load(self) -> Iterator[Document]:
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
if not self.collection:
|
|
||||||
raise ValueError("Missing AstraDB client")
|
|
||||||
queue = Queue(self.nb_prefetched)
|
queue = Queue(self.nb_prefetched)
|
||||||
t = threading.Thread(target=self.fetch_results, args=(queue,))
|
t = threading.Thread(target=self.fetch_results, args=(queue,))
|
||||||
t.start()
|
t.start()
|
||||||
@ -120,9 +80,27 @@ class AstraDBLoader(BaseLoader):
|
|||||||
return [doc async for doc in self.alazy_load()]
|
return [doc async for doc in self.alazy_load()]
|
||||||
|
|
||||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||||
if not self.async_collection:
|
if not self.astra_env.async_astra_db:
|
||||||
raise ValueError("Missing AsyncAstraDB client")
|
iterator = run_in_executor(
|
||||||
async for doc in self.async_collection.paginated_find(
|
None,
|
||||||
|
self.collection.paginated_find,
|
||||||
|
filter=self.filter,
|
||||||
|
options=self.find_options,
|
||||||
|
projection=self.projection,
|
||||||
|
sort=None,
|
||||||
|
prefetched=True,
|
||||||
|
)
|
||||||
|
done = object()
|
||||||
|
while True:
|
||||||
|
item = await run_in_executor(None, lambda it: next(it, done), iterator)
|
||||||
|
if item is done:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
return
|
||||||
|
async_collection = await self.astra_env.async_astra_db.collection(
|
||||||
|
self.collection_name
|
||||||
|
)
|
||||||
|
async for doc in async_collection.paginated_find(
|
||||||
filter=self.filter,
|
filter=self.filter,
|
||||||
options=self.find_options,
|
options=self.find_options,
|
||||||
projection=self.projection,
|
projection=self.projection,
|
||||||
@ -132,8 +110,8 @@ class AstraDBLoader(BaseLoader):
|
|||||||
yield Document(
|
yield Document(
|
||||||
page_content=self.extraction_function(doc),
|
page_content=self.extraction_function(doc),
|
||||||
metadata={
|
metadata={
|
||||||
"namespace": self.async_collection.astra_db.namespace,
|
"namespace": async_collection.astra_db.namespace,
|
||||||
"api_endpoint": self.async_collection.astra_db.base_url,
|
"api_endpoint": async_collection.astra_db.base_url,
|
||||||
"collection": self.collection_name,
|
"collection": self.collection_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
Iterator,
|
||||||
@ -13,6 +16,11 @@ from typing import (
|
|||||||
|
|
||||||
from langchain_core.stores import BaseStore, ByteStore
|
from langchain_core.stores import BaseStore, ByteStore
|
||||||
|
|
||||||
|
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from astrapy.db import AstraDB
|
||||||
|
|
||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
|
|
||||||
|
|
||||||
@ -22,31 +30,19 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
|
|||||||
collection_name: str,
|
collection_name: str,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
api_endpoint: Optional[str] = None,
|
api_endpoint: Optional[str] = None,
|
||||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
astra_db_client: Optional[AstraDB] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
astra_env = AstraDBEnvironment(
|
||||||
from astrapy.db import AstraDB, AstraDBCollection
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import a recent astrapy python package. "
|
|
||||||
"Please install it with `pip install --upgrade astrapy`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Conflicting-arg checks:
|
|
||||||
if astra_db_client is not None:
|
|
||||||
if token is not None or api_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
|
||||||
"'token' and 'api_endpoint'."
|
|
||||||
)
|
|
||||||
|
|
||||||
astra_db = astra_db_client or AstraDB(
|
|
||||||
token=token,
|
token=token,
|
||||||
api_endpoint=api_endpoint,
|
api_endpoint=api_endpoint,
|
||||||
|
astra_db_client=astra_db_client,
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
)
|
)
|
||||||
self.collection = AstraDBCollection(collection_name, astra_db=astra_db)
|
self.astra_db = astra_env.astra_db
|
||||||
|
self.collection = self.astra_db.create_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def decode_value(self, value: Any) -> Optional[V]:
|
def decode_value(self, value: Any) -> Optional[V]:
|
||||||
|
84
libs/community/langchain_community/utilities/astradb.py
Normal file
84
libs/community/langchain_community/utilities/astradb.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from astrapy.db import (
|
||||||
|
AstraDB,
|
||||||
|
AsyncAstraDB,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AstraDBEnvironment:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
api_endpoint: Optional[str] = None,
|
||||||
|
astra_db_client: Optional[AstraDB] = None,
|
||||||
|
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
self.token = token
|
||||||
|
self.api_endpoint = api_endpoint
|
||||||
|
astra_db = astra_db_client
|
||||||
|
self.async_astra_db = async_astra_db_client
|
||||||
|
self.namespace = namespace
|
||||||
|
|
||||||
|
from astrapy import db
|
||||||
|
|
||||||
|
try:
|
||||||
|
from astrapy.db import AstraDB
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import a recent astrapy python package. "
|
||||||
|
"Please install it with `pip install --upgrade astrapy`."
|
||||||
|
)
|
||||||
|
|
||||||
|
supports_async = hasattr(db, "AsyncAstraDB")
|
||||||
|
|
||||||
|
# Conflicting-arg checks:
|
||||||
|
if astra_db_client is not None or async_astra_db_client is not None:
|
||||||
|
if token is not None or api_endpoint is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
||||||
|
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if token and api_endpoint:
|
||||||
|
astra_db = AstraDB(
|
||||||
|
token=self.token,
|
||||||
|
api_endpoint=self.api_endpoint,
|
||||||
|
namespace=self.namespace,
|
||||||
|
)
|
||||||
|
if supports_async:
|
||||||
|
self.async_astra_db = db.AsyncAstraDB(
|
||||||
|
token=self.token,
|
||||||
|
api_endpoint=self.api_endpoint,
|
||||||
|
namespace=self.namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
if astra_db:
|
||||||
|
self.astra_db = astra_db
|
||||||
|
else:
|
||||||
|
if self.async_astra_db:
|
||||||
|
self.astra_db = AstraDB(
|
||||||
|
token=self.async_astra_db.token,
|
||||||
|
api_endpoint=self.async_astra_db.base_url,
|
||||||
|
api_path=self.async_astra_db.api_path,
|
||||||
|
api_version=self.async_astra_db.api_version,
|
||||||
|
namespace=self.async_astra_db.namespace,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
||||||
|
"'token' and 'api_endpoint'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.async_astra_db and self.astra_db and supports_async:
|
||||||
|
self.async_astra_db = db.AsyncAstraDB(
|
||||||
|
token=self.astra_db.token,
|
||||||
|
api_endpoint=self.astra_db.base_url,
|
||||||
|
api_path=self.astra_db.api_path,
|
||||||
|
api_version=self.astra_db.api_version,
|
||||||
|
namespace=self.astra_db.namespace,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user