community[minor]: Add mmr and similarity_score_threshold retrieval to DatabricksVectorSearch (#16829)

- **Description:** This PR adds support for `search_types="mmr"` and
`search_type="similarity_score_threshold"` to retrievers using
`DatabricksVectorSearch`,
  - **Issue:** 
  - **Dependencies:**
  - **Twitter handle:**

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
david-tempelmann 2024-02-12 21:51:37 +01:00 committed by GitHub
parent 42648061ad
commit 93da18b667
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 278 additions and 6 deletions

View File

@ -3,12 +3,15 @@ from __future__ import annotations
import json import json
import logging import logging
import uuid import uuid
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type
import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore from langchain_core.vectorstores import VST, VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING: if TYPE_CHECKING:
from databricks.vector_search.client import VectorSearchIndex from databricks.vector_search.client import VectorSearchIndex
@ -321,6 +324,126 @@ class DatabricksVectorSearch(VectorStore):
) )
return self._parse_search_response(search_resp) return self._parse_search_response(search_resp)
@staticmethod
def _identity_fn(score: float) -> float:
return score
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
Databricks Vector search uses a normalized score 1/(1+d) where d
is the L2 distance. Hence, we simply return the identity function.
"""
return self._identity_fn
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filters: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filters: Filters to apply to the query. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if not self._is_databricks_managed_embeddings():
assert self.embeddings is not None, "embedding model is required."
query_vector = self.embeddings.embed_query(query)
else:
raise ValueError(
"`max_marginal_relevance_search` is not supported for index with "
"Databricks-managed embeddings."
)
docs = self.max_marginal_relevance_search_by_vector(
query_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filters=filters,
)
return docs
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filters: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filters: Filters to apply to the query. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if not self._is_databricks_managed_embeddings():
embedding_column = self._embedding_vector_column_name()
else:
raise ValueError(
"`max_marginal_relevance_search` is not supported for index with "
"Databricks-managed embeddings."
)
search_resp = self.index.similarity_search(
columns=list(set(self.columns + [embedding_column])),
query_text=None,
query_vector=embedding,
filters=filters,
num_results=fetch_k,
)
embeddings_result_index = (
search_resp.get("manifest").get("columns").index({"name": embedding_column})
)
embeddings = [
doc[embeddings_result_index]
for doc in search_resp.get("result").get("data_array")
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
ignore_cols: List = (
[embedding_column] if embedding_column not in self.columns else []
)
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
def similarity_search_by_vector( def similarity_search_by_vector(
self, self,
embedding: List[float], embedding: List[float],
@ -373,8 +496,13 @@ class DatabricksVectorSearch(VectorStore):
) )
return self._parse_search_response(search_resp) return self._parse_search_response(search_resp)
def _parse_search_response(self, search_resp: dict) -> List[Tuple[Document, float]]: def _parse_search_response(
self, search_resp: dict, ignore_cols: Optional[List[str]] = None
) -> List[Tuple[Document, float]]:
"""Parse the search response into a list of Documents with score.""" """Parse the search response into a list of Documents with score."""
if ignore_cols is None:
ignore_cols = []
columns = [ columns = [
col["name"] col["name"]
for col in search_resp.get("manifest", dict()).get("columns", []) for col in search_resp.get("manifest", dict()).get("columns", [])
@ -386,7 +514,7 @@ class DatabricksVectorSearch(VectorStore):
metadata = { metadata = {
col: value col: value
for col, value in zip(columns[:-1], result[:-1]) for col, value in zip(columns[:-1], result[:-1])
if col not in [self.primary_key, self.text_column] if col not in ([self.primary_key, self.text_column] + ignore_cols)
} }
metadata[self.primary_key] = doc_id metadata[self.primary_key] = doc_id
score = result[-1] score = result[-1]

View File

@ -1,7 +1,8 @@
import itertools
import random import random
import uuid import uuid
from typing import List from typing import Dict, List, Optional, Set
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
@ -120,6 +121,52 @@ EXAMPLE_SEARCH_RESPONSE = {
"next_page_token": "", "next_page_token": "",
} }
EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE: Dict = {
"manifest": {
"column_count": 3,
"columns": [
{"name": DEFAULT_PRIMARY_KEY},
{"name": DEFAULT_TEXT_COLUMN},
{"name": "score"},
],
},
"result": {
"row_count": len(fake_texts),
"data_array": sorted(
[[str(uuid.uuid4()), s, 0.5] for s in fake_texts],
key=lambda x: x[2], # type: ignore
reverse=True,
),
},
"next_page_token": "",
}
EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING = {
"manifest": {
"column_count": 3,
"columns": [
{"name": DEFAULT_PRIMARY_KEY},
{"name": DEFAULT_TEXT_COLUMN},
{"name": DEFAULT_VECTOR_COLUMN},
{"name": "score"},
],
},
"result": {
"row_count": len(fake_texts),
"data_array": sorted(
[
[str(uuid.uuid4()), s, e, random.uniform(0, 1)]
for s, e in zip(
fake_texts, DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
)
],
key=lambda x: x[2], # type: ignore
reverse=True,
),
},
"next_page_token": "",
}
def mock_index(index_details: dict) -> MagicMock: def mock_index(index_details: dict) -> MagicMock:
from databricks.vector_search.client import VectorSearchIndex from databricks.vector_search.client import VectorSearchIndex
@ -129,11 +176,14 @@ def mock_index(index_details: dict) -> MagicMock:
return index return index
def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch: def default_databricks_vector_search(
index: MagicMock, columns: Optional[List[str]] = None
) -> DatabricksVectorSearch:
return DatabricksVectorSearch( return DatabricksVectorSearch(
index, index,
embedding=DEFAULT_EMBEDDING_MODEL, embedding=DEFAULT_EMBEDDING_MODEL,
text_column=DEFAULT_TEXT_COLUMN, text_column=DEFAULT_TEXT_COLUMN,
columns=columns,
) )
@ -456,6 +506,100 @@ def test_similarity_search(index_details: dict) -> None:
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, columns, expected_columns",
[
(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, None, {"id"}),
(DIRECT_ACCESS_INDEX, None, {"id"}),
(
DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS,
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
{"text_vector", "id"},
),
(
DIRECT_ACCESS_INDEX,
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
{"text_vector", "id"},
),
],
)
def test_mmr_search(
index_details: dict, columns: Optional[List[str]], expected_columns: Set[str]
) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
vectorsearch = default_databricks_vector_search(index, columns)
query = fake_texts[0]
filters = {"some filter": True}
limit = 1
search_result = vectorsearch.max_marginal_relevance_search(
query, k=limit, filters=filters
)
assert [doc.page_content for doc in search_result] == [fake_texts[0]]
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
)
def test_mmr_parameters(index_details: dict) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
query = fake_texts[0]
limit = 1
fetch_k = 3
lambda_mult = 0.25
filters = {"some filter": True}
with patch(
"langchain_community.vectorstores.databricks_vector_search.maximal_marginal_relevance"
) as mock_mmr:
mock_mmr.return_value = [2]
retriever = default_databricks_vector_search(index).as_retriever(
search_type="mmr",
search_kwargs={
"k": limit,
"fetch_k": fetch_k,
"lambda_mult": lambda_mult,
"filters": filters,
},
)
search_result = retriever.get_relevant_documents(query)
mock_mmr.assert_called_once()
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
assert index.similarity_search.call_args[1]["num_results"] == fetch_k
assert index.similarity_search.call_args[1]["filters"] == filters
assert len(search_result) == limit
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, threshold", itertools.product(ALL_INDEXES, [0.4, 0.5, 0.8])
)
def test_similarity_score_threshold(index_details: dict, threshold: float) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE
uniform_response_score = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE["result"][
"data_array"
][0][2]
query = fake_texts[0]
limit = len(fake_texts)
retriever = default_databricks_vector_search(index).as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": limit, "score_threshold": threshold},
)
search_result = retriever.get_relevant_documents(query)
if uniform_response_score >= threshold:
assert len(search_result) == len(fake_texts)
else:
assert len(search_result) == 0
@pytest.mark.requires("databricks", "databricks.vector_search") @pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]