community[minor]: Add mmr and similarity_score_threshold retrieval to DatabricksVectorSearch (#16829)

- **Description:** This PR adds support for `search_types="mmr"` and
`search_type="similarity_score_threshold"` to retrievers using
`DatabricksVectorSearch`,
  - **Issue:** 
  - **Dependencies:**
  - **Twitter handle:**

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
david-tempelmann
2024-02-12 21:51:37 +01:00
committed by GitHub
parent 42648061ad
commit 93da18b667
2 changed files with 278 additions and 6 deletions

View File

@@ -1,7 +1,8 @@
import itertools
import random
import uuid
from typing import List
from unittest.mock import MagicMock
from typing import Dict, List, Optional, Set
from unittest.mock import MagicMock, patch
import pytest
@@ -120,6 +121,52 @@ EXAMPLE_SEARCH_RESPONSE = {
"next_page_token": "",
}
EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE: Dict = {
"manifest": {
"column_count": 3,
"columns": [
{"name": DEFAULT_PRIMARY_KEY},
{"name": DEFAULT_TEXT_COLUMN},
{"name": "score"},
],
},
"result": {
"row_count": len(fake_texts),
"data_array": sorted(
[[str(uuid.uuid4()), s, 0.5] for s in fake_texts],
key=lambda x: x[2], # type: ignore
reverse=True,
),
},
"next_page_token": "",
}
EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING = {
"manifest": {
"column_count": 3,
"columns": [
{"name": DEFAULT_PRIMARY_KEY},
{"name": DEFAULT_TEXT_COLUMN},
{"name": DEFAULT_VECTOR_COLUMN},
{"name": "score"},
],
},
"result": {
"row_count": len(fake_texts),
"data_array": sorted(
[
[str(uuid.uuid4()), s, e, random.uniform(0, 1)]
for s, e in zip(
fake_texts, DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
)
],
key=lambda x: x[2], # type: ignore
reverse=True,
),
},
"next_page_token": "",
}
def mock_index(index_details: dict) -> MagicMock:
from databricks.vector_search.client import VectorSearchIndex
@@ -129,11 +176,14 @@ def mock_index(index_details: dict) -> MagicMock:
return index
def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch:
def default_databricks_vector_search(
index: MagicMock, columns: Optional[List[str]] = None
) -> DatabricksVectorSearch:
return DatabricksVectorSearch(
index,
embedding=DEFAULT_EMBEDDING_MODEL,
text_column=DEFAULT_TEXT_COLUMN,
columns=columns,
)
@@ -456,6 +506,100 @@ def test_similarity_search(index_details: dict) -> None:
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, columns, expected_columns",
[
(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, None, {"id"}),
(DIRECT_ACCESS_INDEX, None, {"id"}),
(
DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS,
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
{"text_vector", "id"},
),
(
DIRECT_ACCESS_INDEX,
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
{"text_vector", "id"},
),
],
)
def test_mmr_search(
index_details: dict, columns: Optional[List[str]], expected_columns: Set[str]
) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
vectorsearch = default_databricks_vector_search(index, columns)
query = fake_texts[0]
filters = {"some filter": True}
limit = 1
search_result = vectorsearch.max_marginal_relevance_search(
query, k=limit, filters=filters
)
assert [doc.page_content for doc in search_result] == [fake_texts[0]]
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
)
def test_mmr_parameters(index_details: dict) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
query = fake_texts[0]
limit = 1
fetch_k = 3
lambda_mult = 0.25
filters = {"some filter": True}
with patch(
"langchain_community.vectorstores.databricks_vector_search.maximal_marginal_relevance"
) as mock_mmr:
mock_mmr.return_value = [2]
retriever = default_databricks_vector_search(index).as_retriever(
search_type="mmr",
search_kwargs={
"k": limit,
"fetch_k": fetch_k,
"lambda_mult": lambda_mult,
"filters": filters,
},
)
search_result = retriever.get_relevant_documents(query)
mock_mmr.assert_called_once()
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
assert index.similarity_search.call_args[1]["num_results"] == fetch_k
assert index.similarity_search.call_args[1]["filters"] == filters
assert len(search_result) == limit
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, threshold", itertools.product(ALL_INDEXES, [0.4, 0.5, 0.8])
)
def test_similarity_score_threshold(index_details: dict, threshold: float) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE
uniform_response_score = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE["result"][
"data_array"
][0][2]
query = fake_texts[0]
limit = len(fake_texts)
retriever = default_databricks_vector_search(index).as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": limit, "score_threshold": threshold},
)
search_result = retriever.get_relevant_documents(query)
if uniform_response_score >= threshold:
assert len(search_result) == len(fake_texts)
else:
assert len(search_result) == 0
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]