databricks: add vector search and embeddings (#25648)

### Summary

Add `DatabricksVectorSearch` and `DatabricksEmbeddings` classes to the
`langchain-databricks` partner packages. Core functionality is
unchanged, but the vector search class is largely refactored for
readability and maintainability.

This PR does not add integration tests yet. This will be added once the
Databricks test workspace is ready.

Tagging @efriis as POC


### Tracker
[] Create a package and imgrate ChatDatabricks
[✍️] Migrate DatabricksVectorSearch, DatabricksEmbeddings, and their
docs
~[ ] Migrate UCFunctionToolkit and its doc~
[ ] Add provider document and update README.md
[ ] Add integration tests and set up secrets (after moved to an external
package)
[ ] Add deprecation note to the community implementations.

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Yuki Watanabe
2024-08-24 09:40:21 +09:00
committed by GitHub
parent 71c039571a
commit c7a8af2e75
12 changed files with 2321 additions and 215 deletions

View File

@@ -1,6 +1,8 @@
from importlib import metadata
from langchain_databricks.chat_models import ChatDatabricks
from langchain_databricks.embeddings import DatabricksEmbeddings
from langchain_databricks.vectorstores import DatabricksVectorSearch
try:
__version__ = metadata.version(__package__)
@@ -11,5 +13,7 @@ del metadata # optional, avoids polluting the results of dir(__package__)
__all__ = [
"ChatDatabricks",
"DatabricksEmbeddings",
"DatabricksVectorSearch",
"__version__",
]

View File

@@ -15,7 +15,6 @@ from typing import (
Type,
Union,
)
from urllib.parse import urlparse
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
@@ -50,6 +49,8 @@ from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_databricks.utils import get_deployment_client
logger = logging.getLogger(__name__)
@@ -230,25 +231,7 @@ class ChatDatabricks(BaseChatModel):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._validate_uri()
try:
from mlflow.deployments import get_deploy_client # type: ignore
self._client = get_deploy_client(self.target_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. Please run `pip install mlflow` to "
"install required dependencies."
) from e
def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
if urlparse(self.target_uri).scheme != "databricks":
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)
self._client = get_deployment_client(self.target_uri)
@property
def _default_params(self) -> Dict[str, Any]:

View File

@@ -0,0 +1,91 @@
from typing import Any, Dict, Iterator, List
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from langchain_databricks.utils import get_deployment_client
class DatabricksEmbeddings(Embeddings, BaseModel):
"""Databricks embedding model integration.
Setup:
Install ``langchain-databricks``.
.. code-block:: bash
pip install -U langchain-databricks
If you are outside Databricks, set the Databricks workspace
hostname and personal access token to environment variables:
.. code-block:: bash
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
export DATABRICKS_TOKEN="your-personal-access-token"
Key init args — completion params:
endpoint: str
Name of Databricks Model Serving endpoint to query.
target_uri: str
The target URI to use. Defaults to ``databricks``.
query_params: Dict[str, str]
The parameters to use for queries.
documents_params: Dict[str, str]
The parameters to use for documents.
Instantiate:
.. code-block:: python
from langchain_databricks import DatabricksEmbeddings
embed = DatabricksEmbeddings(
endpoint="databricks-bge-large-en",
)
Embed single text:
.. code-block:: python
input_text = "The meaning of life is 42"
embed.embed_query(input_text)
.. code-block:: python
[
0.01605224609375,
-0.0298309326171875,
...
]
"""
endpoint: str
"""The endpoint to use."""
target_uri: str = "databricks"
"""The parameters to use for queries."""
query_params: Dict[str, Any] = {}
"""The parameters to use for documents."""
documents_params: Dict[str, Any] = {}
"""The target URI to use."""
_client: Any = PrivateAttr()
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._client = get_deployment_client(self.target_uri)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embed(texts, params=self.documents_params)
def embed_query(self, text: str) -> List[float]:
return self._embed([text], params=self.query_params)[0]
def _embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]:
embeddings: List[List[float]] = []
for txt in _chunk(texts, 20):
resp = self._client.predict(
endpoint=self.endpoint,
inputs={"input": txt, **params}, # type: ignore[arg-type]
)
embeddings.extend(r["embedding"] for r in resp["data"])
return embeddings
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
for i in range(0, len(texts), size):
yield texts[i : i + size]

View File

