mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
community[minor]: Add mmr and similarity_score_threshold retrieval to DatabricksVectorSearch (#16829)
- **Description:** This PR adds support for `search_types="mmr"` and `search_type="similarity_score_threshold"` to retrievers using `DatabricksVectorSearch`, - **Issue:** - **Dependencies:** - **Twitter handle:** --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
42648061ad
commit
93da18b667
@ -3,12 +3,15 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VST, VectorStore
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from databricks.vector_search.client import VectorSearchIndex
|
||||
|
||||
@ -321,6 +324,126 @@ class DatabricksVectorSearch(VectorStore):
|
||||
)
|
||||
return self._parse_search_response(search_resp)
|
||||
|
||||
@staticmethod
|
||||
def _identity_fn(score: float) -> float:
|
||||
return score
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
Databricks Vector search uses a normalized score 1/(1+d) where d
|
||||
is the L2 distance. Hence, we simply return the identity function.
|
||||
"""
|
||||
|
||||
return self._identity_fn
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filters: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filters: Filters to apply to the query. Defaults to None.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if not self._is_databricks_managed_embeddings():
|
||||
assert self.embeddings is not None, "embedding model is required."
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`max_marginal_relevance_search` is not supported for index with "
|
||||
"Databricks-managed embeddings."
|
||||
)
|
||||
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
query_vector,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filters=filters,
|
||||
)
|
||||
return docs
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filters: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filters: Filters to apply to the query. Defaults to None.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if not self._is_databricks_managed_embeddings():
|
||||
embedding_column = self._embedding_vector_column_name()
|
||||
else:
|
||||
raise ValueError(
|
||||
"`max_marginal_relevance_search` is not supported for index with "
|
||||
"Databricks-managed embeddings."
|
||||
)
|
||||
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=list(set(self.columns + [embedding_column])),
|
||||
query_text=None,
|
||||
query_vector=embedding,
|
||||
filters=filters,
|
||||
num_results=fetch_k,
|
||||
)
|
||||
|
||||
embeddings_result_index = (
|
||||
search_resp.get("manifest").get("columns").index({"name": embedding_column})
|
||||
)
|
||||
embeddings = [
|
||||
doc[embeddings_result_index]
|
||||
for doc in search_resp.get("result").get("data_array")
|
||||
]
|
||||
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
ignore_cols: List = (
|
||||
[embedding_column] if embedding_column not in self.columns else []
|
||||
)
|
||||
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
|
||||
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
|
||||
return selected_results
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
@ -373,8 +496,13 @@ class DatabricksVectorSearch(VectorStore):
|
||||
)
|
||||
return self._parse_search_response(search_resp)
|
||||
|
||||
def _parse_search_response(self, search_resp: dict) -> List[Tuple[Document, float]]:
|
||||
def _parse_search_response(
|
||||
self, search_resp: dict, ignore_cols: Optional[List[str]] = None
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Parse the search response into a list of Documents with score."""
|
||||
if ignore_cols is None:
|
||||
ignore_cols = []
|
||||
|
||||
columns = [
|
||||
col["name"]
|
||||
for col in search_resp.get("manifest", dict()).get("columns", [])
|
||||
@ -386,7 +514,7 @@ class DatabricksVectorSearch(VectorStore):
|
||||
metadata = {
|
||||
col: value
|
||||
for col, value in zip(columns[:-1], result[:-1])
|
||||
if col not in [self.primary_key, self.text_column]
|
||||
if col not in ([self.primary_key, self.text_column] + ignore_cols)
|
||||
}
|
||||
metadata[self.primary_key] = doc_id
|
||||
score = result[-1]
|
||||
|
@ -1,7 +1,8 @@
|
||||
import itertools
|
||||
import random
|
||||
import uuid
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
from typing import Dict, List, Optional, Set
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -120,6 +121,52 @@ EXAMPLE_SEARCH_RESPONSE = {
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE: Dict = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": DEFAULT_PRIMARY_KEY},
|
||||
{"name": DEFAULT_TEXT_COLUMN},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": len(fake_texts),
|
||||
"data_array": sorted(
|
||||
[[str(uuid.uuid4()), s, 0.5] for s in fake_texts],
|
||||
key=lambda x: x[2], # type: ignore
|
||||
reverse=True,
|
||||
),
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": DEFAULT_PRIMARY_KEY},
|
||||
{"name": DEFAULT_TEXT_COLUMN},
|
||||
{"name": DEFAULT_VECTOR_COLUMN},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": len(fake_texts),
|
||||
"data_array": sorted(
|
||||
[
|
||||
[str(uuid.uuid4()), s, e, random.uniform(0, 1)]
|
||||
for s, e in zip(
|
||||
fake_texts, DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
|
||||
)
|
||||
],
|
||||
key=lambda x: x[2], # type: ignore
|
||||
reverse=True,
|
||||
),
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
|
||||
def mock_index(index_details: dict) -> MagicMock:
|
||||
from databricks.vector_search.client import VectorSearchIndex
|
||||
@ -129,11 +176,14 @@ def mock_index(index_details: dict) -> MagicMock:
|
||||
return index
|
||||
|
||||
|
||||
def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch:
|
||||
def default_databricks_vector_search(
|
||||
index: MagicMock, columns: Optional[List[str]] = None
|
||||
) -> DatabricksVectorSearch:
|
||||
return DatabricksVectorSearch(
|
||||
index,
|
||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||
text_column=DEFAULT_TEXT_COLUMN,
|
||||
columns=columns,
|
||||
)
|
||||
|
||||
|
||||
@ -456,6 +506,100 @@ def test_similarity_search(index_details: dict) -> None:
|
||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details, columns, expected_columns",
|
||||
[
|
||||
(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, None, {"id"}),
|
||||
(DIRECT_ACCESS_INDEX, None, {"id"}),
|
||||
(
|
||||
DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS,
|
||||
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
|
||||
{"text_vector", "id"},
|
||||
),
|
||||
(
|
||||
DIRECT_ACCESS_INDEX,
|
||||
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
|
||||
{"text_vector", "id"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mmr_search(
|
||||
index_details: dict, columns: Optional[List[str]], expected_columns: Set[str]
|
||||
) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
|
||||
vectorsearch = default_databricks_vector_search(index, columns)
|
||||
query = fake_texts[0]
|
||||
filters = {"some filter": True}
|
||||
limit = 1
|
||||
|
||||
search_result = vectorsearch.max_marginal_relevance_search(
|
||||
query, k=limit, filters=filters
|
||||
)
|
||||
assert [doc.page_content for doc in search_result] == [fake_texts[0]]
|
||||
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
)
|
||||
def test_mmr_parameters(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
|
||||
query = fake_texts[0]
|
||||
limit = 1
|
||||
fetch_k = 3
|
||||
lambda_mult = 0.25
|
||||
filters = {"some filter": True}
|
||||
|
||||
with patch(
|
||||
"langchain_community.vectorstores.databricks_vector_search.maximal_marginal_relevance"
|
||||
) as mock_mmr:
|
||||
mock_mmr.return_value = [2]
|
||||
retriever = default_databricks_vector_search(index).as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={
|
||||
"k": limit,
|
||||
"fetch_k": fetch_k,
|
||||
"lambda_mult": lambda_mult,
|
||||
"filters": filters,
|
||||
},
|
||||
)
|
||||
search_result = retriever.get_relevant_documents(query)
|
||||
|
||||
mock_mmr.assert_called_once()
|
||||
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
|
||||
assert index.similarity_search.call_args[1]["num_results"] == fetch_k
|
||||
assert index.similarity_search.call_args[1]["filters"] == filters
|
||||
assert len(search_result) == limit
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details, threshold", itertools.product(ALL_INDEXES, [0.4, 0.5, 0.8])
|
||||
)
|
||||
def test_similarity_score_threshold(index_details: dict, threshold: float) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE
|
||||
uniform_response_score = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE["result"][
|
||||
"data_array"
|
||||
][0][2]
|
||||
query = fake_texts[0]
|
||||
limit = len(fake_texts)
|
||||
|
||||
retriever = default_databricks_vector_search(index).as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": limit, "score_threshold": threshold},
|
||||
)
|
||||
search_result = retriever.get_relevant_documents(query)
|
||||
if uniform_response_score >= threshold:
|
||||
assert len(search_result) == len(fake_texts)
|
||||
else:
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
|
Loading…
Reference in New Issue
Block a user