mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
community: Factorize AstraDB components constructors (#16779)
* Adds `AstraDBEnvironment` class and use it in `AstraDBLoader`, `AstraDBCache`, `AstraDBSemanticCache`, `AstraDBBaseStore` and `AstraDBChatMessageHistory` * Create an `AsyncAstraDB` if we only have an `AstraDB` and vice-versa so: * we always have an instance of `AstraDB` * we always have an instance of `AsyncAstraDB` for recent versions of astrapy * Create collection if not exists in `AstraDBBaseStore` * Some typing improvements Note: `AstraDB` `VectorStore` not using `AstraDBEnvironment` at the moment. This will be done after the `langchain-astradb` package is out.
This commit is contained in:
parent
93366861c7
commit
9d458d089a
@ -60,12 +60,14 @@ from langchain_core.load.load import loads
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
from langchain_community.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import momento
|
||||
from astrapy.db import AstraDB
|
||||
from cassandra.cluster import Session as CassandraSession
|
||||
|
||||
|
||||
@ -1262,7 +1264,7 @@ class AstraDBCache(BaseCache):
|
||||
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -1278,39 +1280,17 @@ class AstraDBCache(BaseCache):
|
||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
"""
|
||||
try:
|
||||
from astrapy.db import (
|
||||
AstraDB as LibAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
astra_env = AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.astra_db = astra_env.astra_db
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
self.collection_name = collection_name
|
||||
|
||||
@staticmethod
|
||||
def _make_id(prompt: str, llm_string: str) -> str:
|
||||
@ -1364,7 +1344,7 @@ class AstraDBCache(BaseCache):
|
||||
def delete(self, prompt: str, llm_string: str) -> None:
|
||||
"""Evict from cache if there's an entry."""
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
return self.collection.delete_one(doc_id)
|
||||
self.collection.delete_one(doc_id)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. This is for all LLMs at once."""
|
||||
@ -1395,7 +1375,7 @@ class AstraDBSemanticCache(BaseCache):
|
||||
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
embedding: Embeddings,
|
||||
metric: Optional[str] = None,
|
||||
@ -1423,22 +1403,13 @@ class AstraDBSemanticCache(BaseCache):
|
||||
The default score threshold is tuned to the default metric.
|
||||
Tune it carefully yourself if switching to another distance metric.
|
||||
"""
|
||||
try:
|
||||
from astrapy.db import (
|
||||
AstraDB as LibAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
astra_env = AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.astra_db = astra_env.astra_db
|
||||
|
||||
self.embedding = embedding
|
||||
self.metric = metric
|
||||
@ -1457,18 +1428,7 @@ class AstraDBSemanticCache(BaseCache):
|
||||
self.embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
dimension=self.embedding_dimension,
|
||||
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import typing
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB as LibAstraDB
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
@ -42,40 +43,22 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
||||
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB'
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create an Astra DB chat message history."""
|
||||
try:
|
||||
from astrapy.db import AstraDB as LibAstraDB
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
astra_env = AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.astra_db = astra_env.astra_db
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(collection_name)
|
||||
|
||||
self.session_id = session_id
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(self.collection_name)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
|
@ -16,8 +16,10 @@ from typing import (
|
||||
)
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
@ -42,21 +44,15 @@ class AstraDBLoader(BaseLoader):
|
||||
nb_prefetched: int = 1000,
|
||||
extraction_function: Callable[[Dict], str] = json.dumps,
|
||||
) -> None:
|
||||
try:
|
||||
from astrapy.db import AstraDB
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None or async_astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
||||
"AstraDB if passing 'token' and 'api_endpoint'."
|
||||
)
|
||||
astra_env = AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
async_astra_db_client=async_astra_db_client,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.astra_env = astra_env
|
||||
self.collection = astra_env.astra_db.collection(collection_name)
|
||||
self.collection_name = collection_name
|
||||
self.filter = filter_criteria
|
||||
self.projection = projection
|
||||
@ -64,47 +60,11 @@ class AstraDBLoader(BaseLoader):
|
||||
self.nb_prefetched = nb_prefetched
|
||||
self.extraction_function = extraction_function
|
||||
|
||||
astra_db = astra_db_client
|
||||
async_astra_db = async_astra_db_client
|
||||
|
||||
if token and api_endpoint:
|
||||
astra_db = AstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=namespace,
|
||||
)
|
||||
try:
|
||||
from astrapy.db import AsyncAstraDB
|
||||
|
||||
async_astra_db = AsyncAstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=namespace,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
if not astra_db and not async_astra_db:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or 'token' "
|
||||
"and 'api_endpoint'"
|
||||
)
|
||||
self.collection = astra_db.collection(collection_name) if astra_db else None
|
||||
if async_astra_db:
|
||||
from astrapy.db import AsyncAstraDBCollection
|
||||
|
||||
self.async_collection = AsyncAstraDBCollection(
|
||||
astra_db=async_astra_db, collection_name=collection_name
|
||||
)
|
||||
else:
|
||||
self.async_collection = None
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Eagerly load the content."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
if not self.collection:
|
||||
raise ValueError("Missing AstraDB client")
|
||||
queue = Queue(self.nb_prefetched)
|
||||
t = threading.Thread(target=self.fetch_results, args=(queue,))
|
||||
t.start()
|
||||
@ -120,9 +80,27 @@ class AstraDBLoader(BaseLoader):
|
||||
return [doc async for doc in self.alazy_load()]
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
if not self.async_collection:
|
||||
raise ValueError("Missing AsyncAstraDB client")
|
||||
async for doc in self.async_collection.paginated_find(
|
||||
if not self.astra_env.async_astra_db:
|
||||
iterator = run_in_executor(
|
||||
None,
|
||||
self.collection.paginated_find,
|
||||
filter=self.filter,
|
||||
options=self.find_options,
|
||||
projection=self.projection,
|
||||
sort=None,
|
||||
prefetched=True,
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
item = await run_in_executor(None, lambda it: next(it, done), iterator)
|
||||
if item is done:
|
||||
break
|
||||
yield item
|
||||
return
|
||||
async_collection = await self.astra_env.async_astra_db.collection(
|
||||
self.collection_name
|
||||
)
|
||||
async for doc in async_collection.paginated_find(
|
||||
filter=self.filter,
|
||||
options=self.find_options,
|
||||
projection=self.projection,
|
||||
@ -132,8 +110,8 @@ class AstraDBLoader(BaseLoader):
|
||||
yield Document(
|
||||
page_content=self.extraction_function(doc),
|
||||
metadata={
|
||||
"namespace": self.async_collection.astra_db.namespace,
|
||||
"api_endpoint": self.async_collection.astra_db.base_url,
|
||||
"namespace": async_collection.astra_db.namespace,
|
||||
"api_endpoint": async_collection.astra_db.base_url,
|
||||
"collection": self.collection_name,
|
||||
},
|
||||
)
|
||||
|
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Iterator,
|
||||
@ -13,6 +16,11 @@ from typing import (
|
||||
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB
|
||||
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
@ -22,31 +30,19 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
try:
|
||||
from astrapy.db import AstraDB, AstraDBCollection
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
astra_db = astra_db_client or AstraDB(
|
||||
astra_env = AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.collection = AstraDBCollection(collection_name, astra_db=astra_db)
|
||||
self.astra_db = astra_env.astra_db
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def decode_value(self, value: Any) -> Optional[V]:
|
||||
|
84
libs/community/langchain_community/utilities/astradb.py
Normal file
84
libs/community/langchain_community/utilities/astradb.py
Normal file
@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import (
|
||||
AstraDB,
|
||||
AsyncAstraDB,
|
||||
)
|
||||
|
||||
|
||||
class AstraDBEnvironment:
|
||||
def __init__(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
astra_db = astra_db_client
|
||||
self.async_astra_db = async_astra_db_client
|
||||
self.namespace = namespace
|
||||
|
||||
from astrapy import db
|
||||
|
||||
try:
|
||||
from astrapy.db import AstraDB
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
|
||||
supports_async = hasattr(db, "AsyncAstraDB")
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None or async_astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
||||
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
if token and api_endpoint:
|
||||
astra_db = AstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if supports_async:
|
||||
self.async_astra_db = db.AsyncAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
|
||||
if astra_db:
|
||||
self.astra_db = astra_db
|
||||
else:
|
||||
if self.async_astra_db:
|
||||
self.astra_db = AstraDB(
|
||||
token=self.async_astra_db.token,
|
||||
api_endpoint=self.async_astra_db.base_url,
|
||||
api_path=self.async_astra_db.api_path,
|
||||
api_version=self.async_astra_db.api_version,
|
||||
namespace=self.async_astra_db.namespace,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
||||
"'token' and 'api_endpoint'"
|
||||
)
|
||||
|
||||
if not self.async_astra_db and self.astra_db and supports_async:
|
||||
self.async_astra_db = db.AsyncAstraDB(
|
||||
token=self.astra_db.token,
|
||||
api_endpoint=self.astra_db.base_url,
|
||||
api_path=self.astra_db.api_path,
|
||||
api_version=self.astra_db.api_version,
|
||||
namespace=self.astra_db.namespace,
|
||||
)
|
Loading…
Reference in New Issue
Block a user