@@ -0,0 +1,101 @@
from typing import Any, List, Union
from urllib.parse import urlparse
import numpy as np
def get_deployment_client(target_uri: str) -> Any:
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)
try:
from mlflow.deployments import get_deploy_client # type: ignore[import-untyped]
return get_deploy_client(target_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. "
"Please run `pip install mlflow` to install "
"required dependencies."
) from e
# Utility function for Maximal Marginal Relevance (MMR) reranking.
# Copied from langchain_community/vectorstores/utils.py to avoid cross-dependency
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
def maximal_marginal_relevance(
query_embedding: np.ndarray,
embedding_list: list,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
"""Calculate maximal marginal relevance.
Args:
query_embedding: Query embedding.
embedding_list: List of embeddings to select from.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
k: Number of Documents to return. Defaults to 4.
Returns:
List of indices of embeddings selected by maximal marginal relevance.
"""
if min(k, len(embedding_list)) <= 0:
return []
if query_embedding.ndim == 1:
query_embedding = np.expand_dims(query_embedding, axis=0)
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
most_similar = int(np.argmax(similarity_to_query))
idxs = [most_similar]
selected = np.array([embedding_list[most_similar]])
while len(idxs) < min(k, len(embedding_list)):
best_score = -np.inf
idx_to_add = -1
similarity_to_selected = cosine_similarity(embedding_list, selected)
for i, query_score in enumerate(similarity_to_query):
if i in idxs:
continue
redundant_score = max(similarity_to_selected[i])
equation_score = (
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
)
if equation_score > best_score:
best_score = equation_score
idx_to_add = i
idxs.append(idx_to_add)
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
return idxs
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices.
Raises:
ValueError: If the number of columns in X and Y are not the same.
"""
if len(X) == 0 or len(Y) == 0:
return np.array([])
X = np.array(X)
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
raise ValueError(
"Number of columns in X and Y must be the same. X has shape"
f"{X.shape} "
f"and Y has shape {Y.shape}."
)
X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity

View File

@@ -0,0 +1,837 @@
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from enum import Enum
from functools import partial
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
)
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore
from langchain_databricks.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"
_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index."
_NON_MANAGED_EMB_ONLY_MSG = (
"`%s` is not supported for index with Databricks-managed embeddings."
)
class DatabricksVectorSearch(VectorStore):
"""Databricks vector store integration.
Setup:
Install ``langchain-databricks`` and ``databricks-vectorsearch`` python packages.
.. code-block:: bash
pip install -U langchain-databricks databricks-vectorsearch
If you don't have a Databricks Vector Search endpoint already, you can create one by following the instructions here: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html
If you are outside Databricks, set the Databricks workspace
hostname and personal access token to environment variables:
.. code-block:: bash
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
export DATABRICKS_TOKEN="your-personal-access-token"
Key init args — indexing params:
endpoint: The name of the Databricks Vector Search endpoint.
index_name: The name of the index to use. Format: "catalog.schema.index".
embedding: The embedding model.
Required for direct-access index or delta-sync index
with self-managed embeddings.
text_column: The name of the text column to use for the embeddings.
Required for direct-access index or delta-sync index
with self-managed embeddings.
Make sure the text column specified is in the index.
columns: The list of column names to get when doing the search.
Defaults to ``[primary_key, text_column]``.
Instantiate:
`DatabricksVectorSearch` supports two types of indexes:
* **Delta Sync Index** automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the Delta Table changes.
* **Direct Vector Access Index** supports direct read and write of vectors and metadata. The user is responsible for updating this table using the REST API or the Python SDK.
Also for delta-sync index, you can choose to use Databricks-managed embeddings or self-managed embeddings (via LangChain embeddings classes).
If you are using a delta-sync index with Databricks-managed embeddings:
.. code-block:: python
from langchain_databricks.vectorstores import DatabricksVectorSearch
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>"
)
If you are using a direct-access index or a delta-sync index with self-managed embeddings,
you also need to provide the embedding model and text column in your source table to
use for the embeddings:
.. code-block:: python
from langchain_openai import OpenAIEmbeddings
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>",
embedding=OpenAIEmbeddings(),
text_column="document_content"
)
Add Documents:
.. code-block:: python
from langchain_core.documents import Document
document_1 = Document(page_content="foo", metadata={"baz": "bar"})
document_2 = Document(page_content="thud", metadata={"bar": "baz"})
document_3 = Document(page_content="i will be deleted :(")
documents = [document_1, document_2, document_3]
ids = ["1", "2", "3"]
vector_store.add_documents(documents=documents, ids=ids)
Delete Documents:
.. code-block:: python
vector_store.delete(ids=["3"])
.. note::
The `delete` method is only supported for direct-access index.
Search:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1)
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* thud [{'id': '2'}]
.. note:
By default, similarity search only returns the primary key and text column.
If you want to retrieve the custom metadata associated with the document,
pass the additional columns in the `columns` parameter when initializing the vector store.
.. code-block:: python
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>",
columns=["baz", "bar"],
)
vector_store.similarity_search(query="thud",k=1)
# Output: * thud [{'bar': 'baz', 'baz': None, 'id': '2'}]
Search with filter:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"})
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* thud [{'id': '2'}]
Search with score:
.. code-block:: python
results = vector_store.similarity_search_with_score(query="qux",k=1)
for doc, score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* [SIM=0.748804] foo [{'id': '1'}]
Async:
.. code-block:: python
# add documents
await vector_store.aadd_documents(documents=documents, ids=ids)
# delete documents
await vector_store.adelete(ids=["3"])
# search
results = vector_store.asimilarity_search(query="thud",k=1)
# search with score
results = await vector_store.asimilarity_search_with_score(query="qux",k=1)
for doc,score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* [SIM=0.748807] foo [{'id': '1'}]
Use as Retriever:
.. code-block:: python
retriever = vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
)
retriever.invoke("thud")
.. code-block:: python
[Document(metadata={'id': '2'}, page_content='thud')]
""" # noqa: E501
def __init__(
self,
endpoint: str,
index_name: str,
embedding: Optional[Embeddings] = None,
text_column: Optional[str] = None,
columns: Optional[List[str]] = None,
):
try:
from databricks.vector_search.client import ( # type: ignore[import]
VectorSearchClient,
)
except ImportError as e:
raise ImportError(
"Could not import databricks-vectorsearch python package. "
"Please install it with `pip install databricks-vectorsearch`."
) from e
self.index = VectorSearchClient().get_index(endpoint, index_name)
self._index_details = IndexDetails(self.index)
_validate_embedding(embedding, self._index_details)
self._embeddings = embedding
self._text_column = _validate_and_get_text_column(
text_column, self._index_details
)
self._columns = _validate_and_get_return_columns(
columns or [], self._text_column, self._index_details
)
self._primary_key = self._index_details.primary_key
@property
def embeddings(self) -> Optional[Embeddings]:
"""Access the query embedding object if available."""
return self._embeddings
@classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[Dict]] = None,
**kwargs: Any,
) -> VST:
raise NotImplementedError(
"`from_texts` is not supported. "
"Use `add_texts` to add to existing direct-access index."
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[Any]] = None,
**kwargs: Any,
) -> List[str]:
"""Add texts to the index.
.. note::
This method is only supported for a direct-access index.
Args:
texts: List of texts to add.
metadatas: List of metadata for each text. Defaults to None.
ids: List of ids for each text. Defaults to None.
If not provided, a random uuid will be generated for each text.
Returns:
List of ids from adding the texts into the index.
"""
if self._index_details.is_delta_sync_index():
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "add_texts")
# Wrap to list if input texts is a single string
if isinstance(texts, str):
texts = [texts]
texts = list(texts)
vectors = self._embeddings.embed_documents(texts) # type: ignore[union-attr]
ids = ids or [str(uuid.uuid4()) for _ in texts]
metadatas = metadatas or [{} for _ in texts]
updates = [
{
self._primary_key: id_,
self._text_column: text,
self._index_details.embedding_vector_column["name"]: vector,
**metadata,
}
for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas)
]
upsert_resp = self.index.upsert(updates)
if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"):
failed_ids = upsert_resp.get("result", dict()).get(
"failed_primary_keys", []
)
if upsert_resp.get("status") == "FAILURE":
logger.error("Failed to add texts to the index.")
else:
logger.warning("Some texts failed to be added to the index.")
return [id_ for id_ in ids if id_ not in failed_ids]
return ids
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.add_texts, **kwargs), texts, metadatas
)
def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete documents from the index.
.. note::
This method is only supported for a direct-access index.
Args:
ids: List of ids of documents to delete.
Returns:
True if successful.
"""
if self._index_details.is_delta_sync_index():
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "delete")
if ids is None:
raise ValueError("ids must be provided.")
self.index.delete(ids)
return True
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding.
"""
docs_with_score = self.similarity_search_with_score(
query=query,
k=k,
filter=filter,
query_type=query_type,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search, query, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query, along with scores.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding and score for each.
"""
if self._index_details.is_databricks_managed_embeddings():
query_text = query
query_vector = None
else:
# The value for `query_text` needs to be specified only for hybrid search.
if query_type is not None and query_type.upper() == "HYBRID":
query_text = query
else:
query_text = None
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
search_resp = self.index.similarity_search(
columns=self._columns,
query_text=query_text,
query_vector=query_vector,
filters=filter,
num_results=k,
query_type=query_type,
)
return self._parse_search_response(search_resp)
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
Databricks Vector search uses a normalized score 1/(1+d) where d
is the L2 distance. Hence, we simply return the identity function.
"""
return lambda score: score
async def asimilarity_search_with_score(
self, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]:
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search_with_score, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding.
"""
if self._index_details.is_databricks_managed_embeddings():
raise NotImplementedError(
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector"
)
docs_with_score = self.similarity_search_by_vector_with_score(
embedding=embedding,
k=k,
filter=filter,
query_type=query_type,
query=query,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector_with_score(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector, along with scores.
.. note::
This method is not supported for index with Databricks-managed embeddings.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding and score for each.
"""
if self._index_details.is_databricks_managed_embeddings():
raise NotImplementedError(
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector_with_score"
)
if query_type is not None and query_type.upper() == "HYBRID":
if query is None:
raise ValueError(
"A value for `query` must be specified for hybrid search."
)
query_text = query
else:
if query is not None:
raise ValueError(
(
"Cannot specify both `embedding` and "
'`query` unless `query_type="HYBRID"'
)
)
query_text = None
search_resp = self.index.similarity_search(
columns=self._columns,
query_vector=embedding,
query_text=query_text,
filters=filter,
num_results=k,
query_type=query_type,
)
return self._parse_search_response(search_resp)
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
.. note::
This method is not supported for index with Databricks-managed embeddings.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._index_details.is_databricks_managed_embeddings():
raise NotImplementedError(
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search"
)
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
docs = self.max_marginal_relevance_search_by_vector(
query_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
query_type=query_type,
)
return docs
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(
self.max_marginal_relevance_search,
query,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
.. note::
This method is not supported for index with Databricks-managed embeddings.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._index_details.is_databricks_managed_embeddings():
raise NotImplementedError(
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search_by_vector"
)
embedding_column = self._index_details.embedding_vector_column["name"]
search_resp = self.index.similarity_search(
columns=list(set(self._columns + [embedding_column])),
query_text=None,
query_vector=embedding,
filters=filter,
num_results=fetch_k,
query_type=query_type,
)
embeddings_result_index = (
search_resp.get("manifest").get("columns").index({"name": embedding_column})
)
embeddings = [
doc[embeddings_result_index]
for doc in search_resp.get("result").get("data_array")
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
ignore_cols: List = (
[embedding_column] if embedding_column not in self._columns else []
)
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def _parse_search_response(
self, search_resp: Dict, ignore_cols: Optional[List[str]] = None
) -> List[Tuple[Document, float]]:
"""Parse the search response into a list of Documents with score."""
if ignore_cols is None:
ignore_cols = []
columns = [
col["name"]
for col in search_resp.get("manifest", dict()).get("columns", [])
]
docs_with_score = []
for result in search_resp.get("result", dict()).get("data_array", []):
doc_id = result[columns.index(self._primary_key)]
text_content = result[columns.index(self._text_column)]
ignore_cols = [self._primary_key, self._text_column] + ignore_cols
metadata = {
col: value
for col, value in zip(columns[:-1], result[:-1])
if col not in ignore_cols
}
metadata[self._primary_key] = doc_id
score = result[-1]
doc = Document(page_content=text_content, metadata=metadata)
docs_with_score.append((doc, score))
return docs_with_score
def _validate_and_get_text_column(
text_column: Optional[str], index_details: IndexDetails
) -> str:
if index_details.is_databricks_managed_embeddings():
index_source_column: str = index_details.embedding_source_column["name"]
# check if input text column matches the source column of the index
if text_column is not None:
raise ValueError(
f"The index '{index_details.name}' has the source column configured as "
f"'{index_source_column}'. Do not pass the `text_column` parameter."
)
return index_source_column
else:
if text_column is None:
raise ValueError("The `text_column` parameter is required for this index.")
return text_column
def _validate_and_get_return_columns(
columns: List[str], text_column: str, index_details: IndexDetails
) -> List[str]:
"""
Get a list of columns to retrieve from the index.
If the index is direct-access index, validate the given columns against the schema.
"""
# add primary key column and source column if not in columns
if index_details.primary_key not in columns:
columns.append(index_details.primary_key)
if text_column and text_column not in columns:
columns.append(text_column)
# Validate specified columns are in the index
if index_details.is_direct_access_index() and (
index_schema := index_details.schema
):
if missing_columns := [c for c in columns if c not in index_schema]:
raise ValueError(
"Some columns specified in `columns` are not "
f"in the index schema: {missing_columns}"
)
return columns
def _validate_embedding(
embedding: Optional[Embeddings], index_details: IndexDetails
) -> None:
if index_details.is_databricks_managed_embeddings():
if embedding is not None:
raise ValueError(
f"The index '{index_details.name}' uses Databricks-managed embeddings. "
"Do not pass the `embedding` parameter when initializing vector store."
)
else:
if not embedding:
raise ValueError(
"The `embedding` parameter is required for a direct-access index "
"or delta-sync index with self-managed embedding."
)
_validate_embedding_dimension(embedding, index_details)
def _validate_embedding_dimension(
embeddings: Embeddings, index_details: IndexDetails
) -> None:
"""validate if the embedding dimension matches with the index's configuration."""
if index_embedding_dimension := index_details.embedding_vector_column.get(
"embedding_dimension"
):
# Infer the embedding dimension from the embedding function."""
actual_dimension = len(embeddings.embed_query("test"))
if actual_dimension != index_embedding_dimension:
raise ValueError(
f"The specified embedding model's dimension '{actual_dimension}' does "
f"not match with the index configuration '{index_embedding_dimension}'."
)
class IndexDetails:
"""An utility class to store the configuration details of an index."""
def __init__(self, index: Any):
self._index_details = index.describe()
@property
def name(self) -> str:
return self._index_details["name"]
@property
def schema(self) -> Optional[Dict]:
if self.is_direct_access_index():
schema_json = self.index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None
@property
def primary_key(self) -> str:
return self._index_details["primary_key"]
@property
def index_spec(self) -> Dict:
return (
self._index_details.get("delta_sync_index_spec", {})
if self.is_delta_sync_index()
else self._index_details.get("direct_access_index_spec", {})
)
@property
def embedding_vector_column(self) -> Dict:
if vector_columns := self.index_spec.get("embedding_vector_columns"):
return vector_columns[0]
return {}
@property
def embedding_source_column(self) -> Dict:
if source_columns := self.index_spec.get("embedding_source_columns"):
return source_columns[0]
return {}
def is_delta_sync_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value
def is_direct_access_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value
def is_databricks_managed_embeddings(self) -> bool:
return (
self.is_delta_sync_index()
and self.embedding_source_column.get("name") is not None
)

View File

@@ -339,6 +339,22 @@ requests = ">=2.28.1,<3"
dev = ["autoflake", "databricks-connect", "ipython", "ipywidgets", "isort", "pycodestyle", "pyfakefs", "pytest", "pytest-cov", "pytest-mock", "pytest-rerunfailures", "pytest-xdist", "requests-mock", "wheel", "yapf"]
notebook = ["ipython (>=8,<9)", "ipywidgets (>=8,<9)"]
[[package]]
name = "databricks-vectorsearch"
version = "0.40"
description = "Databricks Vector Search Client"
optional = false
python-versions = ">=3.7"
files = [
{file = "databricks_vectorsearch-0.40-py3-none-any.whl", hash = "sha256:c684291e1b0472ece8f6df8c6ff7982f49ce7075e1df5b93459e148dea1d70d7"},
]
[package.dependencies]
deprecation = ">=2"
mlflow-skinny = ">=2.11.3,<3"
protobuf = ">=3.12.0,<5"
requests = ">=2"
[[package]]
name = "deprecated"
version = "1.2.14"
@@ -356,6 +372,20 @@ wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]]
name = "deprecation"
version = "2.1.0"
description = "A library to handle automated deprecations"
optional = false
python-versions = "*"
files = [
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
]
[package.dependencies]
packaging = "*"
[[package]]
name = "docker"
version = "7.1.0"
@@ -1469,8 +1499,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@@ -1613,22 +1643,22 @@ testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "protobuf"
version = "5.27.3"
version = "4.25.4"
description = ""
optional = false
python-versions = ">=3.8"
files = [
{file = "protobuf-5.27.3-cp310-abi3-win32.whl", hash = "sha256:dcb307cd4ef8fec0cf52cb9105a03d06fbb5275ce6d84a6ae33bc6cf84e0a07b"},
{file = "protobuf-5.27.3-cp310-abi3-win_amd64.whl", hash = "sha256:16ddf3f8c6c41e1e803da7abea17b1793a97ef079a912e42351eabb19b2cffe7"},
{file = "protobuf-5.27.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:68248c60d53f6168f565a8c76dc58ba4fa2ade31c2d1ebdae6d80f969cdc2d4f"},
{file = "protobuf-5.27.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:b8a994fb3d1c11156e7d1e427186662b64694a62b55936b2b9348f0a7c6625ce"},
{file = "protobuf-5.27.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:a55c48f2a2092d8e213bd143474df33a6ae751b781dd1d1f4d953c128a415b25"},
{file = "protobuf-5.27.3-cp38-cp38-win32.whl", hash = "sha256:043853dcb55cc262bf2e116215ad43fa0859caab79bb0b2d31b708f128ece035"},
{file = "protobuf-5.27.3-cp38-cp38-win_amd64.whl", hash = "sha256:c2a105c24f08b1e53d6c7ffe69cb09d0031512f0b72f812dd4005b8112dbe91e"},
{file = "protobuf-5.27.3-cp39-cp39-win32.whl", hash = "sha256:c84eee2c71ed83704f1afbf1a85c3171eab0fd1ade3b399b3fad0884cbcca8bf"},
{file = "protobuf-5.27.3-cp39-cp39-win_amd64.whl", hash = "sha256:af7c0b7cfbbb649ad26132e53faa348580f844d9ca46fd3ec7ca48a1ea5db8a1"},
{file = "protobuf-5.27.3-py3-none-any.whl", hash = "sha256:8572c6533e544ebf6899c360e91d6bcbbee2549251643d32c52cf8a5de295ba5"},
{file = "protobuf-5.27.3.tar.gz", hash = "sha256:82460903e640f2b7e34ee81a947fdaad89de796d324bcbc38ff5430bcdead82c"},
{file = "protobuf-4.25.4-cp310-abi3-win32.whl", hash = "sha256:db9fd45183e1a67722cafa5c1da3e85c6492a5383f127c86c4c4aa4845867dc4"},
{file = "protobuf-4.25.4-cp310-abi3-win_amd64.whl", hash = "sha256:ba3d8504116a921af46499471c63a85260c1a5fc23333154a427a310e015d26d"},
{file = "protobuf-4.25.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:eecd41bfc0e4b1bd3fa7909ed93dd14dd5567b98c941d6c1ad08fdcab3d6884b"},
{file = "protobuf-4.25.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:4c8a70fdcb995dcf6c8966cfa3a29101916f7225e9afe3ced4395359955d3835"},
{file = "protobuf-4.25.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3319e073562e2515c6ddc643eb92ce20809f5d8f10fead3332f71c63be6a7040"},
{file = "protobuf-4.25.4-cp38-cp38-win32.whl", hash = "sha256:7e372cbbda66a63ebca18f8ffaa6948455dfecc4e9c1029312f6c2edcd86c4e1"},
{file = "protobuf-4.25.4-cp38-cp38-win_amd64.whl", hash = "sha256:051e97ce9fa6067a4546e75cb14f90cf0232dcb3e3d508c448b8d0e4265b61c1"},
{file = "protobuf-4.25.4-cp39-cp39-win32.whl", hash = "sha256:90bf6fd378494eb698805bbbe7afe6c5d12c8e17fca817a646cd6a1818c696ca"},
{file = "protobuf-4.25.4-cp39-cp39-win_amd64.whl", hash = "sha256:ac79a48d6b99dfed2729ccccee547b34a1d3d63289c71cef056653a846a2240f"},
{file = "protobuf-4.25.4-py3-none-any.whl", hash = "sha256:bfbebc1c8e4793cfd58589acfb8a1026be0003e852b9da7db5a4285bde996978"},
{file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"},
]
[[package]]
@@ -2492,4 +2522,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<3.12"
content-hash = "6da52f0e39bc7da1a80cc181bdd481e57f4644daf2f3c6da6a6b0ead2e813be9"
content-hash = "857f47603d9dd6fe8882c7525613a54a54ee459a9ee012f3d19e510c5477f3db"

View File

@@ -26,6 +26,7 @@ scipy = [
{version = ">=1.11", python = ">=3.12"},
{version = "<2", python = "<3.12"}
]
databricks-vectorsearch = "^0.40"
[tool.poetry.group.test]
optional = true

View File

@@ -0,0 +1,69 @@
"""Test Together AI embeddings."""
from typing import Any, Dict, Generator
from unittest import mock
import pytest
from mlflow.deployments import BaseDeploymentClient # type: ignore[import-untyped]
from langchain_databricks import DatabricksEmbeddings
def _mock_embeddings(endpoint: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": list(range(1536)),
"index": 0,
}
for _ in inputs["input"]
],
"model": "text-embedding-3-small",
"usage": {"prompt_tokens": 8, "total_tokens": 8},
}
@pytest.fixture
def mock_client() -> Generator:
client = mock.MagicMock()
client.predict.side_effect = _mock_embeddings
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
yield client
@pytest.fixture
def embeddings() -> DatabricksEmbeddings:
return DatabricksEmbeddings(
endpoint="text-embedding-3-small",
documents_params={"fruit": "apple"},
query_params={"fruit": "banana"},
)
def test_embed_documents(
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
) -> None:
documents = ["foo"] * 30
output = embeddings.embed_documents(documents)
assert len(output) == 30
assert len(output[0]) == 1536
assert mock_client.predict.call_count == 2
assert all(
call_arg[1]["inputs"]["fruit"] == "apple"
for call_arg in mock_client().predict.call_args_list
)
def test_embed_query(
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
) -> None:
query = "foo bar"
output = embeddings.embed_query(query)
assert len(output) == 1536
mock_client.predict.assert_called_once()
assert mock_client.predict.call_args[1] == {
"endpoint": "text-embedding-3-small",
"inputs": {"input": [query], "fruit": "banana"},
}

