mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 12:07:36 +00:00
databricks: add vector search and embeddings (#25648)
### Summary Add `DatabricksVectorSearch` and `DatabricksEmbeddings` classes to the `langchain-databricks` partner packages. Core functionality is unchanged, but the vector search class is largely refactored for readability and maintainability. This PR does not add integration tests yet. This will be added once the Databricks test workspace is ready. Tagging @efriis as POC ### Tracker [✅] Create a package and imgrate ChatDatabricks [✍️] Migrate DatabricksVectorSearch, DatabricksEmbeddings, and their docs ~[ ] Migrate UCFunctionToolkit and its doc~ [ ] Add provider document and update README.md [ ] Add integration tests and set up secrets (after moved to an external package) [ ] Add deprecation note to the community implementations. --------- Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_databricks.chat_models import ChatDatabricks
|
||||
from langchain_databricks.embeddings import DatabricksEmbeddings
|
||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
@@ -11,5 +13,7 @@ del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
__all__ = [
|
||||
"ChatDatabricks",
|
||||
"DatabricksEmbeddings",
|
||||
"DatabricksVectorSearch",
|
||||
"__version__",
|
||||
]
|
||||
|
@@ -15,7 +15,6 @@ from typing import (
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@@ -50,6 +49,8 @@ from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_databricks.utils import get_deployment_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -230,25 +231,7 @@ class ChatDatabricks(BaseChatModel):
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._validate_uri()
|
||||
try:
|
||||
from mlflow.deployments import get_deploy_client # type: ignore
|
||||
|
||||
self._client = get_deploy_client(self.target_uri)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Failed to create the client. Please run `pip install mlflow` to "
|
||||
"install required dependencies."
|
||||
) from e
|
||||
|
||||
def _validate_uri(self) -> None:
|
||||
if self.target_uri == "databricks":
|
||||
return
|
||||
|
||||
if urlparse(self.target_uri).scheme != "databricks":
|
||||
raise ValueError(
|
||||
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||
)
|
||||
self._client = get_deployment_client(self.target_uri)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
|
91
libs/partners/databricks/langchain_databricks/embeddings.py
Normal file
91
libs/partners/databricks/langchain_databricks/embeddings.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Any, Dict, Iterator, List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
|
||||
|
||||
from langchain_databricks.utils import get_deployment_client
|
||||
|
||||
|
||||
class DatabricksEmbeddings(Embeddings, BaseModel):
|
||||
"""Databricks embedding model integration.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-databricks``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-databricks
|
||||
|
||||
If you are outside Databricks, set the Databricks workspace
|
||||
hostname and personal access token to environment variables:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
||||
|
||||
Key init args — completion params:
|
||||
endpoint: str
|
||||
Name of Databricks Model Serving endpoint to query.
|
||||
target_uri: str
|
||||
The target URI to use. Defaults to ``databricks``.
|
||||
query_params: Dict[str, str]
|
||||
The parameters to use for queries.
|
||||
documents_params: Dict[str, str]
|
||||
The parameters to use for documents.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
from langchain_databricks import DatabricksEmbeddings
|
||||
embed = DatabricksEmbeddings(
|
||||
endpoint="databricks-bge-large-en",
|
||||
)
|
||||
|
||||
Embed single text:
|
||||
.. code-block:: python
|
||||
input_text = "The meaning of life is 42"
|
||||
embed.embed_query(input_text)
|
||||
|
||||
.. code-block:: python
|
||||
[
|
||||
0.01605224609375,
|
||||
-0.0298309326171875,
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
endpoint: str
|
||||
"""The endpoint to use."""
|
||||
target_uri: str = "databricks"
|
||||
"""The parameters to use for queries."""
|
||||
query_params: Dict[str, Any] = {}
|
||||
"""The parameters to use for documents."""
|
||||
documents_params: Dict[str, Any] = {}
|
||||
"""The target URI to use."""
|
||||
_client: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._client = get_deployment_client(self.target_uri)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embed(texts, params=self.documents_params)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embed([text], params=self.query_params)[0]
|
||||
|
||||
def _embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]:
|
||||
embeddings: List[List[float]] = []
|
||||
for txt in _chunk(texts, 20):
|
||||
resp = self._client.predict(
|
||||
endpoint=self.endpoint,
|
||||
inputs={"input": txt, **params}, # type: ignore[arg-type]
|
||||
)
|
||||
embeddings.extend(r["embedding"] for r in resp["data"])
|
||||
return embeddings
|
||||
|
||||
|
||||
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||
for i in range(0, len(texts), size):
|
||||
yield texts[i : i + size]
|
101
libs/partners/databricks/langchain_databricks/utils.py
Normal file
101
libs/partners/databricks/langchain_databricks/utils.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from typing import Any, List, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_deployment_client(target_uri: str) -> Any:
|
||||
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
|
||||
raise ValueError(
|
||||
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||
)
|
||||
|
||||
try:
|
||||
from mlflow.deployments import get_deploy_client # type: ignore[import-untyped]
|
||||
|
||||
return get_deploy_client(target_uri)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Failed to create the client. "
|
||||
"Please run `pip install mlflow` to install "
|
||||
"required dependencies."
|
||||
) from e
|
||||
|
||||
|
||||
# Utility function for Maximal Marginal Relevance (MMR) reranking.
|
||||
# Copied from langchain_community/vectorstores/utils.py to avoid cross-dependency
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
|
||||
|
||||
def maximal_marginal_relevance(
|
||||
query_embedding: np.ndarray,
|
||||
embedding_list: list,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[int]:
|
||||
"""Calculate maximal marginal relevance.
|
||||
|
||||
Args:
|
||||
query_embedding: Query embedding.
|
||||
embedding_list: List of embeddings to select from.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of indices of embeddings selected by maximal marginal relevance.
|
||||
"""
|
||||
if min(k, len(embedding_list)) <= 0:
|
||||
return []
|
||||
if query_embedding.ndim == 1:
|
||||
query_embedding = np.expand_dims(query_embedding, axis=0)
|
||||
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
||||
most_similar = int(np.argmax(similarity_to_query))
|
||||
idxs = [most_similar]
|
||||
selected = np.array([embedding_list[most_similar]])
|
||||
while len(idxs) < min(k, len(embedding_list)):
|
||||
best_score = -np.inf
|
||||
idx_to_add = -1
|
||||
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
||||
for i, query_score in enumerate(similarity_to_query):
|
||||
if i in idxs:
|
||||
continue
|
||||
redundant_score = max(similarity_to_selected[i])
|
||||
equation_score = (
|
||||
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
||||
)
|
||||
if equation_score > best_score:
|
||||
best_score = equation_score
|
||||
idx_to_add = i
|
||||
idxs.append(idx_to_add)
|
||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
||||
return idxs
|
||||
|
||||
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of columns in X and Y are not the same.
|
||||
"""
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return np.array([])
|
||||
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError(
|
||||
"Number of columns in X and Y must be the same. X has shape"
|
||||
f"{X.shape} "
|
||||
f"and Y has shape {Y.shape}."
|
||||
)
|
||||
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
Y_norm = np.linalg.norm(Y, axis=1)
|
||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
837
libs/partners/databricks/langchain_databricks/vectorstores.py
Normal file
837
libs/partners/databricks/langchain_databricks/vectorstores.py
Normal file
@@ -0,0 +1,837 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VST, VectorStore
|
||||
|
||||
from langchain_databricks.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexType(str, Enum):
|
||||
DIRECT_ACCESS = "DIRECT_ACCESS"
|
||||
DELTA_SYNC = "DELTA_SYNC"
|
||||
|
||||
|
||||
_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index."
|
||||
_NON_MANAGED_EMB_ONLY_MSG = (
|
||||
"`%s` is not supported for index with Databricks-managed embeddings."
|
||||
)
|
||||
|
||||
|
||||
class DatabricksVectorSearch(VectorStore):
|
||||
"""Databricks vector store integration.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-databricks`` and ``databricks-vectorsearch`` python packages.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-databricks databricks-vectorsearch
|
||||
|
||||
If you don't have a Databricks Vector Search endpoint already, you can create one by following the instructions here: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html
|
||||
|
||||
If you are outside Databricks, set the Databricks workspace
|
||||
hostname and personal access token to environment variables:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
||||
|
||||
Key init args — indexing params:
|
||||
|
||||
endpoint: The name of the Databricks Vector Search endpoint.
|
||||
index_name: The name of the index to use. Format: "catalog.schema.index".
|
||||
embedding: The embedding model.
|
||||
Required for direct-access index or delta-sync index
|
||||
with self-managed embeddings.
|
||||
text_column: The name of the text column to use for the embeddings.
|
||||
Required for direct-access index or delta-sync index
|
||||
with self-managed embeddings.
|
||||
Make sure the text column specified is in the index.
|
||||
columns: The list of column names to get when doing the search.
|
||||
Defaults to ``[primary_key, text_column]``.
|
||||
|
||||
Instantiate:
|
||||
|
||||
`DatabricksVectorSearch` supports two types of indexes:
|
||||
|
||||
* **Delta Sync Index** automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the Delta Table changes.
|
||||
|
||||
* **Direct Vector Access Index** supports direct read and write of vectors and metadata. The user is responsible for updating this table using the REST API or the Python SDK.
|
||||
|
||||
Also for delta-sync index, you can choose to use Databricks-managed embeddings or self-managed embeddings (via LangChain embeddings classes).
|
||||
|
||||
If you are using a delta-sync index with Databricks-managed embeddings:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
||||
|
||||
vector_store = DatabricksVectorSearch(
|
||||
endpoint="<your-endpoint-name>",
|
||||
index_name="<your-index-name>"
|
||||
)
|
||||
|
||||
If you are using a direct-access index or a delta-sync index with self-managed embeddings,
|
||||
you also need to provide the embedding model and text column in your source table to
|
||||
use for the embeddings:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vector_store = DatabricksVectorSearch(
|
||||
endpoint="<your-endpoint-name>",
|
||||
index_name="<your-index-name>",
|
||||
embedding=OpenAIEmbeddings(),
|
||||
text_column="document_content"
|
||||
)
|
||||
|
||||
Add Documents:
|
||||
.. code-block:: python
|
||||
from langchain_core.documents import Document
|
||||
|
||||
document_1 = Document(page_content="foo", metadata={"baz": "bar"})
|
||||
document_2 = Document(page_content="thud", metadata={"bar": "baz"})
|
||||
document_3 = Document(page_content="i will be deleted :(")
|
||||
documents = [document_1, document_2, document_3]
|
||||
ids = ["1", "2", "3"]
|
||||
vector_store.add_documents(documents=documents, ids=ids)
|
||||
|
||||
Delete Documents:
|
||||
.. code-block:: python
|
||||
vector_store.delete(ids=["3"])
|
||||
|
||||
.. note::
|
||||
|
||||
The `delete` method is only supported for direct-access index.
|
||||
|
||||
Search:
|
||||
.. code-block:: python
|
||||
results = vector_store.similarity_search(query="thud",k=1)
|
||||
for doc in results:
|
||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
||||
.. code-block:: python
|
||||
* thud [{'id': '2'}]
|
||||
|
||||
.. note:
|
||||
|
||||
By default, similarity search only returns the primary key and text column.
|
||||
If you want to retrieve the custom metadata associated with the document,
|
||||
pass the additional columns in the `columns` parameter when initializing the vector store.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
vector_store = DatabricksVectorSearch(
|
||||
endpoint="<your-endpoint-name>",
|
||||
index_name="<your-index-name>",
|
||||
columns=["baz", "bar"],
|
||||
)
|
||||
|
||||
vector_store.similarity_search(query="thud",k=1)
|
||||
# Output: * thud [{'bar': 'baz', 'baz': None, 'id': '2'}]
|
||||
|
||||
Search with filter:
|
||||
.. code-block:: python
|
||||
results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"})
|
||||
for doc in results:
|
||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
||||
.. code-block:: python
|
||||
* thud [{'id': '2'}]
|
||||
|
||||
Search with score:
|
||||
.. code-block:: python
|
||||
results = vector_store.similarity_search_with_score(query="qux",k=1)
|
||||
for doc, score in results:
|
||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
||||
.. code-block:: python
|
||||
* [SIM=0.748804] foo [{'id': '1'}]
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
# add documents
|
||||
await vector_store.aadd_documents(documents=documents, ids=ids)
|
||||
# delete documents
|
||||
await vector_store.adelete(ids=["3"])
|
||||
# search
|
||||
results = vector_store.asimilarity_search(query="thud",k=1)
|
||||
# search with score
|
||||
results = await vector_store.asimilarity_search_with_score(query="qux",k=1)
|
||||
for doc,score in results:
|
||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
||||
.. code-block:: python
|
||||
* [SIM=0.748807] foo [{'id': '1'}]
|
||||
|
||||
Use as Retriever:
|
||||
.. code-block:: python
|
||||
retriever = vector_store.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
|
||||
)
|
||||
retriever.invoke("thud")
|
||||
.. code-block:: python
|
||||
[Document(metadata={'id': '2'}, page_content='thud')]
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
index_name: str,
|
||||
embedding: Optional[Embeddings] = None,
|
||||
text_column: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
):
|
||||
try:
|
||||
from databricks.vector_search.client import ( # type: ignore[import]
|
||||
VectorSearchClient,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import databricks-vectorsearch python package. "
|
||||
"Please install it with `pip install databricks-vectorsearch`."
|
||||
) from e
|
||||
|
||||
self.index = VectorSearchClient().get_index(endpoint, index_name)
|
||||
self._index_details = IndexDetails(self.index)
|
||||
|
||||
_validate_embedding(embedding, self._index_details)
|
||||
self._embeddings = embedding
|
||||
self._text_column = _validate_and_get_text_column(
|
||||
text_column, self._index_details
|
||||
)
|
||||
self._columns = _validate_and_get_return_columns(
|
||||
columns or [], self._text_column, self._index_details
|
||||
)
|
||||
self._primary_key = self._index_details.primary_key
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
"""Access the query embedding object if available."""
|
||||
return self._embeddings
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[Dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
raise NotImplementedError(
|
||||
"`from_texts` is not supported. "
|
||||
"Use `add_texts` to add to existing direct-access index."
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add texts to the index.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is only supported for a direct-access index.
|
||||
|
||||
Args:
|
||||
texts: List of texts to add.
|
||||
metadatas: List of metadata for each text. Defaults to None.
|
||||
ids: List of ids for each text. Defaults to None.
|
||||
If not provided, a random uuid will be generated for each text.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the index.
|
||||
"""
|
||||
if self._index_details.is_delta_sync_index():
|
||||
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "add_texts")
|
||||
|
||||
# Wrap to list if input texts is a single string
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
texts = list(texts)
|
||||
vectors = self._embeddings.embed_documents(texts) # type: ignore[union-attr]
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
metadatas = metadatas or [{} for _ in texts]
|
||||
|
||||
updates = [
|
||||
{
|
||||
self._primary_key: id_,
|
||||
self._text_column: text,
|
||||
self._index_details.embedding_vector_column["name"]: vector,
|
||||
**metadata,
|
||||
}
|
||||
for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas)
|
||||
]
|
||||
|
||||
upsert_resp = self.index.upsert(updates)
|
||||
if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"):
|
||||
failed_ids = upsert_resp.get("result", dict()).get(
|
||||
"failed_primary_keys", []
|
||||
)
|
||||
if upsert_resp.get("status") == "FAILURE":
|
||||
logger.error("Failed to add texts to the index.")
|
||||
else:
|
||||
logger.warning("Some texts failed to be added to the index.")
|
||||
return [id_ for id_ in ids if id_ not in failed_ids]
|
||||
|
||||
return ids
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.add_texts, **kwargs), texts, metadatas
|
||||
)
|
||||
|
||||
def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete documents from the index.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is only supported for a direct-access index.
|
||||
|
||||
Args:
|
||||
ids: List of ids of documents to delete.
|
||||
|
||||
Returns:
|
||||
True if successful.
|
||||
"""
|
||||
if self._index_details.is_delta_sync_index():
|
||||
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "delete")
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("ids must be provided.")
|
||||
self.index.delete(ids)
|
||||
return True
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_with_score = self.similarity_search_with_score(
|
||||
query=query,
|
||||
k=k,
|
||||
filter=filter,
|
||||
query_type=query_type,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_with_score]
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search, query, k=k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query, along with scores.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding and score for each.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
query_text = query
|
||||
query_vector = None
|
||||
else:
|
||||
# The value for `query_text` needs to be specified only for hybrid search.
|
||||
if query_type is not None and query_type.upper() == "HYBRID":
|
||||
query_text = query
|
||||
else:
|
||||
query_text = None
|
||||
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
|
||||
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=self._columns,
|
||||
query_text=query_text,
|
||||
query_vector=query_vector,
|
||||
filters=filter,
|
||||
num_results=k,
|
||||
query_type=query_type,
|
||||
)
|
||||
return self._parse_search_response(search_resp)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
Databricks Vector search uses a normalized score 1/(1+d) where d
|
||||
is the L2 distance. Hence, we simply return the identity function.
|
||||
"""
|
||||
return lambda score: score
|
||||
|
||||
async def asimilarity_search_with_score(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_with_score, *args, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Any] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
raise NotImplementedError(
|
||||
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector"
|
||||
)
|
||||
|
||||
docs_with_score = self.similarity_search_by_vector_with_score(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_with_score]
|
||||
|
||||
async def asimilarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_by_vector_with_score(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Any] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector, along with scores.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is not supported for index with Databricks-managed embeddings.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding and score for each.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
raise NotImplementedError(
|
||||
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector_with_score"
|
||||
)
|
||||
|
||||
if query_type is not None and query_type.upper() == "HYBRID":
|
||||
if query is None:
|
||||
raise ValueError(
|
||||
"A value for `query` must be specified for hybrid search."
|
||||
)
|
||||
query_text = query
|
||||
else:
|
||||
if query is not None:
|
||||
raise ValueError(
|
||||
(
|
||||
"Cannot specify both `embedding` and "
|
||||
'`query` unless `query_type="HYBRID"'
|
||||
)
|
||||
)
|
||||
query_text = None
|
||||
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=self._columns,
|
||||
query_vector=embedding,
|
||||
query_text=query_text,
|
||||
filters=filter,
|
||||
num_results=k,
|
||||
query_type=query_type,
|
||||
)
|
||||
return self._parse_search_response(search_resp)
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is not supported for index with Databricks-managed embeddings.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
raise NotImplementedError(
|
||||
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search"
|
||||
)
|
||||
|
||||
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
query_vector,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
query_type=query_type,
|
||||
)
|
||||
return docs
|
||||
|
||||
async def amax_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(
|
||||
self.max_marginal_relevance_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
**kwargs,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Any] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is not supported for index with Databricks-managed embeddings.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
raise NotImplementedError(
|
||||
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search_by_vector"
|
||||
)
|
||||
|
||||
embedding_column = self._index_details.embedding_vector_column["name"]
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=list(set(self._columns + [embedding_column])),
|
||||
query_text=None,
|
||||
query_vector=embedding,
|
||||
filters=filter,
|
||||
num_results=fetch_k,
|
||||
query_type=query_type,
|
||||
)
|
||||
|
||||
embeddings_result_index = (
|
||||
search_resp.get("manifest").get("columns").index({"name": embedding_column})
|
||||
)
|
||||
embeddings = [
|
||||
doc[embeddings_result_index]
|
||||
for doc in search_resp.get("result").get("data_array")
|
||||
]
|
||||
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
ignore_cols: List = (
|
||||
[embedding_column] if embedding_column not in self._columns else []
|
||||
)
|
||||
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
|
||||
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
|
||||
return selected_results
|
||||
|
||||
async def amax_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_search_response(
|
||||
self, search_resp: Dict, ignore_cols: Optional[List[str]] = None
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Parse the search response into a list of Documents with score."""
|
||||
if ignore_cols is None:
|
||||
ignore_cols = []
|
||||
|
||||
columns = [
|
||||
col["name"]
|
||||
for col in search_resp.get("manifest", dict()).get("columns", [])
|
||||
]
|
||||
docs_with_score = []
|
||||
for result in search_resp.get("result", dict()).get("data_array", []):
|
||||
doc_id = result[columns.index(self._primary_key)]
|
||||
text_content = result[columns.index(self._text_column)]
|
||||
ignore_cols = [self._primary_key, self._text_column] + ignore_cols
|
||||
metadata = {
|
||||
col: value
|
||||
for col, value in zip(columns[:-1], result[:-1])
|
||||
if col not in ignore_cols
|
||||
}
|
||||
metadata[self._primary_key] = doc_id
|
||||
score = result[-1]
|
||||
doc = Document(page_content=text_content, metadata=metadata)
|
||||
docs_with_score.append((doc, score))
|
||||
return docs_with_score
|
||||
|
||||
|
||||
def _validate_and_get_text_column(
|
||||
text_column: Optional[str], index_details: IndexDetails
|
||||
) -> str:
|
||||
if index_details.is_databricks_managed_embeddings():
|
||||
index_source_column: str = index_details.embedding_source_column["name"]
|
||||
# check if input text column matches the source column of the index
|
||||
if text_column is not None:
|
||||
raise ValueError(
|
||||
f"The index '{index_details.name}' has the source column configured as "
|
||||
f"'{index_source_column}'. Do not pass the `text_column` parameter."
|
||||
)
|
||||
return index_source_column
|
||||
else:
|
||||
if text_column is None:
|
||||
raise ValueError("The `text_column` parameter is required for this index.")
|
||||
return text_column
|
||||
|
||||
|
||||
def _validate_and_get_return_columns(
|
||||
columns: List[str], text_column: str, index_details: IndexDetails
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get a list of columns to retrieve from the index.
|
||||
|
||||
If the index is direct-access index, validate the given columns against the schema.
|
||||
"""
|
||||
# add primary key column and source column if not in columns
|
||||
if index_details.primary_key not in columns:
|
||||
columns.append(index_details.primary_key)
|
||||
if text_column and text_column not in columns:
|
||||
columns.append(text_column)
|
||||
|
||||
# Validate specified columns are in the index
|
||||
if index_details.is_direct_access_index() and (
|
||||
index_schema := index_details.schema
|
||||
):
|
||||
if missing_columns := [c for c in columns if c not in index_schema]:
|
||||
raise ValueError(
|
||||
"Some columns specified in `columns` are not "
|
||||
f"in the index schema: {missing_columns}"
|
||||
)
|
||||
return columns
|
||||
|
||||
|
||||
def _validate_embedding(
|
||||
embedding: Optional[Embeddings], index_details: IndexDetails
|
||||
) -> None:
|
||||
if index_details.is_databricks_managed_embeddings():
|
||||
if embedding is not None:
|
||||
raise ValueError(
|
||||
f"The index '{index_details.name}' uses Databricks-managed embeddings. "
|
||||
"Do not pass the `embedding` parameter when initializing vector store."
|
||||
)
|
||||
else:
|
||||
if not embedding:
|
||||
raise ValueError(
|
||||
"The `embedding` parameter is required for a direct-access index "
|
||||
"or delta-sync index with self-managed embedding."
|
||||
)
|
||||
_validate_embedding_dimension(embedding, index_details)
|
||||
|
||||
|
||||
def _validate_embedding_dimension(
|
||||
embeddings: Embeddings, index_details: IndexDetails
|
||||
) -> None:
|
||||
"""validate if the embedding dimension matches with the index's configuration."""
|
||||
if index_embedding_dimension := index_details.embedding_vector_column.get(
|
||||
"embedding_dimension"
|
||||
):
|
||||
# Infer the embedding dimension from the embedding function."""
|
||||
actual_dimension = len(embeddings.embed_query("test"))
|
||||
if actual_dimension != index_embedding_dimension:
|
||||
raise ValueError(
|
||||
f"The specified embedding model's dimension '{actual_dimension}' does "
|
||||
f"not match with the index configuration '{index_embedding_dimension}'."
|
||||
)
|
||||
|
||||
|
||||
class IndexDetails:
|
||||
"""An utility class to store the configuration details of an index."""
|
||||
|
||||
def __init__(self, index: Any):
|
||||
self._index_details = index.describe()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._index_details["name"]
|
||||
|
||||
@property
|
||||
def schema(self) -> Optional[Dict]:
|
||||
if self.is_direct_access_index():
|
||||
schema_json = self.index_spec.get("schema_json")
|
||||
if schema_json is not None:
|
||||
return json.loads(schema_json)
|
||||
return None
|
||||
|
||||
@property
|
||||
def primary_key(self) -> str:
|
||||
return self._index_details["primary_key"]
|
||||
|
||||
@property
|
||||
def index_spec(self) -> Dict:
|
||||
return (
|
||||
self._index_details.get("delta_sync_index_spec", {})
|
||||
if self.is_delta_sync_index()
|
||||
else self._index_details.get("direct_access_index_spec", {})
|
||||
)
|
||||
|
||||
@property
|
||||
def embedding_vector_column(self) -> Dict:
|
||||
if vector_columns := self.index_spec.get("embedding_vector_columns"):
|
||||
return vector_columns[0]
|
||||
return {}
|
||||
|
||||
@property
|
||||
def embedding_source_column(self) -> Dict:
|
||||
if source_columns := self.index_spec.get("embedding_source_columns"):
|
||||
return source_columns[0]
|
||||
return {}
|
||||
|
||||
def is_delta_sync_index(self) -> bool:
|
||||
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value
|
||||
|
||||
def is_direct_access_index(self) -> bool:
|
||||
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value
|
||||
|
||||
def is_databricks_managed_embeddings(self) -> bool:
|
||||
return (
|
||||
self.is_delta_sync_index()
|
||||
and self.embedding_source_column.get("name") is not None
|
||||
)
|
58
libs/partners/databricks/poetry.lock
generated
58
libs/partners/databricks/poetry.lock
generated
@@ -339,6 +339,22 @@ requests = ">=2.28.1,<3"
|
||||
dev = ["autoflake", "databricks-connect", "ipython", "ipywidgets", "isort", "pycodestyle", "pyfakefs", "pytest", "pytest-cov", "pytest-mock", "pytest-rerunfailures", "pytest-xdist", "requests-mock", "wheel", "yapf"]
|
||||
notebook = ["ipython (>=8,<9)", "ipywidgets (>=8,<9)"]
|
||||
|
||||
[[package]]
|
||||
name = "databricks-vectorsearch"
|
||||
version = "0.40"
|
||||
description = "Databricks Vector Search Client"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "databricks_vectorsearch-0.40-py3-none-any.whl", hash = "sha256:c684291e1b0472ece8f6df8c6ff7982f49ce7075e1df5b93459e148dea1d70d7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2"
|
||||
mlflow-skinny = ">=2.11.3,<3"
|
||||
protobuf = ">=3.12.0,<5"
|
||||
requests = ">=2"
|
||||
|
||||
[[package]]
|
||||
name = "deprecated"
|
||||
version = "1.2.14"
|
||||
@@ -356,6 +372,20 @@ wrapt = ">=1.10,<2"
|
||||
[package.extras]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
description = "A library to handle automated deprecations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
|
||||
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = "*"
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "7.1.0"
|
||||
@@ -1469,8 +1499,8 @@ files = [
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
@@ -1613,22 +1643,22 @@ testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "5.27.3"
|
||||
version = "4.25.4"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "protobuf-5.27.3-cp310-abi3-win32.whl", hash = "sha256:dcb307cd4ef8fec0cf52cb9105a03d06fbb5275ce6d84a6ae33bc6cf84e0a07b"},
|
||||
{file = "protobuf-5.27.3-cp310-abi3-win_amd64.whl", hash = "sha256:16ddf3f8c6c41e1e803da7abea17b1793a97ef079a912e42351eabb19b2cffe7"},
|
||||
{file = "protobuf-5.27.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:68248c60d53f6168f565a8c76dc58ba4fa2ade31c2d1ebdae6d80f969cdc2d4f"},
|
||||
{file = "protobuf-5.27.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:b8a994fb3d1c11156e7d1e427186662b64694a62b55936b2b9348f0a7c6625ce"},
|
||||
{file = "protobuf-5.27.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:a55c48f2a2092d8e213bd143474df33a6ae751b781dd1d1f4d953c128a415b25"},
|
||||
{file = "protobuf-5.27.3-cp38-cp38-win32.whl", hash = "sha256:043853dcb55cc262bf2e116215ad43fa0859caab79bb0b2d31b708f128ece035"},
|
||||
{file = "protobuf-5.27.3-cp38-cp38-win_amd64.whl", hash = "sha256:c2a105c24f08b1e53d6c7ffe69cb09d0031512f0b72f812dd4005b8112dbe91e"},
|
||||
{file = "protobuf-5.27.3-cp39-cp39-win32.whl", hash = "sha256:c84eee2c71ed83704f1afbf1a85c3171eab0fd1ade3b399b3fad0884cbcca8bf"},
|
||||
{file = "protobuf-5.27.3-cp39-cp39-win_amd64.whl", hash = "sha256:af7c0b7cfbbb649ad26132e53faa348580f844d9ca46fd3ec7ca48a1ea5db8a1"},
|
||||
{file = "protobuf-5.27.3-py3-none-any.whl", hash = "sha256:8572c6533e544ebf6899c360e91d6bcbbee2549251643d32c52cf8a5de295ba5"},
|
||||
{file = "protobuf-5.27.3.tar.gz", hash = "sha256:82460903e640f2b7e34ee81a947fdaad89de796d324bcbc38ff5430bcdead82c"},
|
||||
{file = "protobuf-4.25.4-cp310-abi3-win32.whl", hash = "sha256:db9fd45183e1a67722cafa5c1da3e85c6492a5383f127c86c4c4aa4845867dc4"},
|
||||
{file = "protobuf-4.25.4-cp310-abi3-win_amd64.whl", hash = "sha256:ba3d8504116a921af46499471c63a85260c1a5fc23333154a427a310e015d26d"},
|
||||
{file = "protobuf-4.25.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:eecd41bfc0e4b1bd3fa7909ed93dd14dd5567b98c941d6c1ad08fdcab3d6884b"},
|
||||
{file = "protobuf-4.25.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:4c8a70fdcb995dcf6c8966cfa3a29101916f7225e9afe3ced4395359955d3835"},
|
||||
{file = "protobuf-4.25.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3319e073562e2515c6ddc643eb92ce20809f5d8f10fead3332f71c63be6a7040"},
|
||||
{file = "protobuf-4.25.4-cp38-cp38-win32.whl", hash = "sha256:7e372cbbda66a63ebca18f8ffaa6948455dfecc4e9c1029312f6c2edcd86c4e1"},
|
||||
{file = "protobuf-4.25.4-cp38-cp38-win_amd64.whl", hash = "sha256:051e97ce9fa6067a4546e75cb14f90cf0232dcb3e3d508c448b8d0e4265b61c1"},
|
||||
{file = "protobuf-4.25.4-cp39-cp39-win32.whl", hash = "sha256:90bf6fd378494eb698805bbbe7afe6c5d12c8e17fca817a646cd6a1818c696ca"},
|
||||
{file = "protobuf-4.25.4-cp39-cp39-win_amd64.whl", hash = "sha256:ac79a48d6b99dfed2729ccccee547b34a1d3d63289c71cef056653a846a2240f"},
|
||||
{file = "protobuf-4.25.4-py3-none-any.whl", hash = "sha256:bfbebc1c8e4793cfd58589acfb8a1026be0003e852b9da7db5a4285bde996978"},
|
||||
{file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2492,4 +2522,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<3.12"
|
||||
content-hash = "6da52f0e39bc7da1a80cc181bdd481e57f4644daf2f3c6da6a6b0ead2e813be9"
|
||||
content-hash = "857f47603d9dd6fe8882c7525613a54a54ee459a9ee012f3d19e510c5477f3db"
|
||||
|
@@ -26,6 +26,7 @@ scipy = [
|
||||
{version = ">=1.11", python = ">=3.12"},
|
||||
{version = "<2", python = "<3.12"}
|
||||
]
|
||||
databricks-vectorsearch = "^0.40"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
69
libs/partners/databricks/tests/unit_tests/test_embeddings.py
Normal file
69
libs/partners/databricks/tests/unit_tests/test_embeddings.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Test Together AI embeddings."""
|
||||
|
||||
from typing import Any, Dict, Generator
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from mlflow.deployments import BaseDeploymentClient # type: ignore[import-untyped]
|
||||
|
||||
from langchain_databricks import DatabricksEmbeddings
|
||||
|
||||
|
||||
def _mock_embeddings(endpoint: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": list(range(1536)),
|
||||
"index": 0,
|
||||
}
|
||||
for _ in inputs["input"]
|
||||
],
|
||||
"model": "text-embedding-3-small",
|
||||
"usage": {"prompt_tokens": 8, "total_tokens": 8},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client() -> Generator:
|
||||
client = mock.MagicMock()
|
||||
client.predict.side_effect = _mock_embeddings
|
||||
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embeddings() -> DatabricksEmbeddings:
|
||||
return DatabricksEmbeddings(
|
||||
endpoint="text-embedding-3-small",
|
||||
documents_params={"fruit": "apple"},
|
||||
query_params={"fruit": "banana"},
|
||||
)
|
||||
|
||||
|
||||
def test_embed_documents(
|
||||
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
|
||||
) -> None:
|
||||
documents = ["foo"] * 30
|
||||
output = embeddings.embed_documents(documents)
|
||||
assert len(output) == 30
|
||||
assert len(output[0]) == 1536
|
||||
assert mock_client.predict.call_count == 2
|
||||
assert all(
|
||||
call_arg[1]["inputs"]["fruit"] == "apple"
|
||||
for call_arg in mock_client().predict.call_args_list
|
||||
)
|
||||
|
||||
|
||||
def test_embed_query(
|
||||
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
|
||||
) -> None:
|
||||
query = "foo bar"
|
||||
output = embeddings.embed_query(query)
|
||||
assert len(output) == 1536
|
||||
mock_client.predict.assert_called_once()
|
||||
assert mock_client.predict.call_args[1] == {
|
||||
"endpoint": "text-embedding-3-small",
|
||||
"inputs": {"input": [query], "fruit": "banana"},
|
||||
}
|
@@ -2,6 +2,8 @@ from langchain_databricks import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatDatabricks",
|
||||
"DatabricksEmbeddings",
|
||||
"DatabricksVectorSearch",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
|
629
libs/partners/databricks/tests/unit_tests/test_vectorstore.py
Normal file
629
libs/partners/databricks/tests/unit_tests/test_vectorstore.py
Normal file
@@ -0,0 +1,629 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator, List, Optional, Set
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
||||
|
||||
INPUT_TEXTS = ["foo", "bar", "baz"]
|
||||
DEFAULT_VECTOR_DIMENSION = 4
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION):
|
||||
super().__init__()
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (self.dimension - 1) + [float(i)]
|
||||
for i in range(len(embedding_texts))
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [float(1.0)] * (self.dimension - 1) + [float(0.0)]
|
||||
|
||||
|
||||
EMBEDDING_MODEL = FakeEmbeddings()
|
||||
|
||||
|
||||
### Dummy similarity_search() Response ###
|
||||
EXAMPLE_SEARCH_RESPONSE = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": "id"},
|
||||
{"name": "text"},
|
||||
{"name": "text_vector"},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": len(INPUT_TEXTS),
|
||||
"data_array": sorted(
|
||||
[
|
||||
[str(uuid.uuid4()), s, e, 0.5]
|
||||
for s, e in zip(
|
||||
INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
)
|
||||
],
|
||||
key=lambda x: x[2], # type: ignore
|
||||
reverse=True,
|
||||
),
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
|
||||
### Dummy Indices ####
|
||||
|
||||
ENDPOINT_NAME = "test-endpoint"
|
||||
DIRECT_ACCESS_INDEX = "test-direct-access-index"
|
||||
DELTA_SYNC_INDEX = "test-delta-sync-index"
|
||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test-delta-sync-self-managed-index"
|
||||
ALL_INDEX_NAMES = {
|
||||
DIRECT_ACCESS_INDEX,
|
||||
DELTA_SYNC_INDEX,
|
||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
|
||||
}
|
||||
|
||||
INDEX_DETAILS = {
|
||||
DELTA_SYNC_INDEX: {
|
||||
"name": DELTA_SYNC_INDEX,
|
||||
"endpoint_name": ENDPOINT_NAME,
|
||||
"index_type": "DELTA_SYNC",
|
||||
"primary_key": "id",
|
||||
"delta_sync_index_spec": {
|
||||
"source_table": "ml.llm.source_table",
|
||||
"pipeline_type": "CONTINUOUS",
|
||||
"embedding_source_columns": [
|
||||
{
|
||||
"name": "text",
|
||||
"embedding_model_endpoint_name": "openai-text-embedding",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: {
|
||||
"name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
|
||||
"endpoint_name": ENDPOINT_NAME,
|
||||
"index_type": "DELTA_SYNC",
|
||||
"primary_key": "id",
|
||||
"delta_sync_index_spec": {
|
||||
"source_table": "ml.llm.source_table",
|
||||
"pipeline_type": "CONTINUOUS",
|
||||
"embedding_vector_columns": [
|
||||
{
|
||||
"name": "text_vector",
|
||||
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
DIRECT_ACCESS_INDEX: {
|
||||
"name": DIRECT_ACCESS_INDEX,
|
||||
"endpoint_name": ENDPOINT_NAME,
|
||||
"index_type": "DIRECT_ACCESS",
|
||||
"primary_key": "id",
|
||||
"direct_access_index_spec": {
|
||||
"embedding_vector_columns": [
|
||||
{
|
||||
"name": "text_vector",
|
||||
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
|
||||
}
|
||||
],
|
||||
"schema_json": f"{{"
|
||||
f'"{"id"}": "int", '
|
||||
f'"feat1": "str", '
|
||||
f'"feat2": "float", '
|
||||
f'"text": "string", '
|
||||
f'"{"text_vector"}": "array<float>"'
|
||||
f"}}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_vs_client() -> Generator:
|
||||
def _get_index(endpoint: str, index_name: str) -> MagicMock:
|
||||
from databricks.vector_search.client import VectorSearchIndex # type: ignore
|
||||
|
||||
if endpoint != ENDPOINT_NAME:
|
||||
raise ValueError(f"Unknown endpoint: {endpoint}")
|
||||
|
||||
index = MagicMock(spec=VectorSearchIndex)
|
||||
index.describe.return_value = INDEX_DETAILS[index_name]
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
return index
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_index.side_effect = _get_index
|
||||
with mock.patch(
|
||||
"databricks.vector_search.client.VectorSearchClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def init_vector_search(
|
||||
index_name: str, columns: Optional[List[str]] = None
|
||||
) -> DatabricksVectorSearch:
|
||||
kwargs: Dict[str, Any] = {
|
||||
"endpoint": ENDPOINT_NAME,
|
||||
"index_name": index_name,
|
||||
"columns": columns,
|
||||
}
|
||||
if index_name != DELTA_SYNC_INDEX:
|
||||
kwargs.update(
|
||||
{
|
||||
"embedding": EMBEDDING_MODEL,
|
||||
"text_column": "text",
|
||||
}
|
||||
)
|
||||
return DatabricksVectorSearch(**kwargs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
def test_init(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
assert vectorsearch.index.describe() == INDEX_DETAILS[index_name]
|
||||
|
||||
|
||||
def test_init_fail_text_column_mismatch() -> None:
|
||||
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=DELTA_SYNC_INDEX,
|
||||
text_column="some_other_column",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_init_fail_no_text_column(index_name: str) -> None:
|
||||
with pytest.raises(ValueError, match="The `text_column` parameter is required"):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=index_name,
|
||||
embedding=EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
|
||||
def test_init_fail_columns_not_in_schema() -> None:
|
||||
columns = ["some_random_column"]
|
||||
with pytest.raises(ValueError, match="Some columns specified in `columns`"):
|
||||
init_vector_search(DIRECT_ACCESS_INDEX, columns=columns)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_init_fail_no_embedding(index_name: str) -> None:
|
||||
with pytest.raises(ValueError, match="The `embedding` parameter is required"):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=index_name,
|
||||
text_column="text",
|
||||
)
|
||||
|
||||
|
||||
def test_init_fail_embedding_already_specified_in_source() -> None:
|
||||
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=DELTA_SYNC_INDEX,
|
||||
embedding=EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_init_fail_embedding_dim_mismatch(index_name: str) -> None:
|
||||
with pytest.raises(
|
||||
ValueError, match="embedding model's dimension '1000' does not match"
|
||||
):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=index_name,
|
||||
text_column="text",
|
||||
embedding=FakeEmbeddings(1000),
|
||||
)
|
||||
|
||||
|
||||
def test_from_texts_not_supported() -> None:
|
||||
with pytest.raises(NotImplementedError, match="`from_texts` is not supported"):
|
||||
DatabricksVectorSearch.from_texts(INPUT_TEXTS, EMBEDDING_MODEL)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
|
||||
def test_add_texts_not_supported_for_delta_sync_index(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="`add_texts` is only supported for direct-access index.",
|
||||
):
|
||||
vectorsearch.add_texts(INPUT_TEXTS)
|
||||
|
||||
|
||||
def is_valid_uuid(val: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(str(val))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def test_add_texts() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
ids = [idx for idx, i in enumerate(INPUT_TEXTS)]
|
||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
|
||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS, ids=ids)
|
||||
vectorsearch.index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"text_vector": vector,
|
||||
}
|
||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(INPUT_TEXTS)
|
||||
assert added_ids == ids
|
||||
|
||||
|
||||
def test_add_texts_handle_single_text() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
|
||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS[0])
|
||||
vectorsearch.index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"text_vector": vector,
|
||||
}
|
||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == 1
|
||||
assert is_valid_uuid(added_ids[0])
|
||||
|
||||
|
||||
def test_add_texts_with_default_id() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
|
||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS)
|
||||
vectorsearch.index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"text_vector": vector,
|
||||
}
|
||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(INPUT_TEXTS)
|
||||
assert all([is_valid_uuid(id_) for id_ in added_ids])
|
||||
|
||||
|
||||
def test_add_texts_with_metadata() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(INPUT_TEXTS))]
|
||||
|
||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS, metadatas=metadatas)
|
||||
vectorsearch.index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"text_vector": vector,
|
||||
**metadata, # type: ignore[arg-type]
|
||||
}
|
||||
for text, vector, id_, metadata in zip(
|
||||
INPUT_TEXTS, vectors, added_ids, metadatas
|
||||
)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(INPUT_TEXTS)
|
||||
assert all([is_valid_uuid(id_) for id_ in added_ids])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_embeddings_property(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
assert vectorsearch.embeddings == EMBEDDING_MODEL
|
||||
|
||||
|
||||
def test_delete() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch.delete(["some id"])
|
||||
vectorsearch.index.delete.assert_called_once_with(["some id"])
|
||||
|
||||
|
||||
def test_delete_fail_no_ids() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
with pytest.raises(ValueError, match="ids must be provided."):
|
||||
vectorsearch.delete()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
|
||||
def test_delete_not_supported_for_delta_sync_index(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="`delete` is only supported for direct-access"
|
||||
):
|
||||
vectorsearch.delete(["some id"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
@pytest.mark.parametrize("query_type", [None, "ANN"])
|
||||
def test_similarity_search(index_name: str, query_type: Optional[str]) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
query = "foo"
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search(
|
||||
query, k=limit, filter=filters, query_type=query_type
|
||||
)
|
||||
if index_name == DELTA_SYNC_INDEX:
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_text=query,
|
||||
query_vector=None,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type=query_type,
|
||||
)
|
||||
else:
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_text=None,
|
||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type=query_type,
|
||||
)
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
||||
assert all(["id" in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
def test_similarity_search_hybrid(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
query = "foo"
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search(
|
||||
query, k=limit, filter=filters, query_type="HYBRID"
|
||||
)
|
||||
if index_name == DELTA_SYNC_INDEX:
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_text=query,
|
||||
query_vector=None,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
)
|
||||
else:
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_text=query,
|
||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
)
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
||||
assert all(["id" in d.metadata for d in search_result])
|
||||
|
||||
|
||||
def test_similarity_search_both_filter_and_filters_passed() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
query = "foo"
|
||||
filter = {"some filter": True}
|
||||
filters = {"some other filter": False}
|
||||
|
||||
vectorsearch.similarity_search(query, filter=filter, filters=filters)
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
||||
# `filter` should prevail over `filters`
|
||||
filters=filter,
|
||||
num_results=4,
|
||||
query_text=None,
|
||||
query_type=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
@pytest.mark.parametrize(
|
||||
"columns, expected_columns",
|
||||
[
|
||||
(None, {"id"}),
|
||||
(["id", "text", "text_vector"], {"text_vector", "id"}),
|
||||
],
|
||||
)
|
||||
def test_mmr_search(
|
||||
index_name: str, columns: Optional[List[str]], expected_columns: Set[str]
|
||||
) -> None:
|
||||
vectorsearch = init_vector_search(index_name, columns=columns)
|
||||
|
||||
query = INPUT_TEXTS[0]
|
||||
filters = {"some filter": True}
|
||||
limit = 1
|
||||
|
||||
search_result = vectorsearch.max_marginal_relevance_search(
|
||||
query, k=limit, filters=filters
|
||||
)
|
||||
assert [doc.page_content for doc in search_result] == [INPUT_TEXTS[0]]
|
||||
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_mmr_parameters(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
|
||||
query = INPUT_TEXTS[0]
|
||||
limit = 1
|
||||
fetch_k = 3
|
||||
lambda_mult = 0.25
|
||||
filters = {"some filter": True}
|
||||
|
||||
with patch(
|
||||
"langchain_databricks.vectorstores.maximal_marginal_relevance"
|
||||
) as mock_mmr:
|
||||
mock_mmr.return_value = [2]
|
||||
retriever = vectorsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={
|
||||
"k": limit,
|
||||
"fetch_k": fetch_k,
|
||||
"lambda_mult": lambda_mult,
|
||||
"filter": filters,
|
||||
},
|
||||
)
|
||||
search_result = retriever.invoke(query)
|
||||
|
||||
mock_mmr.assert_called_once()
|
||||
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
|
||||
assert vectorsearch.index.similarity_search.call_args[1]["num_results"] == fetch_k
|
||||
assert vectorsearch.index.similarity_search.call_args[1]["filters"] == filters
|
||||
assert len(search_result) == limit
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
@pytest.mark.parametrize("threshold", [0.4, 0.5, 0.8])
|
||||
def test_similarity_score_threshold(index_name: str, threshold: float) -> None:
|
||||
query = INPUT_TEXTS[0]
|
||||
limit = len(INPUT_TEXTS)
|
||||
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
retriever = vectorsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": limit, "score_threshold": threshold},
|
||||
)
|
||||
search_result = retriever.invoke(query)
|
||||
if threshold <= 0.5:
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
else:
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
def test_standard_params() -> None:
|
||||
vectorstore = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
retriever = vectorstore.as_retriever()
|
||||
ls_params = retriever._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_retriever_name": "vectorstore",
|
||||
"ls_vector_store_provider": "DatabricksVectorSearch",
|
||||
"ls_embedding_provider": "FakeEmbeddings",
|
||||
}
|
||||
|
||||
vectorstore = init_vector_search(DELTA_SYNC_INDEX)
|
||||
retriever = vectorstore.as_retriever()
|
||||
ls_params = retriever._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_retriever_name": "vectorstore",
|
||||
"ls_vector_store_provider": "DatabricksVectorSearch",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
@pytest.mark.parametrize("query_type", [None, "ANN"])
|
||||
def test_similarity_search_by_vector(
|
||||
index_name: str, query_type: Optional[str]
|
||||
) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filter=filters, query_type=query_type
|
||||
)
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_vector=query_embedding,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type=query_type,
|
||||
query_text=None,
|
||||
)
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
||||
assert all(["id" in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_similarity_search_by_vector_hybrid(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
|
||||
)
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_vector=query_embedding,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
query_text="foo",
|
||||
)
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
||||
assert all(["id" in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
def test_similarity_search_empty_result(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
vectorsearch.index.similarity_search.return_value = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": "id"},
|
||||
{"name": "text"},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": 0,
|
||||
"data_array": [],
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
search_result = vectorsearch.similarity_search("foo")
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None:
|
||||
vectorsearch = init_vector_search(DELTA_SYNC_INDEX)
|
||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="`similarity_search_by_vector` is not supported"
|
||||
):
|
||||
vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filters=filters
|
||||
)
|
Reference in New Issue
Block a user