mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
community[minor]: Add mmr and similarity_score_threshold retrieval to DatabricksVectorSearch (#16829)
- **Description:** This PR adds support for `search_types="mmr"` and `search_type="similarity_score_threshold"` to retrievers using `DatabricksVectorSearch`, - **Issue:** - **Dependencies:** - **Twitter handle:** --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
42648061ad
commit
93da18b667
@ -3,12 +3,15 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
|
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VST, VectorStore
|
from langchain_core.vectorstores import VST, VectorStore
|
||||||
|
|
||||||
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from databricks.vector_search.client import VectorSearchIndex
|
from databricks.vector_search.client import VectorSearchIndex
|
||||||
|
|
||||||
@ -321,6 +324,126 @@ class DatabricksVectorSearch(VectorStore):
|
|||||||
)
|
)
|
||||||
return self._parse_search_response(search_resp)
|
return self._parse_search_response(search_resp)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _identity_fn(score: float) -> float:
|
||||||
|
return score
|
||||||
|
|
||||||
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
|
"""
|
||||||
|
Databricks Vector search uses a normalized score 1/(1+d) where d
|
||||||
|
is the L2 distance. Hence, we simply return the identity function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._identity_fn
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
filters: Optional[Any] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
filters: Filters to apply to the query. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
if not self._is_databricks_managed_embeddings():
|
||||||
|
assert self.embeddings is not None, "embedding model is required."
|
||||||
|
query_vector = self.embeddings.embed_query(query)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"`max_marginal_relevance_search` is not supported for index with "
|
||||||
|
"Databricks-managed embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = self.max_marginal_relevance_search_by_vector(
|
||||||
|
query_vector,
|
||||||
|
k,
|
||||||
|
fetch_k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
filters=filters,
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
filters: Optional[Any] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
filters: Filters to apply to the query. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
if not self._is_databricks_managed_embeddings():
|
||||||
|
embedding_column = self._embedding_vector_column_name()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"`max_marginal_relevance_search` is not supported for index with "
|
||||||
|
"Databricks-managed embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
|
search_resp = self.index.similarity_search(
|
||||||
|
columns=list(set(self.columns + [embedding_column])),
|
||||||
|
query_text=None,
|
||||||
|
query_vector=embedding,
|
||||||
|
filters=filters,
|
||||||
|
num_results=fetch_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_result_index = (
|
||||||
|
search_resp.get("manifest").get("columns").index({"name": embedding_column})
|
||||||
|
)
|
||||||
|
embeddings = [
|
||||||
|
doc[embeddings_result_index]
|
||||||
|
for doc in search_resp.get("result").get("data_array")
|
||||||
|
]
|
||||||
|
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array(embedding, dtype=np.float32),
|
||||||
|
embeddings,
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
|
||||||
|
ignore_cols: List = (
|
||||||
|
[embedding_column] if embedding_column not in self.columns else []
|
||||||
|
)
|
||||||
|
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
|
||||||
|
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
|
||||||
|
return selected_results
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
@ -373,8 +496,13 @@ class DatabricksVectorSearch(VectorStore):
|
|||||||
)
|
)
|
||||||
return self._parse_search_response(search_resp)
|
return self._parse_search_response(search_resp)
|
||||||
|
|
||||||
def _parse_search_response(self, search_resp: dict) -> List[Tuple[Document, float]]:
|
def _parse_search_response(
|
||||||
|
self, search_resp: dict, ignore_cols: Optional[List[str]] = None
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
"""Parse the search response into a list of Documents with score."""
|
"""Parse the search response into a list of Documents with score."""
|
||||||
|
if ignore_cols is None:
|
||||||
|
ignore_cols = []
|
||||||
|
|
||||||
columns = [
|
columns = [
|
||||||
col["name"]
|
col["name"]
|
||||||
for col in search_resp.get("manifest", dict()).get("columns", [])
|
for col in search_resp.get("manifest", dict()).get("columns", [])
|
||||||
@ -386,7 +514,7 @@ class DatabricksVectorSearch(VectorStore):
|
|||||||
metadata = {
|
metadata = {
|
||||||
col: value
|
col: value
|
||||||
for col, value in zip(columns[:-1], result[:-1])
|
for col, value in zip(columns[:-1], result[:-1])
|
||||||
if col not in [self.primary_key, self.text_column]
|
if col not in ([self.primary_key, self.text_column] + ignore_cols)
|
||||||
}
|
}
|
||||||
metadata[self.primary_key] = doc_id
|
metadata[self.primary_key] = doc_id
|
||||||
score = result[-1]
|
score = result[-1]
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
|
import itertools
|
||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List
|
from typing import Dict, List, Optional, Set
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -120,6 +121,52 @@ EXAMPLE_SEARCH_RESPONSE = {
|
|||||||
"next_page_token": "",
|
"next_page_token": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE: Dict = {
|
||||||
|
"manifest": {
|
||||||
|
"column_count": 3,
|
||||||
|
"columns": [
|
||||||
|
{"name": DEFAULT_PRIMARY_KEY},
|
||||||
|
{"name": DEFAULT_TEXT_COLUMN},
|
||||||
|
{"name": "score"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"result": {
|
||||||
|
"row_count": len(fake_texts),
|
||||||
|
"data_array": sorted(
|
||||||
|
[[str(uuid.uuid4()), s, 0.5] for s in fake_texts],
|
||||||
|
key=lambda x: x[2], # type: ignore
|
||||||
|
reverse=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"next_page_token": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING = {
|
||||||
|
"manifest": {
|
||||||
|
"column_count": 3,
|
||||||
|
"columns": [
|
||||||
|
{"name": DEFAULT_PRIMARY_KEY},
|
||||||
|
{"name": DEFAULT_TEXT_COLUMN},
|
||||||
|
{"name": DEFAULT_VECTOR_COLUMN},
|
||||||
|
{"name": "score"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"result": {
|
||||||
|
"row_count": len(fake_texts),
|
||||||
|
"data_array": sorted(
|
||||||
|
[
|
||||||
|
[str(uuid.uuid4()), s, e, random.uniform(0, 1)]
|
||||||
|
for s, e in zip(
|
||||||
|
fake_texts, DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
key=lambda x: x[2], # type: ignore
|
||||||
|
reverse=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"next_page_token": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def mock_index(index_details: dict) -> MagicMock:
|
def mock_index(index_details: dict) -> MagicMock:
|
||||||
from databricks.vector_search.client import VectorSearchIndex
|
from databricks.vector_search.client import VectorSearchIndex
|
||||||
@ -129,11 +176,14 @@ def mock_index(index_details: dict) -> MagicMock:
|
|||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch:
|
def default_databricks_vector_search(
|
||||||
|
index: MagicMock, columns: Optional[List[str]] = None
|
||||||
|
) -> DatabricksVectorSearch:
|
||||||
return DatabricksVectorSearch(
|
return DatabricksVectorSearch(
|
||||||
index,
|
index,
|
||||||
embedding=DEFAULT_EMBEDDING_MODEL,
|
embedding=DEFAULT_EMBEDDING_MODEL,
|
||||||
text_column=DEFAULT_TEXT_COLUMN,
|
text_column=DEFAULT_TEXT_COLUMN,
|
||||||
|
columns=columns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -456,6 +506,100 @@ def test_similarity_search(index_details: dict) -> None:
|
|||||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"index_details, columns, expected_columns",
|
||||||
|
[
|
||||||
|
(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, None, {"id"}),
|
||||||
|
(DIRECT_ACCESS_INDEX, None, {"id"}),
|
||||||
|
(
|
||||||
|
DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS,
|
||||||
|
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
|
||||||
|
{"text_vector", "id"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
DIRECT_ACCESS_INDEX,
|
||||||
|
[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN],
|
||||||
|
{"text_vector", "id"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mmr_search(
|
||||||
|
index_details: dict, columns: Optional[List[str]], expected_columns: Set[str]
|
||||||
|
) -> None:
|
||||||
|
index = mock_index(index_details)
|
||||||
|
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
|
||||||
|
vectorsearch = default_databricks_vector_search(index, columns)
|
||||||
|
query = fake_texts[0]
|
||||||
|
filters = {"some filter": True}
|
||||||
|
limit = 1
|
||||||
|
|
||||||
|
search_result = vectorsearch.max_marginal_relevance_search(
|
||||||
|
query, k=limit, filters=filters
|
||||||
|
)
|
||||||
|
assert [doc.page_content for doc in search_result] == [fake_texts[0]]
|
||||||
|
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||||
|
)
|
||||||
|
def test_mmr_parameters(index_details: dict) -> None:
|
||||||
|
index = mock_index(index_details)
|
||||||
|
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING
|
||||||
|
query = fake_texts[0]
|
||||||
|
limit = 1
|
||||||
|
fetch_k = 3
|
||||||
|
lambda_mult = 0.25
|
||||||
|
filters = {"some filter": True}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain_community.vectorstores.databricks_vector_search.maximal_marginal_relevance"
|
||||||
|
) as mock_mmr:
|
||||||
|
mock_mmr.return_value = [2]
|
||||||
|
retriever = default_databricks_vector_search(index).as_retriever(
|
||||||
|
search_type="mmr",
|
||||||
|
search_kwargs={
|
||||||
|
"k": limit,
|
||||||
|
"fetch_k": fetch_k,
|
||||||
|
"lambda_mult": lambda_mult,
|
||||||
|
"filters": filters,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
search_result = retriever.get_relevant_documents(query)
|
||||||
|
|
||||||
|
mock_mmr.assert_called_once()
|
||||||
|
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
|
||||||
|
assert index.similarity_search.call_args[1]["num_results"] == fetch_k
|
||||||
|
assert index.similarity_search.call_args[1]["filters"] == filters
|
||||||
|
assert len(search_result) == limit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"index_details, threshold", itertools.product(ALL_INDEXES, [0.4, 0.5, 0.8])
|
||||||
|
)
|
||||||
|
def test_similarity_score_threshold(index_details: dict, threshold: float) -> None:
|
||||||
|
index = mock_index(index_details)
|
||||||
|
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE
|
||||||
|
uniform_response_score = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE["result"][
|
||||||
|
"data_array"
|
||||||
|
][0][2]
|
||||||
|
query = fake_texts[0]
|
||||||
|
limit = len(fake_texts)
|
||||||
|
|
||||||
|
retriever = default_databricks_vector_search(index).as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={"k": limit, "score_threshold": threshold},
|
||||||
|
)
|
||||||
|
search_result = retriever.get_relevant_documents(query)
|
||||||
|
if uniform_response_score >= threshold:
|
||||||
|
assert len(search_result) == len(fake_texts)
|
||||||
|
else:
|
||||||
|
assert len(search_result) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||||
|
Loading…
Reference in New Issue
Block a user