mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
community[minor]: Add mmr and similarity_score_threshold retrieval to DatabricksVectorSearch (#16829)
- **Description:** This PR adds support for `search_types="mmr"` and `search_type="similarity_score_threshold"` to retrievers using `DatabricksVectorSearch`, - **Issue:** - **Dependencies:** - **Twitter handle:** --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import itertools
|
||||
import random
|
||||
import uuid
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
from typing import Dict, List, Optional, Set
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -120,6 +121,52 @@ EXAMPLE_SEARCH_RESPONSE = {
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE: Dict = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": DEFAULT_PRIMARY_KEY},
|
||||
{"name": DEFAULT_TEXT_COLUMN},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": len(fake_texts),
|
||||
"data_array": sorted(
|
||||
[[str(uuid.uuid4()), s, 0.5] for s in fake_texts],
|
||||
key=lambda x: x[2], # type: ignore
|
||||
reverse=True,
|
||||
),
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": DEFAULT_PRIMARY_KEY},
|
||||
{"name": DEFAULT_TEXT_COLUMN},
|
||||
{"name": DEFAULT_VECTOR_COLUMN},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": len(fake_texts),
|
||||
"data_array": sorted(
|
||||
[
|
||||
[str(uuid.uuid4()), s, e, random.uniform(0, 1)]
|
||||
for s, e in zip(
|
||||
fake_texts, DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
|
||||
)
|
||||
],
|
||||
key=lambda x: x[2], # type: ignore
|
||||
reverse=True,
|
||||
),
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
|
||||
def mock_index(index_details: dict) -> MagicMock:
|
||||
from databricks.vector_search.client import VectorSearchIndex
|
||||
@@ -129,11 +176,14 @@ def mock_index(index_details: dict) -> MagicMock:
|
||||
return index
|
||||
|
||||
|
||||
def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch:
|
||||
def default_databricks_vector_search(
|
||||
index: MagicMock, columns: Optional[List[str]] = None
|
||||
) -> DatabricksVectorSearch:
|
||||
return DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
columns=columns,
|
||||
)
|
||||
|
||||
|
||||
@@ -456,6 +506,100 @@ def test_similarity_search(index_details: dict) -> None:
|
||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details, columns, expected_columns",
|
||||
[
|
||||
(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, None, {"id"}),
|
||||
(DIRECT_ACCESS_INDEX, None, {"id"}),
|
||||
(
|
||||
DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS,
|
||||
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
|
||||
{"text_vector", "id"},
|
||||
),
|
||||
(
|
||||
DIRECT_ACCESS_INDEX,
|
||||
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
|
||||
{"text_vector", "id"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mmr_search(
|
||||
index_details: dict, columns: Optional[List[str]], expected_columns: Set[str]
|
||||
) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
|
||||
vectorsearch = default_databricks_vector_search(index, columns)
|
||||
query = fake_texts[0]
|
||||
filters = {"some filter": True}
|
||||
limit = 1
|
||||
|
||||
search_result = vectorsearch.max_marginal_relevance_search(
|
||||
query, k=limit, filters=filters
|
||||
)
|
||||
assert [doc.page_content for doc in search_result] == [fake_texts[0]]
|
||||
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
)
|
||||
def test_mmr_parameters(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
|
||||
query = fake_texts[0]
|
||||
limit = 1
|
||||
fetch_k = 3
|
||||
lambda_mult = 0.25
|
||||
filters = {"some filter": True}
|
||||
|
||||
with patch(
|
||||
"langchain_community.vectorstores.databricks_vector_search.maximal_marginal_relevance"
|
||||
) as mock_mmr:
|
||||
mock_mmr.return_value = [2]
|
||||
retriever = default_databricks_vector_search(index).as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={
|
||||
"k": limit,
|
||||
"fetch_k": fetch_k,
|
||||
"lambda_mult": lambda_mult,
|
||||
"filters": filters,
|
||||
},
|
||||
)
|
||||
search_result = retriever.get_relevant_documents(query)
|
||||
|
||||
mock_mmr.assert_called_once()
|
||||
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
|
||||
assert index.similarity_search.call_args[1]["num_results"] == fetch_k
|
||||
assert index.similarity_search.call_args[1]["filters"] == filters
|
||||
assert len(search_result) == limit
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details, threshold", itertools.product(ALL_INDEXES, [0.4, 0.5, 0.8])
|
||||
)
|
||||
def test_similarity_score_threshold(index_details: dict, threshold: float) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE
|
||||
uniform_response_score = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE["result"][
|
||||
"data_array"
|
||||
][0][2]
|
||||
query = fake_texts[0]
|
||||
limit = len(fake_texts)
|
||||
|
||||
retriever = default_databricks_vector_search(index).as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": limit, "score_threshold": threshold},
|
||||
)
|
||||
search_result = retriever.get_relevant_documents(query)
|
||||
if uniform_response_score >= threshold:
|
||||
assert len(search_result) == len(fake_texts)
|
||||
else:
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
|
Reference in New Issue
Block a user