mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 14:03:26 +00:00
community: Add async methods to AstraDBCache (#17415)
Adds async methods to AstraDBCache
This commit is contained in:
committed by
GitHub
parent
e438fe6be9
commit
ca2d4078f3
@@ -1,6 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
import asyncio
|
||||
import inspect
|
||||
from asyncio import InvalidStateError, Task
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Awaitable, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import (
|
||||
@@ -9,7 +13,13 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class AstraDBEnvironment:
|
||||
class SetupMode(Enum):
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
OFF = 3
|
||||
|
||||
|
||||
class _AstraDBEnvironment:
|
||||
def __init__(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
@@ -21,21 +31,20 @@ class AstraDBEnvironment:
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
astra_db = astra_db_client
|
||||
self.async_astra_db = async_astra_db_client
|
||||
async_astra_db = async_astra_db_client
|
||||
self.namespace = namespace
|
||||
|
||||
from astrapy import db
|
||||
|
||||
try:
|
||||
from astrapy.db import AstraDB
|
||||
from astrapy.db import (
|
||||
AstraDB,
|
||||
AsyncAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
|
||||
supports_async = hasattr(db, "AsyncAstraDB")
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None or async_astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
@@ -46,39 +55,115 @@ class AstraDBEnvironment:
|
||||
|
||||
if token and api_endpoint:
|
||||
astra_db = AstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
async_astra_db = AsyncAstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if supports_async:
|
||||
self.async_astra_db = db.AsyncAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
|
||||
if astra_db:
|
||||
self.astra_db = astra_db
|
||||
else:
|
||||
if self.async_astra_db:
|
||||
self.astra_db = AstraDB(
|
||||
token=self.async_astra_db.token,
|
||||
api_endpoint=self.async_astra_db.base_url,
|
||||
api_path=self.async_astra_db.api_path,
|
||||
api_version=self.async_astra_db.api_version,
|
||||
namespace=self.async_astra_db.namespace,
|
||||
)
|
||||
if async_astra_db:
|
||||
self.async_astra_db = async_astra_db
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
||||
"'token' and 'api_endpoint'"
|
||||
self.async_astra_db = AsyncAstraDB(
|
||||
token=self.astra_db.token,
|
||||
api_endpoint=self.astra_db.base_url,
|
||||
api_path=self.astra_db.api_path,
|
||||
api_version=self.astra_db.api_version,
|
||||
namespace=self.astra_db.namespace,
|
||||
)
|
||||
elif async_astra_db:
|
||||
self.async_astra_db = async_astra_db
|
||||
self.astra_db = AstraDB(
|
||||
token=self.async_astra_db.token,
|
||||
api_endpoint=self.async_astra_db.base_url,
|
||||
api_path=self.async_astra_db.api_path,
|
||||
api_version=self.async_astra_db.api_version,
|
||||
namespace=self.async_astra_db.namespace,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
||||
"'token' and 'api_endpoint'"
|
||||
)
|
||||
|
||||
|
||||
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
pre_delete_collection: bool = False,
|
||||
embedding_dimension: Union[int, Awaitable[int], None] = None,
|
||||
metric: Optional[str] = None,
|
||||
) -> None:
|
||||
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
|
||||
|
||||
super().__init__(
|
||||
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
|
||||
)
|
||||
self.collection_name = collection_name
|
||||
self.collection = AstraDBCollection(
|
||||
collection_name=collection_name,
|
||||
astra_db=self.astra_db,
|
||||
)
|
||||
|
||||
self.async_collection = AsyncAstraDBCollection(
|
||||
collection_name=collection_name,
|
||||
astra_db=self.async_astra_db,
|
||||
)
|
||||
|
||||
self.async_setup_db_task: Optional[Task] = None
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
async_astra_db = self.async_astra_db
|
||||
|
||||
async def _setup_db() -> None:
|
||||
if pre_delete_collection:
|
||||
await async_astra_db.delete_collection(collection_name)
|
||||
if inspect.isawaitable(embedding_dimension):
|
||||
dimension = await embedding_dimension
|
||||
else:
|
||||
dimension = embedding_dimension
|
||||
await async_astra_db.create_collection(
|
||||
collection_name, dimension=dimension, metric=metric
|
||||
)
|
||||
|
||||
if not self.async_astra_db and self.astra_db and supports_async:
|
||||
self.async_astra_db = db.AsyncAstraDB(
|
||||
token=self.astra_db.token,
|
||||
api_endpoint=self.astra_db.base_url,
|
||||
api_path=self.astra_db.api_path,
|
||||
api_version=self.astra_db.api_version,
|
||||
namespace=self.astra_db.namespace,
|
||||
self.async_setup_db_task = asyncio.create_task(_setup_db())
|
||||
elif setup_mode == SetupMode.SYNC:
|
||||
if pre_delete_collection:
|
||||
self.astra_db.delete_collection(collection_name)
|
||||
if inspect.isawaitable(embedding_dimension):
|
||||
raise ValueError(
|
||||
"Cannot use an awaitable embedding_dimension with async_setup "
|
||||
"set to False"
|
||||
)
|
||||
self.astra_db.create_collection(
|
||||
collection_name,
|
||||
dimension=embedding_dimension, # type: ignore[arg-type]
|
||||
metric=metric,
|
||||
)
|
||||
|
||||
def ensure_db_setup(self) -> None:
|
||||
if self.async_setup_db_task:
|
||||
try:
|
||||
self.async_setup_db_task.result()
|
||||
except InvalidStateError:
|
||||
raise ValueError(
|
||||
"Asynchronous setup of the DB not finished. "
|
||||
"NB: AstraDB components sync methods shouldn't be called from the "
|
||||
"event loop. Consider using their async equivalents."
|
||||
)
|
||||
|
||||
async def aensure_db_setup(self) -> None:
|
||||
if self.async_setup_db_task:
|
||||
await self.async_setup_db_task
|
||||
|
Reference in New Issue
Block a user