mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
FEATURE: Add Databricks Vector Search as a new vector store (#13621)
**Description:** This PR adds Databricks Vector Search as a new vector store in LangChain. - [x] Add `DatabricksVectorSearch` in `langchain/vectorstores/` - [x] Unit tests - [x] Add [`databricks-vectorsearch`](https://pypi.org/project/databricks-vectorsearch/) as a new optional dependency We ran the following checks: - `make format` passed ✅ - `make lint` failed but the failures were caused by other files + Files touched by this PR passed the linter ✅ - `make test` passed ✅ - `make coverage` failed but the failures were caused by other files. Tests added by or related to this PR all passed + langchain/vectorstores/databricks_vector_search.py test coverage 94% ✅ - `make spell_check` passed ✅ The example notebook and updates to the [provider's documentation page](https://github.com/langchain-ai/langchain/blob/master/docs/docs/integrations/providers/databricks.md) will be added later in a separate PR. **Dependencies:** Optional dependency: [`databricks-vectorsearch`](https://pypi.org/project/databricks-vectorsearch/) --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
25387db432
commit
4b8e053fe8
@ -60,7 +60,7 @@
|
||||
" * document addition by id (`add_documents` method with `ids` argument)\n",
|
||||
" * delete by id (`delete` method with `ids` argument)\n",
|
||||
"\n",
|
||||
"Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `DashVector`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `MyScale`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `ScaNN`, `SupabaseVectorStore`, `TimescaleVector`, `Vald`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`.\n",
|
||||
"Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `MyScale`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `ScaNN`, `SupabaseVectorStore`, `TimescaleVector`, `Vald`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`.\n",
|
||||
" \n",
|
||||
"## Caution\n",
|
||||
"\n",
|
||||
|
@ -140,6 +140,12 @@ def _import_dashvector() -> Any:
|
||||
return DashVector
|
||||
|
||||
|
||||
def _import_databricks_vector_search() -> Any:
|
||||
from langchain.vectorstores.databricks_vector_search import DatabricksVectorSearch
|
||||
|
||||
return DatabricksVectorSearch
|
||||
|
||||
|
||||
def _import_deeplake() -> Any:
|
||||
from langchain.vectorstores.deeplake import DeepLake
|
||||
|
||||
@ -461,6 +467,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_clickhouse()
|
||||
elif name == "DashVector":
|
||||
return _import_dashvector()
|
||||
elif name == "DatabricksVectorSearch":
|
||||
return _import_databricks_vector_search()
|
||||
elif name == "DeepLake":
|
||||
return _import_deeplake()
|
||||
elif name == "Dingo":
|
||||
@ -575,6 +583,7 @@ __all__ = [
|
||||
"Clickhouse",
|
||||
"ClickhouseSettings",
|
||||
"DashVector",
|
||||
"DatabricksVectorSearch",
|
||||
"DeepLake",
|
||||
"Dingo",
|
||||
"DocArrayHnswSearch",
|
||||
|
@ -0,0 +1,473 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VST, VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from databricks.vector_search.client import VectorSearchIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabricksVectorSearch(VectorStore):
|
||||
"""`Databricks Vector Search` vector store.
|
||||
|
||||
To use, you should have the ``databricks-vectorsearch`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import DatabricksVectorSearch
|
||||
from databricks.vector_search.client import VectorSearchClient
|
||||
|
||||
vs_client = VectorSearchClient()
|
||||
vs_index = vs_client.get_index(
|
||||
endpoint_name="vs_endpoint",
|
||||
index_name="ml.llm.index"
|
||||
)
|
||||
vectorstore = DatabricksVectorSearch(vs_index)
|
||||
|
||||
Args:
|
||||
index: A Databricks Vector Search index object.
|
||||
embedding: The embedding model.
|
||||
Required for direct-access index or delta-sync index
|
||||
with self-managed embeddings.
|
||||
text_column: The name of the text column to use for the embeddings.
|
||||
Required for direct-access index or delta-sync index
|
||||
with self-managed embeddings.
|
||||
Make sure the text column specified is in the index.
|
||||
columns: The list of column names to get when doing the search.
|
||||
Defaults to ``[primary_key, text_column]``.
|
||||
|
||||
Delta-sync index with Databricks-managed embeddings manages the ingestion, deletion,
|
||||
and embedding for you.
|
||||
Manually ingestion/deletion of the documents/texts is not supported for delta-sync
|
||||
index.
|
||||
|
||||
If you want to use a delta-sync index with self-managed embeddings, you need to
|
||||
provide the embedding model and text column name to use for the embeddings.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import DatabricksVectorSearch
|
||||
from databricks.vector_search.client import VectorSearchClient
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
vs_client = VectorSearchClient()
|
||||
vs_index = vs_client.get_index(
|
||||
endpoint_name="vs_endpoint",
|
||||
index_name="ml.llm.index"
|
||||
)
|
||||
vectorstore = DatabricksVectorSearch(
|
||||
index=vs_index,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
text_column="document_content"
|
||||
)
|
||||
|
||||
If you want to manage the documents ingestion/deletion yourself, you can use a
|
||||
direct-access index.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import DatabricksVectorSearch
|
||||
from databricks.vector_search.client import VectorSearchClient
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
vs_client = VectorSearchClient()
|
||||
vs_index = vs_client.get_index(
|
||||
endpoint_name="vs_endpoint",
|
||||
index_name="ml.llm.index"
|
||||
)
|
||||
vectorstore = DatabricksVectorSearch(
|
||||
index=vs_index,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
text_column="document_content"
|
||||
)
|
||||
vectorstore.add_texts(
|
||||
texts=["text1", "text2"]
|
||||
)
|
||||
|
||||
For more information on Databricks Vector Search, see `Databricks Vector Search
|
||||
documentation <TODO: pending-link-to-documentation-page>`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: VectorSearchIndex,
|
||||
*,
|
||||
embedding: Optional[Embeddings] = None,
|
||||
text_column: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
):
|
||||
try:
|
||||
from databricks.vector_search.client import VectorSearchIndex
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import databricks-vectorsearch python package. "
|
||||
"Please install it with `pip install databricks-vectorsearch`."
|
||||
) from e
|
||||
# index
|
||||
self.index = index
|
||||
if not isinstance(index, VectorSearchIndex):
|
||||
raise TypeError("index must be of type VectorSearchIndex.")
|
||||
|
||||
# index_details
|
||||
index_details = self.index.describe()
|
||||
self.primary_key = index_details["primary_key"]
|
||||
self.index_type = index_details.get("index_type")
|
||||
self._delta_sync_index_spec = index_details.get("delta_sync_index_spec", dict())
|
||||
self._direct_access_index_spec = index_details.get(
|
||||
"direct_access_index_spec", dict()
|
||||
)
|
||||
|
||||
# text_column
|
||||
if self._is_databricks_managed_embeddings():
|
||||
index_source_column = self._embedding_source_column_name()
|
||||
# check if input text column matches the source column of the index
|
||||
if text_column is not None and text_column != index_source_column:
|
||||
raise ValueError(
|
||||
f"text_column '{text_column}' does not match with the "
|
||||
f"source column of the index: '{index_source_column}'."
|
||||
)
|
||||
self.text_column = index_source_column
|
||||
else:
|
||||
self._require_arg(text_column, "text_column")
|
||||
self.text_column = text_column
|
||||
|
||||
# columns
|
||||
self.columns = columns or []
|
||||
# add primary key column and source column if not in columns
|
||||
if self.primary_key not in self.columns:
|
||||
self.columns.append(self.primary_key)
|
||||
if self.text_column and self.text_column not in self.columns:
|
||||
self.columns.append(self.text_column)
|
||||
|
||||
# Validate specified columns are in the index
|
||||
if self._is_direct_access_index():
|
||||
index_schema = self._index_schema()
|
||||
if index_schema:
|
||||
for col in self.columns:
|
||||
if col not in index_schema:
|
||||
raise ValueError(
|
||||
f"column '{col}' is not in the index's schema."
|
||||
)
|
||||
|
||||
# embedding model
|
||||
if not self._is_databricks_managed_embeddings():
|
||||
# embedding model is required for direct-access index
|
||||
# or delta-sync index with self-managed embedding
|
||||
self._require_arg(embedding, "embedding")
|
||||
self._embedding = embedding
|
||||
# validate dimension matches
|
||||
index_embedding_dimension = self._embedding_vector_column_dimension()
|
||||
if index_embedding_dimension is not None:
|
||||
inferred_embedding_dimension = self._infer_embedding_dimension()
|
||||
if inferred_embedding_dimension != index_embedding_dimension:
|
||||
raise ValueError(
|
||||
f"embedding model's dimension '{inferred_embedding_dimension}' "
|
||||
f"does not match with the index's dimension "
|
||||
f"'{index_embedding_dimension}'."
|
||||
)
|
||||
else:
|
||||
if embedding is not None:
|
||||
logger.warning(
|
||||
"embedding model is not used in delta-sync index with "
|
||||
"Databricks-managed embeddings."
|
||||
)
|
||||
self._embedding = None
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
raise NotImplementedError(
|
||||
"`from_texts` is not supported. "
|
||||
"Use `add_texts` to add to existing direct-access index."
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add texts to the index.
|
||||
|
||||
Only support direct-access index.
|
||||
|
||||
Args:
|
||||
texts: List of texts to add.
|
||||
metadatas: List of metadata for each text. Defaults to None.
|
||||
ids: List of ids for each text. Defaults to None.
|
||||
If not provided, a random uuid will be generated for each text.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the index.
|
||||
"""
|
||||
self._op_require_direct_access_index("add_texts")
|
||||
assert self.embeddings is not None, "embedding model is required."
|
||||
# Wrap to list if input texts is a single string
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
texts = list(texts)
|
||||
vectors = self.embeddings.embed_documents(texts)
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
metadatas = metadatas or [{} for _ in texts]
|
||||
|
||||
updates = [
|
||||
{
|
||||
self.primary_key: id_,
|
||||
self.text_column: text,
|
||||
self._embedding_vector_column_name(): vector,
|
||||
**metadata,
|
||||
}
|
||||
for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas)
|
||||
]
|
||||
|
||||
upsert_resp = self.index.upsert(updates)
|
||||
if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"):
|
||||
failed_ids = upsert_resp.get("result", dict()).get(
|
||||
"failed_primary_keys", []
|
||||
)
|
||||
if upsert_resp.get("status") == "FAILURE":
|
||||
logger.error("Failed to add texts to the index.")
|
||||
else:
|
||||
logger.warning("Some texts failed to be added to the index.")
|
||||
return [id_ for id_ in ids if id_ not in failed_ids]
|
||||
|
||||
return ids
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
"""Access the query embedding object if available."""
|
||||
return self._embedding
|
||||
|
||||
def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete documents from the index.
|
||||
|
||||
Only support direct-access index.
|
||||
|
||||
Args:
|
||||
ids: List of ids of documents to delete.
|
||||
|
||||
Returns:
|
||||
True if successful.
|
||||
"""
|
||||
self._op_require_direct_access_index("delete")
|
||||
if ids is None:
|
||||
raise ValueError("ids must be provided.")
|
||||
self.index.delete(ids)
|
||||
return True
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filters: Filters to apply to the query. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_with_score = self.similarity_search_with_score(
|
||||
query=query, k=k, filters=filters, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_with_score]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query, along with scores.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filters: Filters to apply to the query. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding and score for each.
|
||||
"""
|
||||
if self._is_databricks_managed_embeddings():
|
||||
query_text = query
|
||||
query_vector = None
|
||||
else:
|
||||
assert self.embeddings is not None, "embedding model is required."
|
||||
query_text = None
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=self.columns,
|
||||
query_text=query_text,
|
||||
query_vector=query_vector,
|
||||
filters=filters,
|
||||
num_results=k,
|
||||
)
|
||||
return self._parse_search_response(search_resp)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filters: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filters: Filters to apply to the query. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_with_score = self.similarity_search_by_vector_with_score(
|
||||
embedding=embedding, k=k, filters=filters, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_with_score]
|
||||
|
||||
def similarity_search_by_vector_with_score(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filters: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector, along with scores.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filters: Filters to apply to the query. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding and score for each.
|
||||
"""
|
||||
if self._is_databricks_managed_embeddings():
|
||||
raise ValueError(
|
||||
"`similarity_search_by_vector` is not supported for index with "
|
||||
"Databricks-managed embeddings."
|
||||
)
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=self.columns,
|
||||
query_vector=embedding,
|
||||
filters=filters,
|
||||
num_results=k,
|
||||
)
|
||||
return self._parse_search_response(search_resp)
|
||||
|
||||
def _parse_search_response(self, search_resp: dict) -> List[Tuple[Document, float]]:
|
||||
"""Parse the search response into a list of Documents with score."""
|
||||
columns = [
|
||||
col["name"]
|
||||
for col in search_resp.get("manifest", dict()).get("columns", [])
|
||||
]
|
||||
docs_with_score = []
|
||||
for result in search_resp.get("result", dict()).get("data_array", []):
|
||||
doc_id = result[columns.index(self.primary_key)]
|
||||
text_content = result[columns.index(self.text_column)]
|
||||
metadata = {
|
||||
col: value
|
||||
for col, value in zip(columns[:-1], result[:-1])
|
||||
if col not in [self.primary_key, self.text_column]
|
||||
}
|
||||
metadata[self.primary_key] = doc_id
|
||||
score = result[-1]
|
||||
doc = Document(page_content=text_content, metadata=metadata)
|
||||
docs_with_score.append((doc, score))
|
||||
return docs_with_score
|
||||
|
||||
def _index_schema(self) -> Optional[dict]:
|
||||
"""Return the index schema as a dictionary.
|
||||
Return None if no schema found.
|
||||
"""
|
||||
if self._is_direct_access_index():
|
||||
schema_json = self._direct_access_index_spec.get("schema_json")
|
||||
if schema_json is not None:
|
||||
return json.loads(schema_json)
|
||||
return None
|
||||
|
||||
def _embedding_vector_column_name(self) -> Optional[str]:
|
||||
"""Return the name of the embedding vector column.
|
||||
None if the index is not a self-managed embedding index.
|
||||
"""
|
||||
return self._embedding_vector_column().get("name")
|
||||
|
||||
def _embedding_vector_column_dimension(self) -> Optional[int]:
|
||||
"""Return the dimension of the embedding vector column.
|
||||
None if the index is not a self-managed embedding index.
|
||||
"""
|
||||
return self._embedding_vector_column().get("embedding_dimension")
|
||||
|
||||
def _embedding_vector_column(self) -> dict:
|
||||
"""Return the embedding vector column configs as a dictionary.
|
||||
Empty if the index is not a self-managed embedding index.
|
||||
"""
|
||||
index_spec = (
|
||||
self._delta_sync_index_spec
|
||||
if self._is_delta_sync_index()
|
||||
else self._direct_access_index_spec
|
||||
)
|
||||
return next(iter(index_spec.get("embedding_vector_columns") or list()), dict())
|
||||
|
||||
def _embedding_source_column_name(self) -> Optional[str]:
|
||||
"""Return the name of the embedding source column.
|
||||
None if the index is not a Databricks-managed embedding index.
|
||||
"""
|
||||
return self._embedding_source_column().get("name")
|
||||
|
||||
def _embedding_source_column(self) -> dict:
|
||||
"""Return the embedding source column configs as a dictionary.
|
||||
Empty if the index is not a Databricks-managed embedding index.
|
||||
"""
|
||||
index_spec = self._delta_sync_index_spec
|
||||
return next(iter(index_spec.get("embedding_source_columns") or list()), dict())
|
||||
|
||||
def _is_delta_sync_index(self) -> bool:
|
||||
"""Return True if the index is a delta-sync index."""
|
||||
return self.index_type == "DELTA_SYNC"
|
||||
|
||||
def _is_direct_access_index(self) -> bool:
|
||||
"""Return True if the index is a direct-access index."""
|
||||
return self.index_type == "DIRECT_ACCESS"
|
||||
|
||||
def _is_databricks_managed_embeddings(self) -> bool:
|
||||
"""Return True if the embeddings are managed by Databricks Vector Search."""
|
||||
return (
|
||||
self._is_delta_sync_index()
|
||||
and self._embedding_source_column_name() is not None
|
||||
)
|
||||
|
||||
def _infer_embedding_dimension(self) -> int:
|
||||
"""Infer the embedding dimension from the embedding function."""
|
||||
assert self.embeddings is not None, "embedding model is required."
|
||||
return len(self.embeddings.embed_query("test"))
|
||||
|
||||
def _op_require_direct_access_index(self, op_name: str) -> None:
|
||||
"""
|
||||
Raise ValueError if the operation is not supported for direct-access index."""
|
||||
if not self._is_direct_access_index():
|
||||
raise ValueError(f"`{op_name}` is only supported for direct-access index.")
|
||||
|
||||
@staticmethod
|
||||
def _require_arg(arg: Any, arg_name: str) -> None:
|
||||
"""Raise ValueError if the required arg with name `arg_name` is None."""
|
||||
if not arg:
|
||||
raise ValueError(f"`{arg_name}` is required for this index.")
|
113
libs/langchain/poetry.lock
generated
113
libs/langchain/poetry.lock
generated
@ -1568,6 +1568,17 @@ click = ">=4.0"
|
||||
[package.extras]
|
||||
test = ["pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "cloudpickle"
|
||||
version = "2.2.1"
|
||||
description = "Extended pickling support for Python objects"
|
||||
optional = true
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"},
|
||||
{file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codespell"
|
||||
version = "2.2.6"
|
||||
@ -1794,6 +1805,41 @@ grpcio = [
|
||||
numpy = "*"
|
||||
protobuf = ">=3.8.0,<4.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "databricks-cli"
|
||||
version = "0.18.0"
|
||||
description = "A command line interface for Databricks"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "databricks-cli-0.18.0.tar.gz", hash = "sha256:87569709eda9af3e9db8047b691e420b5e980c62ef01675575c0d2b9b4211eb7"},
|
||||
{file = "databricks_cli-0.18.0-py2.py3-none-any.whl", hash = "sha256:1176a5f42d3e8af4abfc915446fb23abc44513e325c436725f5898cbb9e3384b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=7.0"
|
||||
oauthlib = ">=3.1.0"
|
||||
pyjwt = ">=1.7.0"
|
||||
requests = ">=2.17.3"
|
||||
six = ">=1.10.0"
|
||||
tabulate = ">=0.7.7"
|
||||
urllib3 = ">=1.26.7,<3"
|
||||
|
||||
[[package]]
|
||||
name = "databricks-vectorsearch"
|
||||
version = "0.21"
|
||||
description = "Databricks Vector Search Client"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "databricks_vectorsearch-0.21-py3-none-any.whl", hash = "sha256:18265affdb38d44e7ec4cc95f8267379c5109bdb6e75bb61a729f126b2433868"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mlflow-skinny = ">=2.4.0,<3"
|
||||
protobuf = ">=3.12.0,<5"
|
||||
requests = ">=2"
|
||||
|
||||
[[package]]
|
||||
name = "dataclasses-json"
|
||||
version = "0.6.1"
|
||||
@ -4719,6 +4765,39 @@ files = [
|
||||
{file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlflow-skinny"
|
||||
version = "2.8.1"
|
||||
description = "MLflow: A Platform for ML Development and Productionization"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "mlflow-skinny-2.8.1.tar.gz", hash = "sha256:8f46462e2df5ffd93a7f7d92ad1d3d7335adbe5e8e999543a3879963ae576d33"},
|
||||
{file = "mlflow_skinny-2.8.1-py3-none-any.whl", hash = "sha256:8e2a1a5b8f1e2a3437c1fab972115a4df25934cd07cd83b8eb70202af8ad814a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=7.0,<9"
|
||||
cloudpickle = "<3"
|
||||
databricks-cli = ">=0.8.7,<1"
|
||||
entrypoints = "<1"
|
||||
gitpython = ">=2.1.0,<4"
|
||||
importlib-metadata = ">=3.7.0,<4.7.0 || >4.7.0,<7"
|
||||
packaging = "<24"
|
||||
protobuf = ">=3.12.0,<5"
|
||||
pytz = "<2024"
|
||||
pyyaml = ">=5.1,<7"
|
||||
requests = ">=2.17.3,<3"
|
||||
sqlparse = ">=0.4.0,<1"
|
||||
|
||||
[package.extras]
|
||||
aliyun-oss = ["aliyunstoreplugin"]
|
||||
databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"]
|
||||
extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"]
|
||||
gateway = ["aiohttp (<4)", "boto3 (>=1.28.56,<2)", "fastapi (<1)", "pydantic (>=1.0,<3)", "uvicorn[standard] (<1)", "watchfiles (<1)"]
|
||||
sqlserver = ["mlflow-dbstore"]
|
||||
xethub = ["mlflow-xethub"]
|
||||
|
||||
[[package]]
|
||||
name = "mmh3"
|
||||
version = "3.1.0"
|
||||
@ -9271,6 +9350,22 @@ files = [
|
||||
{file = "sqlparams-5.1.0.tar.gz", hash = "sha256:1abe87a0684567265b2b86f5a482d5c37db237c0268d4c81774ffedce4300199"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlparse"
|
||||
version = "0.4.4"
|
||||
description = "A non-validating SQL parser."
|
||||
optional = true
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "sqlparse-0.4.4-py3-none-any.whl", hash = "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3"},
|
||||
{file = "sqlparse-0.4.4.tar.gz", hash = "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["build", "flake8"]
|
||||
doc = ["sphinx"]
|
||||
test = ["pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "stack-data"
|
||||
version = "0.6.3"
|
||||
@ -9368,6 +9463,20 @@ files = [
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0,<8.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "tabulate"
|
||||
version = "0.9.0"
|
||||
description = "Pretty-print tabular data"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
|
||||
{file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
widechars = ["wcwidth"]
|
||||
|
||||
[[package]]
|
||||
name = "telethon"
|
||||
version = "1.31.1"
|
||||
@ -11075,7 +11184,7 @@ cli = ["typer"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "databricks-vectorsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
||||
javascript = ["esprima"]
|
||||
llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
@ -11085,4 +11194,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "37e62f668e1acddc4e462fdac5f694af3916b6edbd1ccde0a54c9a57524d6c92"
|
||||
content-hash = "d57493dcdb7c864d71aa43463a57491f0c9cbd8fa8674d21c0b11117e8d7ea67"
|
||||
|
@ -144,6 +144,7 @@ google-cloud-documentai = {version = "^2.20.1", optional = true}
|
||||
fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"}
|
||||
javelin-sdk = {version = "^0.1.8", optional = true}
|
||||
msal = {version = "^1.25.0", optional = true}
|
||||
databricks-vectorsearch = {version = "^0.21", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
@ -381,6 +382,7 @@ extended_testing = [
|
||||
"rspace_client",
|
||||
"fireworks-ai",
|
||||
"javelin-sdk",
|
||||
"databricks-vectorsearch",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -1123,6 +1123,7 @@ def test_compatible_vectorstore_documentation() -> None:
|
||||
"Cassandra",
|
||||
"Chroma",
|
||||
"DashVector",
|
||||
"DatabricksVectorSearch",
|
||||
"DeepLake",
|
||||
"Dingo",
|
||||
"ElasticVectorSearch",
|
||||
|
@ -0,0 +1,526 @@
|
||||
import random
|
||||
import uuid
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.vectorstores import DatabricksVectorSearch
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
FakeEmbeddings,
|
||||
fake_texts,
|
||||
)
|
||||
|
||||
DEFAULT_VECTOR_DIMENSION = 4
|
||||
|
||||
|
||||
class FakeEmbeddingsWithDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION):
|
||||
super().__init__()
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (self.dimension - 1) + [float(i)]
|
||||
for i in range(len(embedding_texts))
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [float(1.0)] * (self.dimension - 1) + [float(0.0)]
|
||||
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_VECTOR_COLUMN = "text_vector"
|
||||
DEFAULT_PRIMARY_KEY = "id"
|
||||
|
||||
DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS = {
|
||||
"name": "ml.llm.index",
|
||||
"endpoint_name": "vector_search_endpoint",
|
||||
"index_type": "DELTA_SYNC",
|
||||
"primary_key": DEFAULT_PRIMARY_KEY,
|
||||
"delta_sync_index_spec": {
|
||||
"source_table": "ml.llm.source_table",
|
||||
"pipeline_type": "CONTINUOUS",
|
||||
"embedding_source_columns": [
|
||||
{
|
||||
"name": DEFAULT_TEXT_COLUMN,
|
||||
"embedding_model_endpoint_name": "openai-text-embedding",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS = {
|
||||
"name": "ml.llm.index",
|
||||
"endpoint_name": "vector_search_endpoint",
|
||||
"index_type": "DELTA_SYNC",
|
||||
"primary_key": DEFAULT_PRIMARY_KEY,
|
||||
"delta_sync_index_spec": {
|
||||
"source_table": "ml.llm.source_table",
|
||||
"pipeline_type": "CONTINUOUS",
|
||||
"embedding_vector_columns": [
|
||||
{
|
||||
"name": DEFAULT_VECTOR_COLUMN,
|
||||
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
DIRECT_ACCESS_INDEX = {
|
||||
"name": "ml.llm.index",
|
||||
"endpoint_name": "vector_search_endpoint",
|
||||
"index_type": "DIRECT_ACCESS",
|
||||
"primary_key": DEFAULT_PRIMARY_KEY,
|
||||
"direct_access_index_spec": {
|
||||
"embedding_vector_columns": [
|
||||
{
|
||||
"name": DEFAULT_VECTOR_COLUMN,
|
||||
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
|
||||
}
|
||||
],
|
||||
"schema_json": f"{{"
|
||||
f'"{DEFAULT_PRIMARY_KEY}": "int", '
|
||||
f'"feat1": "str", '
|
||||
f'"feat2": "float", '
|
||||
f'"text": "string", '
|
||||
f'"{DEFAULT_VECTOR_COLUMN}": "array<float>"'
|
||||
f"}}",
|
||||
},
|
||||
}
|
||||
|
||||
ALL_INDEXES = [
|
||||
DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS,
|
||||
DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS,
|
||||
DIRECT_ACCESS_INDEX,
|
||||
]
|
||||
|
||||
EXAMPLE_SEARCH_RESPONSE = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": DEFAULT_PRIMARY_KEY},
|
||||
{"name": DEFAULT_TEXT_COLUMN},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": len(fake_texts),
|
||||
"data_array": sorted(
|
||||
[[str(uuid.uuid4()), s, random.uniform(0, 1)] for s in fake_texts],
|
||||
key=lambda x: x[2], # type: ignore
|
||||
reverse=True,
|
||||
),
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
|
||||
def mock_index(index_details: dict) -> MagicMock:
|
||||
from databricks.vector_search.client import VectorSearchIndex
|
||||
|
||||
index = MagicMock(spec=VectorSearchIndex)
|
||||
index.describe.return_value = index_details
|
||||
return index
|
||||
|
||||
|
||||
def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch:
|
||||
return DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_init_delta_sync_with_managed_embeddings() -> None:
|
||||
index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS)
|
||||
vectorsearch = DatabricksVectorSearch(index)
|
||||
assert vectorsearch.index == index
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_init_delta_sync_with_self_managed_embeddings() -> None:
|
||||
index = mock_index(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS)
|
||||
vectorsearch = DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
)
|
||||
assert vectorsearch.index == index
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_init_direct_access_index() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch = DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
)
|
||||
assert vectorsearch.index == index
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_init_fail_no_index() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
DatabricksVectorSearch()
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_init_fail_index_none() -> None:
|
||||
with pytest.raises(TypeError) as ex:
|
||||
DatabricksVectorSearch(None)
|
||||
assert "index must be of type VectorSearchIndex." in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_init_fail_text_column_mismatch() -> None:
|
||||
index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DatabricksVectorSearch(
|
||||
index,
|
||||
text_column="some_other_column",
|
||||
)
|
||||
assert (
|
||||
f"text_column 'some_other_column' does not match with the source column of the "
|
||||
f"index: '{DEFAULT_TEXT_COLUMN}'." in str(ex.value)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
)
|
||||
def test_init_fail_no_text_column(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
)
|
||||
assert "`text_column` is required for this index." in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize("index_details", [DIRECT_ACCESS_INDEX])
|
||||
def test_init_fail_columns_not_in_schema(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
columns=["some_random_column"],
|
||||
)
|
||||
assert "column 'some_random_column' is not in the index's schema." in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
)
|
||||
def test_init_fail_no_embedding(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DatabricksVectorSearch(
|
||||
index,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
)
|
||||
assert "`embedding` is required for this index." in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
)
|
||||
def test_init_fail_embedding_dim_mismatch(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DatabricksVectorSearch(
|
||||
index,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
embedding=FakeEmbeddingsWithDimension(DEFAULT_VECTOR_DIMENSION + 1),
|
||||
)
|
||||
assert (
|
||||
f"embedding model's dimension '{DEFAULT_VECTOR_DIMENSION + 1}' does not match "
|
||||
f"with the index's dimension '{DEFAULT_VECTOR_DIMENSION}'"
|
||||
) in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_from_texts_not_supported() -> None:
|
||||
with pytest.raises(NotImplementedError) as ex:
|
||||
DatabricksVectorSearch.from_texts(fake_texts, FakeEmbeddings())
|
||||
assert (
|
||||
"`from_texts` is not supported. "
|
||||
"Use `add_texts` to add to existing direct-access index."
|
||||
) in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details",
|
||||
[DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS, DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS],
|
||||
)
|
||||
def test_add_texts_not_supported_for_delta_sync_index(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
vectorsearch.add_texts(fake_texts)
|
||||
assert "`add_texts` is only supported for direct-access index." in str(ex.value)
|
||||
|
||||
|
||||
def is_valid_uuid(val: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(str(val))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_add_texts() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch = DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
)
|
||||
ids = [idx for idx, i in enumerate(fake_texts)]
|
||||
vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
|
||||
|
||||
added_ids = vectorsearch.add_texts(fake_texts, ids=ids)
|
||||
index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
DEFAULT_PRIMARY_KEY: id_,
|
||||
DEFAULT_TEXT_COLUMN: text,
|
||||
DEFAULT_VECTOR_COLUMN: vector,
|
||||
}
|
||||
for text, vector, id_ in zip(fake_texts, vectors, ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(fake_texts)
|
||||
assert added_ids == ids
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_add_texts_handle_single_text() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch = DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
)
|
||||
vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
|
||||
|
||||
added_ids = vectorsearch.add_texts(fake_texts[0])
|
||||
index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
DEFAULT_PRIMARY_KEY: id_,
|
||||
DEFAULT_TEXT_COLUMN: text,
|
||||
DEFAULT_VECTOR_COLUMN: vector,
|
||||
}
|
||||
for text, vector, id_ in zip(fake_texts, vectors, added_ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == 1
|
||||
assert is_valid_uuid(added_ids[0])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_add_texts_with_default_id() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
|
||||
|
||||
added_ids = vectorsearch.add_texts(fake_texts)
|
||||
index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
DEFAULT_PRIMARY_KEY: id_,
|
||||
DEFAULT_TEXT_COLUMN: text,
|
||||
DEFAULT_VECTOR_COLUMN: vector,
|
||||
}
|
||||
for text, vector, id_ in zip(fake_texts, vectors, added_ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(fake_texts)
|
||||
assert all([is_valid_uuid(id_) for id_ in added_ids])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_add_texts_with_metadata() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
|
||||
metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(fake_texts))]
|
||||
|
||||
added_ids = vectorsearch.add_texts(fake_texts, metadatas=metadatas)
|
||||
index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
DEFAULT_PRIMARY_KEY: id_,
|
||||
DEFAULT_TEXT_COLUMN: text,
|
||||
DEFAULT_VECTOR_COLUMN: vector,
|
||||
**metadata,
|
||||
}
|
||||
for text, vector, id_, metadata in zip(
|
||||
fake_texts, vectors, added_ids, metadatas
|
||||
)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(fake_texts)
|
||||
assert all([is_valid_uuid(id_) for id_ in added_ids])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details",
|
||||
[DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX],
|
||||
)
|
||||
def test_embeddings_property(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
assert vectorsearch.embeddings == DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details",
|
||||
[DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS, DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS],
|
||||
)
|
||||
def test_delete_not_supported_for_delta_sync_index(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
vectorsearch.delete(["some id"])
|
||||
assert "`delete` is only supported for direct-access index." in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_delete() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
|
||||
vectorsearch.delete(["some id"])
|
||||
index.delete.assert_called_once_with(["some id"])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_delete_fail_no_ids() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
vectorsearch.delete()
|
||||
assert "ids must be provided." in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize("index_details", ALL_INDEXES)
|
||||
def test_similarity_search(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
query = "foo"
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search(query, k=limit, filters=filters)
|
||||
if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS:
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
query_text=query,
|
||||
query_vector=None,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
)
|
||||
else:
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
query_text=None,
|
||||
query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query),
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
)
|
||||
assert len(search_result) == len(fake_texts)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
|
||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
)
|
||||
def test_similarity_search_by_vector(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filters=filters
|
||||
)
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
query_vector=query_embedding,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
)
|
||||
assert len(search_result) == len(fake_texts)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
|
||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize("index_details", ALL_INDEXES)
|
||||
def test_similarity_search_empty_result(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": DEFAULT_PRIMARY_KEY},
|
||||
{"name": DEFAULT_TEXT_COLUMN},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": 0,
|
||||
"data_array": [],
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
|
||||
search_result = vectorsearch.similarity_search("foo")
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None:
|
||||
index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filters=filters
|
||||
)
|
||||
assert (
|
||||
"`similarity_search_by_vector` is not supported for index with "
|
||||
"Databricks-managed embeddings." in str(ex.value)
|
||||
)
|
@ -17,6 +17,7 @@ _EXPECTED = [
|
||||
"Clickhouse",
|
||||
"ClickhouseSettings",
|
||||
"DashVector",
|
||||
"DatabricksVectorSearch",
|
||||
"DeepLake",
|
||||
"Dingo",
|
||||
"DocArrayHnswSearch",
|
||||
|
Loading…
Reference in New Issue
Block a user