community[minor]: Add mmr and similarity_score_threshold retrieval to DatabricksVectorSearch (#16829)

- **Description:** This PR adds support for `search_types="mmr"` and
`search_type="similarity_score_threshold"` to retrievers using
`DatabricksVectorSearch`,
  - **Issue:** 
  - **Dependencies:**
  - **Twitter handle:**

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
david-tempelmann 2024-02-12 21:51:37 +01:00 committed by GitHub
parent 42648061ad
commit 93da18b667
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 278 additions and 6 deletions

View File

@ -3,12 +3,15 @@ from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from databricks.vector_search.client import VectorSearchIndex
@ -321,6 +324,126 @@ class DatabricksVectorSearch(VectorStore):
)
return self._parse_search_response(search_resp)
@staticmethod
def _identity_fn(score: float) -> float:
return score
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
Databricks Vector search uses a normalized score 1/(1+d) where d
is the L2 distance. Hence, we simply return the identity function.
"""
return self._identity_fn
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filters: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filters: Filters to apply to the query. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if not self._is_databricks_managed_embeddings():
assert self.embeddings is not None, "embedding model is required."
query_vector = self.embeddings.embed_query(query)
else:
raise ValueError(
"`max_marginal_relevance_search` is not supported for index with "
"Databricks-managed embeddings."
)
docs = self.max_marginal_relevance_search_by_vector(
query_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filters=filters,
)
return docs
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filters: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filters: Filters to apply to the query. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if not self._is_databricks_managed_embeddings():
embedding_column = self._embedding_vector_column_name()
else:
raise ValueError(
"`max_marginal_relevance_search` is not supported for index with "
"Databricks-managed embeddings."
)
search_resp = self.index.similarity_search(
columns=list(set(self.columns + [embedding_column])),
query_text=None,
query_vector=embedding,
filters=filters,
num_results=fetch_k,
)
embeddings_result_index = (
search_resp.get("manifest").get("columns").index({"name": embedding_column})
)
embeddings = [
doc[embeddings_result_index]
for doc in search_resp.get("result").get("data_array")
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
ignore_cols: List = (
[embedding_column] if embedding_column not in self.columns else []
)
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
def similarity_search_by_vector(
self,
embedding: List[float],
@ -373,8 +496,13 @@ class DatabricksVectorSearch(VectorStore):
)
return self._parse_search_response(search_resp)
def _parse_search_response(self, search_resp: dict) -> List[Tuple[Document, float]]:
def _parse_search_response(
self, search_resp: dict, ignore_cols: Optional[List[str]] = None
) -> List[Tuple[Document, float]]:
"""Parse the search response into a list of Documents with score."""
if ignore_cols is None:
ignore_cols = []
columns = [
col["name"]
for col in search_resp.get("manifest", dict()).get("columns", [])
@ -386,7 +514,7 @@ class DatabricksVectorSearch(VectorStore):
metadata = {
col: value
for col, value in zip(columns[:-1], result[:-1])
if col not in [self.primary_key, self.text_column]
if col not in ([self.primary_key, self.text_column] + ignore_cols)
}
metadata[self.primary_key] = doc_id
score = result[-1]

View File

@ -1,7 +1,8 @@
import itertools
import random
import uuid
from typing import List
from unittest.mock import MagicMock
from typing import Dict, List, Optional, Set
from unittest.mock import MagicMock, patch
import pytest
@ -120,6 +121,52 @@ EXAMPLE_SEARCH_RESPONSE = {
"next_page_token": "",
}
EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE: Dict = {
"manifest": {
"column_count": 3,
"columns": [
{"name": DEFAULT_PRIMARY_KEY},
{"name": DEFAULT_TEXT_COLUMN},
{"name": "score"},
],
},
"result": {
"row_count": len(fake_texts),
"data_array": sorted(
[[str(uuid.uuid4()), s, 0.5] for s in fake_texts],
key=lambda x: x[2], # type: ignore
reverse=True,
),
},
"next_page_token": "",
}
EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING = {
"manifest": {
"column_count": 3,
"columns": [
{"name": DEFAULT_PRIMARY_KEY},
{"name": DEFAULT_TEXT_COLUMN},
{"name": DEFAULT_VECTOR_COLUMN},
{"name": "score"},
],
},
"result": {
"row_count": len(fake_texts),
"data_array": sorted(
[
[str(uuid.uuid4()), s, e, random.uniform(0, 1)]
for s, e in zip(
fake_texts, DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
)
],
key=lambda x: x[2], # type: ignore
reverse=True,
),
},
"next_page_token": "",
}
def mock_index(index_details: dict) -> MagicMock:
from databricks.vector_search.client import VectorSearchIndex
@ -129,11 +176,14 @@ def mock_index(index_details: dict) -> MagicMock:
return index
def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch:
def default_databricks_vector_search(
index: MagicMock, columns: Optional[List[str]] = None
) -> DatabricksVectorSearch:
return DatabricksVectorSearch(
index,
embedding=DEFAULT_EMBEDDING_MODEL,
text_column=DEFAULT_TEXT_COLUMN,
columns=columns,
)
@ -456,6 +506,100 @@ def test_similarity_search(index_details: dict) -> None:
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, columns, expected_columns",
[
(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, None, {"id"}),
(DIRECT_ACCESS_INDEX, None, {"id"}),
(
DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS,
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
{"text_vector", "id"},
),
(
DIRECT_ACCESS_INDEX,
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
{"text_vector", "id"},
),
],
)
def test_mmr_search(
index_details: dict, columns: Optional[List[str]], expected_columns: Set[str]
) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
vectorsearch = default_databricks_vector_search(index, columns)
query = fake_texts[0]
filters = {"some filter": True}
limit = 1
search_result = vectorsearch.max_marginal_relevance_search(
query, k=limit, filters=filters
)
assert [doc.page_content for doc in search_result] == [fake_texts[0]]
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
)
def test_mmr_parameters(index_details: dict) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
query = fake_texts[0]
limit = 1
fetch_k = 3
lambda_mult = 0.25
filters = {"some filter": True}
with patch(
"langchain_community.vectorstores.databricks_vector_search.maximal_marginal_relevance"
) as mock_mmr:
mock_mmr.return_value = [2]
retriever = default_databricks_vector_search(index).as_retriever(
search_type="mmr",
search_kwargs={
"k": limit,
"fetch_k": fetch_k,
"lambda_mult": lambda_mult,
"filters": filters,
},
)
search_result = retriever.get_relevant_documents(query)
mock_mmr.assert_called_once()
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
assert index.similarity_search.call_args[1]["num_results"] == fetch_k
assert index.similarity_search.call_args[1]["filters"] == filters
assert len(search_result) == limit
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, threshold", itertools.product(ALL_INDEXES, [0.4, 0.5, 0.8])
)
def test_similarity_score_threshold(index_details: dict, threshold: float) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE
uniform_response_score = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE["result"][
"data_array"
][0][2]
query = fake_texts[0]
limit = len(fake_texts)
retriever = default_databricks_vector_search(index).as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": limit, "score_threshold": threshold},
)
search_result = retriever.get_relevant_documents(query)
if uniform_response_score >= threshold:
assert len(search_result) == len(fake_texts)
else:
assert len(search_result) == 0
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]