View File

@@ -2,6 +2,8 @@ from langchain_databricks import __all__
EXPECTED_ALL = [
"ChatDatabricks",
"DatabricksEmbeddings",
"DatabricksVectorSearch",
"__version__",
]

View File

@@ -0,0 +1,629 @@
import uuid
from typing import Any, Dict, Generator, List, Optional, Set
from unittest import mock
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.embeddings import Embeddings
from langchain_databricks.vectorstores import DatabricksVectorSearch
INPUT_TEXTS = ["foo", "bar", "baz"]
DEFAULT_VECTOR_DIMENSION = 4
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION):
super().__init__()
self.dimension = dimension
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (self.dimension - 1) + [float(i)]
for i in range(len(embedding_texts))
]
def embed_query(self, text: str) -> List[float]:
"""Return simple embeddings."""
return [float(1.0)] * (self.dimension - 1) + [float(0.0)]
EMBEDDING_MODEL = FakeEmbeddings()
### Dummy similarity_search() Response ###
EXAMPLE_SEARCH_RESPONSE = {
"manifest": {
"column_count": 3,
"columns": [
{"name": "id"},
{"name": "text"},
{"name": "text_vector"},
{"name": "score"},
],
},
"result": {
"row_count": len(INPUT_TEXTS),
"data_array": sorted(
[
[str(uuid.uuid4()), s, e, 0.5]
for s, e in zip(
INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
)
],
key=lambda x: x[2], # type: ignore
reverse=True,
),
},
"next_page_token": "",
}
### Dummy Indices ####
ENDPOINT_NAME = "test-endpoint"
DIRECT_ACCESS_INDEX = "test-direct-access-index"
DELTA_SYNC_INDEX = "test-delta-sync-index"
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test-delta-sync-self-managed-index"
ALL_INDEX_NAMES = {
DIRECT_ACCESS_INDEX,
DELTA_SYNC_INDEX,
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
}
INDEX_DETAILS = {
DELTA_SYNC_INDEX: {
"name": DELTA_SYNC_INDEX,
"endpoint_name": ENDPOINT_NAME,
"index_type": "DELTA_SYNC",
"primary_key": "id",
"delta_sync_index_spec": {
"source_table": "ml.llm.source_table",
"pipeline_type": "CONTINUOUS",
"embedding_source_columns": [
{
"name": "text",
"embedding_model_endpoint_name": "openai-text-embedding",
}
],
},
},
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: {
"name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
"endpoint_name": ENDPOINT_NAME,
"index_type": "DELTA_SYNC",
"primary_key": "id",
"delta_sync_index_spec": {
"source_table": "ml.llm.source_table",
"pipeline_type": "CONTINUOUS",
"embedding_vector_columns": [
{
"name": "text_vector",
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
}
],
},
},
DIRECT_ACCESS_INDEX: {
"name": DIRECT_ACCESS_INDEX,
"endpoint_name": ENDPOINT_NAME,
"index_type": "DIRECT_ACCESS",
"primary_key": "id",
"direct_access_index_spec": {
"embedding_vector_columns": [
{
"name": "text_vector",
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
}
],
"schema_json": f"{{"
f'"{"id"}": "int", '
f'"feat1": "str", '
f'"feat2": "float", '
f'"text": "string", '
f'"{"text_vector"}": "array<float>"'
f"}}",
},
},
}
@pytest.fixture(autouse=True)
def mock_vs_client() -> Generator:
def _get_index(endpoint: str, index_name: str) -> MagicMock:
from databricks.vector_search.client import VectorSearchIndex # type: ignore
if endpoint != ENDPOINT_NAME:
raise ValueError(f"Unknown endpoint: {endpoint}")
index = MagicMock(spec=VectorSearchIndex)
index.describe.return_value = INDEX_DETAILS[index_name]
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
return index
mock_client = MagicMock()
mock_client.get_index.side_effect = _get_index
with mock.patch(
"databricks.vector_search.client.VectorSearchClient",
return_value=mock_client,
):
yield
def init_vector_search(
index_name: str, columns: Optional[List[str]] = None
) -> DatabricksVectorSearch:
kwargs: Dict[str, Any] = {
"endpoint": ENDPOINT_NAME,
"index_name": index_name,
"columns": columns,
}
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
"embedding": EMBEDDING_MODEL,
"text_column": "text",
}
)
return DatabricksVectorSearch(**kwargs) # type: ignore[arg-type]
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_init(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
assert vectorsearch.index.describe() == INDEX_DETAILS[index_name]
def test_init_fail_text_column_mismatch() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
text_column="some_other_column",
)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_init_fail_no_text_column(index_name: str) -> None:
with pytest.raises(ValueError, match="The `text_column` parameter is required"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
embedding=EMBEDDING_MODEL,
)
def test_init_fail_columns_not_in_schema() -> None:
columns = ["some_random_column"]
with pytest.raises(ValueError, match="Some columns specified in `columns`"):
init_vector_search(DIRECT_ACCESS_INDEX, columns=columns)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_init_fail_no_embedding(index_name: str) -> None:
with pytest.raises(ValueError, match="The `embedding` parameter is required"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
text_column="text",
)
def test_init_fail_embedding_already_specified_in_source() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
embedding=EMBEDDING_MODEL,
)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_init_fail_embedding_dim_mismatch(index_name: str) -> None:
with pytest.raises(
ValueError, match="embedding model's dimension '1000' does not match"
):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
text_column="text",
embedding=FakeEmbeddings(1000),
)
def test_from_texts_not_supported() -> None:
with pytest.raises(NotImplementedError, match="`from_texts` is not supported"):
DatabricksVectorSearch.from_texts(INPUT_TEXTS, EMBEDDING_MODEL)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
def test_add_texts_not_supported_for_delta_sync_index(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
with pytest.raises(
NotImplementedError,
match="`add_texts` is only supported for direct-access index.",
):
vectorsearch.add_texts(INPUT_TEXTS)
def is_valid_uuid(val: str) -> bool:
try:
uuid.UUID(str(val))
return True
except ValueError:
return False
def test_add_texts() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
ids = [idx for idx, i in enumerate(INPUT_TEXTS)]
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
added_ids = vectorsearch.add_texts(INPUT_TEXTS, ids=ids)
vectorsearch.index.upsert.assert_called_once_with(
[
{
"id": id_,
"text": text,
"text_vector": vector,
}
for text, vector, id_ in zip(INPUT_TEXTS, vectors, ids)
]
)
assert len(added_ids) == len(INPUT_TEXTS)
assert added_ids == ids
def test_add_texts_handle_single_text() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
added_ids = vectorsearch.add_texts(INPUT_TEXTS[0])
vectorsearch.index.upsert.assert_called_once_with(
[
{
"id": id_,
"text": text,
"text_vector": vector,
}
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
]
)
assert len(added_ids) == 1
assert is_valid_uuid(added_ids[0])
def test_add_texts_with_default_id() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
added_ids = vectorsearch.add_texts(INPUT_TEXTS)
vectorsearch.index.upsert.assert_called_once_with(
[
{
"id": id_,
"text": text,
"text_vector": vector,
}
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
]
)
assert len(added_ids) == len(INPUT_TEXTS)
assert all([is_valid_uuid(id_) for id_ in added_ids])
def test_add_texts_with_metadata() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(INPUT_TEXTS))]
added_ids = vectorsearch.add_texts(INPUT_TEXTS, metadatas=metadatas)
vectorsearch.index.upsert.assert_called_once_with(
[
{
"id": id_,
"text": text,
"text_vector": vector,
**metadata, # type: ignore[arg-type]
}
for text, vector, id_, metadata in zip(
INPUT_TEXTS, vectors, added_ids, metadatas
)
]
)
assert len(added_ids) == len(INPUT_TEXTS)
assert all([is_valid_uuid(id_) for id_ in added_ids])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_embeddings_property(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
assert vectorsearch.embeddings == EMBEDDING_MODEL
def test_delete() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
vectorsearch.delete(["some id"])
vectorsearch.index.delete.assert_called_once_with(["some id"])
def test_delete_fail_no_ids() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
with pytest.raises(ValueError, match="ids must be provided."):
vectorsearch.delete()
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
def test_delete_not_supported_for_delta_sync_index(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
with pytest.raises(
NotImplementedError, match="`delete` is only supported for direct-access"
):
vectorsearch.delete(["some id"])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("query_type", [None, "ANN"])
def test_similarity_search(index_name: str, query_type: Optional[str]) -> None:
vectorsearch = init_vector_search(index_name)
query = "foo"
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search(
query, k=limit, filter=filters, query_type=query_type
)
if index_name == DELTA_SYNC_INDEX:
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=query,
query_vector=None,
filters=filters,
num_results=limit,
query_type=query_type,
)
else:
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=None,
query_vector=EMBEDDING_MODEL.embed_query(query),
filters=filters,
num_results=limit,
query_type=query_type,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for d in search_result])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_similarity_search_hybrid(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
query = "foo"
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search(
query, k=limit, filter=filters, query_type="HYBRID"
)
if index_name == DELTA_SYNC_INDEX:
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=query,
query_vector=None,
filters=filters,
num_results=limit,
query_type="HYBRID",
)
else:
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=query,
query_vector=EMBEDDING_MODEL.embed_query(query),
filters=filters,
num_results=limit,
query_type="HYBRID",
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for d in search_result])
def test_similarity_search_both_filter_and_filters_passed() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
query = "foo"
filter = {"some filter": True}
filters = {"some other filter": False}
vectorsearch.similarity_search(query, filter=filter, filters=filters)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=EMBEDDING_MODEL.embed_query(query),
# `filter` should prevail over `filters`
filters=filter,
num_results=4,
query_text=None,
query_type=None,
)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
@pytest.mark.parametrize(
"columns, expected_columns",
[
(None, {"id"}),
(["id", "text", "text_vector"], {"text_vector", "id"}),
],
)
def test_mmr_search(
index_name: str, columns: Optional[List[str]], expected_columns: Set[str]
) -> None:
vectorsearch = init_vector_search(index_name, columns=columns)
query = INPUT_TEXTS[0]
filters = {"some filter": True}
limit = 1
search_result = vectorsearch.max_marginal_relevance_search(
query, k=limit, filters=filters
)
assert [doc.page_content for doc in search_result] == [INPUT_TEXTS[0]]
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_mmr_parameters(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
query = INPUT_TEXTS[0]
limit = 1
fetch_k = 3
lambda_mult = 0.25
filters = {"some filter": True}
with patch(
"langchain_databricks.vectorstores.maximal_marginal_relevance"
) as mock_mmr:
mock_mmr.return_value = [2]
retriever = vectorsearch.as_retriever(
search_type="mmr",
search_kwargs={
"k": limit,
"fetch_k": fetch_k,
"lambda_mult": lambda_mult,
"filter": filters,
},
)
search_result = retriever.invoke(query)
mock_mmr.assert_called_once()
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
assert vectorsearch.index.similarity_search.call_args[1]["num_results"] == fetch_k
assert vectorsearch.index.similarity_search.call_args[1]["filters"] == filters
assert len(search_result) == limit
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("threshold", [0.4, 0.5, 0.8])
def test_similarity_score_threshold(index_name: str, threshold: float) -> None:
query = INPUT_TEXTS[0]
limit = len(INPUT_TEXTS)
vectorsearch = init_vector_search(index_name)
retriever = vectorsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": limit, "score_threshold": threshold},
)
search_result = retriever.invoke(query)
if threshold <= 0.5:
assert len(search_result) == len(INPUT_TEXTS)
else:
assert len(search_result) == 0
def test_standard_params() -> None:
vectorstore = init_vector_search(DIRECT_ACCESS_INDEX)
retriever = vectorstore.as_retriever()
ls_params = retriever._get_ls_params()
assert ls_params == {
"ls_retriever_name": "vectorstore",
"ls_vector_store_provider": "DatabricksVectorSearch",
"ls_embedding_provider": "FakeEmbeddings",
}
vectorstore = init_vector_search(DELTA_SYNC_INDEX)
retriever = vectorstore.as_retriever()
ls_params = retriever._get_ls_params()
assert ls_params == {
"ls_retriever_name": "vectorstore",
"ls_vector_store_provider": "DatabricksVectorSearch",
}
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
@pytest.mark.parametrize("query_type", [None, "ANN"])
def test_similarity_search_by_vector(
index_name: str, query_type: Optional[str]
) -> None:
vectorsearch = init_vector_search(index_name)
query_embedding = EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters, query_type=query_type
)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=query_embedding,
filters=filters,
num_results=limit,
query_type=query_type,
query_text=None,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for d in search_result])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_similarity_search_by_vector_hybrid(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
query_embedding = EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=query_embedding,
filters=filters,
num_results=limit,
query_type="HYBRID",
query_text="foo",
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for d in search_result])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_similarity_search_empty_result(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
vectorsearch.index.similarity_search.return_value = {
"manifest": {
"column_count": 3,
"columns": [
{"name": "id"},
{"name": "text"},
{"name": "score"},
],
},
"result": {
"row_count": 0,
"data_array": [],
},
"next_page_token": "",
}
search_result = vectorsearch.similarity_search("foo")
assert len(search_result) == 0
def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None:
vectorsearch = init_vector_search(DELTA_SYNC_INDEX)
query_embedding = EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
with pytest.raises(
NotImplementedError, match="`similarity_search_by_vector` is not supported"
):
vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filters=filters
)