mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 00:48:45 +00:00
community: Use _AstraDBCollectionEnvironment in AstraDB VectorStore (community) (#17635)
Another PR will be done for the langchain-astradb package. Note: for future PRs, devs will be done in the partner package only. This one is just to align with the rest of the components in the community package and it fixes a bunch of issues.
This commit is contained in:
parent
0b33abc8b1
commit
19ebc7418e
@ -1,13 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from asyncio import Task
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
@ -17,17 +16,21 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_core._api.deprecation import deprecated
|
from langchain_core._api.deprecation import deprecated
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.runnables import run_in_executor
|
|
||||||
from langchain_core.runnables.utils import gather_with_concurrency
|
from langchain_core.runnables.utils import gather_with_concurrency
|
||||||
from langchain_core.utils.iter import batch_iterate
|
from langchain_core.utils.iter import batch_iterate
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
|
from langchain_community.utilities.astradb import (
|
||||||
|
SetupMode,
|
||||||
|
_AstraDBCollectionEnvironment,
|
||||||
|
)
|
||||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -167,28 +170,12 @@ class AstraDB(VectorStore):
|
|||||||
bulk_insert_batch_concurrency: Optional[int] = None,
|
bulk_insert_batch_concurrency: Optional[int] = None,
|
||||||
bulk_insert_overwrite_concurrency: Optional[int] = None,
|
bulk_insert_overwrite_concurrency: Optional[int] = None,
|
||||||
bulk_delete_concurrency: Optional[int] = None,
|
bulk_delete_concurrency: Optional[int] = None,
|
||||||
|
setup_mode: SetupMode = SetupMode.SYNC,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create an AstraDB vector store object. See class docstring for help.
|
Create an AstraDB vector store object. See class docstring for help.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
from astrapy.db import AstraDB as LibAstraDB
|
|
||||||
from astrapy.db import AstraDBCollection
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import a recent astrapy python package. "
|
|
||||||
"Please install it with `pip install --upgrade astrapy`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Conflicting-arg checks:
|
|
||||||
if astra_db_client is not None or async_astra_db_client is not None:
|
|
||||||
if token is not None or api_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
|
||||||
"AstraDB if passing 'token' and 'api_endpoint'."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.token = token
|
self.token = token
|
||||||
@ -207,105 +194,35 @@ class AstraDB(VectorStore):
|
|||||||
bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
|
bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
|
||||||
)
|
)
|
||||||
# "vector-related" settings
|
# "vector-related" settings
|
||||||
self._embedding_dimension: Optional[int] = None
|
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
|
embedding_dimension: Union[int, Awaitable[int], None] = None
|
||||||
|
if setup_mode == SetupMode.ASYNC:
|
||||||
|
embedding_dimension = self._aget_embedding_dimension()
|
||||||
|
elif setup_mode == SetupMode.SYNC:
|
||||||
|
embedding_dimension = self._get_embedding_dimension()
|
||||||
|
|
||||||
self.astra_db = astra_db_client
|
self.astra_env = _AstraDBCollectionEnvironment(
|
||||||
self.async_astra_db = async_astra_db_client
|
collection_name=collection_name,
|
||||||
self.collection = None
|
token=token,
|
||||||
self.async_collection = None
|
api_endpoint=api_endpoint,
|
||||||
|
astra_db_client=astra_db_client,
|
||||||
if token and api_endpoint:
|
async_astra_db_client=async_astra_db_client,
|
||||||
self.astra_db = LibAstraDB(
|
namespace=namespace,
|
||||||
token=self.token,
|
setup_mode=setup_mode,
|
||||||
api_endpoint=self.api_endpoint,
|
pre_delete_collection=pre_delete_collection,
|
||||||
namespace=self.namespace,
|
embedding_dimension=embedding_dimension,
|
||||||
|
metric=metric,
|
||||||
)
|
)
|
||||||
try:
|
self.astra_db = self.astra_env.astra_db
|
||||||
from astrapy.db import AsyncAstraDB
|
self.async_astra_db = self.astra_env.async_astra_db
|
||||||
|
self.collection = self.astra_env.collection
|
||||||
self.async_astra_db = AsyncAstraDB(
|
self.async_collection = self.astra_env.async_collection
|
||||||
token=self.token,
|
|
||||||
api_endpoint=self.api_endpoint,
|
|
||||||
namespace=self.namespace,
|
|
||||||
)
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.astra_db is not None:
|
|
||||||
self.collection = AstraDBCollection(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
astra_db=self.astra_db,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.async_setup_db_task: Optional[Task] = None
|
|
||||||
if self.async_astra_db is not None:
|
|
||||||
from astrapy.db import AsyncAstraDBCollection
|
|
||||||
|
|
||||||
self.async_collection = AsyncAstraDBCollection(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
astra_db=self.async_astra_db,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.async_setup_db_task = asyncio.create_task(
|
|
||||||
self._setup_db(pre_delete_collection)
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.async_setup_db_task is None:
|
|
||||||
if not pre_delete_collection:
|
|
||||||
self._provision_collection()
|
|
||||||
else:
|
|
||||||
self.clear()
|
|
||||||
|
|
||||||
def _ensure_astra_db_client(self): # type: ignore[no-untyped-def]
|
|
||||||
if not self.astra_db:
|
|
||||||
raise ValueError("Missing AstraDB client")
|
|
||||||
|
|
||||||
async def _setup_db(self, pre_delete_collection: bool) -> None:
|
|
||||||
if pre_delete_collection:
|
|
||||||
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
)
|
|
||||||
await self._aprovision_collection()
|
|
||||||
|
|
||||||
async def _ensure_db_setup(self) -> None:
|
|
||||||
if self.async_setup_db_task:
|
|
||||||
await self.async_setup_db_task
|
|
||||||
|
|
||||||
def _get_embedding_dimension(self) -> int:
|
def _get_embedding_dimension(self) -> int:
|
||||||
if self._embedding_dimension is None:
|
return len(self.embedding.embed_query(text="This is a sample sentence."))
|
||||||
self._embedding_dimension = len(
|
|
||||||
self.embedding.embed_query("This is a sample sentence.")
|
|
||||||
)
|
|
||||||
return self._embedding_dimension
|
|
||||||
|
|
||||||
def _provision_collection(self) -> None:
|
async def _aget_embedding_dimension(self) -> int:
|
||||||
"""
|
return len(await self.embedding.aembed_query(text="This is a sample sentence."))
|
||||||
Run the API invocation to create the collection on the backend.
|
|
||||||
|
|
||||||
Internal-usage method, no object members are set,
|
|
||||||
other than working on the underlying actual storage.
|
|
||||||
"""
|
|
||||||
self.astra_db.create_collection( # type: ignore[union-attr]
|
|
||||||
dimension=self._get_embedding_dimension(),
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
metric=self.metric,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _aprovision_collection(self) -> None:
|
|
||||||
"""
|
|
||||||
Run the API invocation to create the collection on the backend.
|
|
||||||
|
|
||||||
Internal-usage method, no object members are set,
|
|
||||||
other than working on the underlying actual storage.
|
|
||||||
"""
|
|
||||||
await self.async_astra_db.create_collection( # type: ignore[union-attr]
|
|
||||||
dimension=self._get_embedding_dimension(),
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
metric=self.metric,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embeddings(self) -> Embeddings:
|
def embeddings(self) -> Embeddings:
|
||||||
@ -326,14 +243,12 @@ class AstraDB(VectorStore):
|
|||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Empty the collection of all its stored entries."""
|
"""Empty the collection of all its stored entries."""
|
||||||
self.delete_collection()
|
self.astra_env.ensure_db_setup()
|
||||||
self._provision_collection()
|
self.collection.delete_many({})
|
||||||
|
|
||||||
async def aclear(self) -> None:
|
async def aclear(self) -> None:
|
||||||
"""Empty the collection of all its stored entries."""
|
"""Empty the collection of all its stored entries."""
|
||||||
await self._ensure_db_setup()
|
await self.astra_env.aensure_db_setup()
|
||||||
if not self.async_astra_db:
|
|
||||||
await run_in_executor(None, self.clear)
|
|
||||||
await self.async_collection.delete_many({}) # type: ignore[union-attr]
|
await self.async_collection.delete_many({}) # type: ignore[union-attr]
|
||||||
|
|
||||||
def delete_by_document_id(self, document_id: str) -> bool:
|
def delete_by_document_id(self, document_id: str) -> bool:
|
||||||
@ -341,7 +256,7 @@ class AstraDB(VectorStore):
|
|||||||
Remove a single document from the store, given its document_id (str).
|
Remove a single document from the store, given its document_id (str).
|
||||||
Return True if a document has indeed been deleted, False if ID not found.
|
Return True if a document has indeed been deleted, False if ID not found.
|
||||||
"""
|
"""
|
||||||
self._ensure_astra_db_client()
|
self.astra_env.ensure_db_setup()
|
||||||
deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
|
deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
|
||||||
return ((deletion_response or {}).get("status") or {}).get(
|
return ((deletion_response or {}).get("status") or {}).get(
|
||||||
"deletedCount", 0
|
"deletedCount", 0
|
||||||
@ -352,9 +267,7 @@ class AstraDB(VectorStore):
|
|||||||
Remove a single document from the store, given its document_id (str).
|
Remove a single document from the store, given its document_id (str).
|
||||||
Return True if a document has indeed been deleted, False if ID not found.
|
Return True if a document has indeed been deleted, False if ID not found.
|
||||||
"""
|
"""
|
||||||
await self._ensure_db_setup()
|
await self.astra_env.aensure_db_setup()
|
||||||
if not self.async_collection:
|
|
||||||
return await run_in_executor(None, self.delete_by_document_id, document_id)
|
|
||||||
deletion_response = await self.async_collection.delete_one(document_id)
|
deletion_response = await self.async_collection.delete_one(document_id)
|
||||||
return ((deletion_response or {}).get("status") or {}).get(
|
return ((deletion_response or {}).get("status") or {}).get(
|
||||||
"deletedCount", 0
|
"deletedCount", 0
|
||||||
@ -439,8 +352,8 @@ class AstraDB(VectorStore):
|
|||||||
Stored data is lost and unrecoverable, resources are freed.
|
Stored data is lost and unrecoverable, resources are freed.
|
||||||
Use with caution.
|
Use with caution.
|
||||||
"""
|
"""
|
||||||
self._ensure_astra_db_client()
|
self.astra_env.ensure_db_setup()
|
||||||
self.astra_db.delete_collection( # type: ignore[union-attr]
|
self.astra_db.delete_collection(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -451,10 +364,8 @@ class AstraDB(VectorStore):
|
|||||||
Stored data is lost and unrecoverable, resources are freed.
|
Stored data is lost and unrecoverable, resources are freed.
|
||||||
Use with caution.
|
Use with caution.
|
||||||
"""
|
"""
|
||||||
await self._ensure_db_setup()
|
await self.astra_env.aensure_db_setup()
|
||||||
if not self.async_astra_db:
|
await self.async_astra_db.delete_collection(
|
||||||
await run_in_executor(None, self.delete_collection)
|
|
||||||
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -569,7 +480,7 @@ class AstraDB(VectorStore):
|
|||||||
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
|
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
|
||||||
"which will be ignored."
|
"which will be ignored."
|
||||||
)
|
)
|
||||||
self._ensure_astra_db_client()
|
self.astra_env.ensure_db_setup()
|
||||||
|
|
||||||
embedding_vectors = self.embedding.embed_documents(list(texts))
|
embedding_vectors = self.embedding.embed_documents(list(texts))
|
||||||
documents_to_insert = self._get_documents_to_insert(
|
documents_to_insert = self._get_documents_to_insert(
|
||||||
@ -655,22 +566,13 @@ class AstraDB(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List[str]: List of ids of the added texts.
|
List[str]: List of ids of the added texts.
|
||||||
"""
|
"""
|
||||||
await self._ensure_db_setup()
|
|
||||||
if not self.async_collection:
|
|
||||||
await super().aadd_texts(
|
|
||||||
texts,
|
|
||||||
metadatas,
|
|
||||||
ids=ids,
|
|
||||||
batch_size=batch_size,
|
|
||||||
batch_concurrency=batch_concurrency,
|
|
||||||
overwrite_concurrency=overwrite_concurrency,
|
|
||||||
)
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Method 'aadd_texts' of AstraDB vector store invoked with "
|
"Method 'aadd_texts' of AstraDB vector store invoked with "
|
||||||
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
|
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
|
||||||
"which will be ignored."
|
"which will be ignored."
|
||||||
)
|
)
|
||||||
|
await self.astra_env.aensure_db_setup()
|
||||||
|
|
||||||
embedding_vectors = await self.embedding.aembed_documents(list(texts))
|
embedding_vectors = await self.embedding.aembed_documents(list(texts))
|
||||||
documents_to_insert = self._get_documents_to_insert(
|
documents_to_insert = self._get_documents_to_insert(
|
||||||
@ -731,7 +633,7 @@ class AstraDB(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of (Document, score, id), the most similar to the query vector.
|
List of (Document, score, id), the most similar to the query vector.
|
||||||
"""
|
"""
|
||||||
self._ensure_astra_db_client()
|
self.astra_env.ensure_db_setup()
|
||||||
metadata_parameter = self._filter_to_metadata(filter)
|
metadata_parameter = self._filter_to_metadata(filter)
|
||||||
#
|
#
|
||||||
hits = list(
|
hits = list(
|
||||||
@ -773,15 +675,7 @@ class AstraDB(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of (Document, score, id), the most similar to the query vector.
|
List of (Document, score, id), the most similar to the query vector.
|
||||||
"""
|
"""
|
||||||
await self._ensure_db_setup()
|
await self.astra_env.aensure_db_setup()
|
||||||
if not self.async_collection:
|
|
||||||
return await run_in_executor(
|
|
||||||
None,
|
|
||||||
self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type]
|
|
||||||
embedding,
|
|
||||||
k,
|
|
||||||
filter,
|
|
||||||
)
|
|
||||||
metadata_parameter = self._filter_to_metadata(filter)
|
metadata_parameter = self._filter_to_metadata(filter)
|
||||||
#
|
#
|
||||||
return [
|
return [
|
||||||
@ -1010,7 +904,7 @@ class AstraDB(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
self._ensure_astra_db_client()
|
self.astra_env.ensure_db_setup()
|
||||||
metadata_parameter = self._filter_to_metadata(filter)
|
metadata_parameter = self._filter_to_metadata(filter)
|
||||||
|
|
||||||
prefetch_hits = list(
|
prefetch_hits = list(
|
||||||
@ -1051,18 +945,7 @@ class AstraDB(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
await self._ensure_db_setup()
|
await self.astra_env.aensure_db_setup()
|
||||||
if not self.async_collection:
|
|
||||||
return await run_in_executor(
|
|
||||||
None,
|
|
||||||
self.max_marginal_relevance_search_by_vector,
|
|
||||||
embedding,
|
|
||||||
k,
|
|
||||||
fetch_k,
|
|
||||||
lambda_mult,
|
|
||||||
filter,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
metadata_parameter = self._filter_to_metadata(filter)
|
metadata_parameter = self._filter_to_metadata(filter)
|
||||||
|
|
||||||
prefetch_hits = [
|
prefetch_hits = [
|
||||||
|
Loading…
Reference in New Issue
Block a user