From 93da18b667091aaa242b7e8b2c2003cdcda535e2 Mon Sep 17 00:00:00 2001 From: david-tempelmann <89131897+david-tempelmann@users.noreply.github.com> Date: Mon, 12 Feb 2024 21:51:37 +0100 Subject: [PATCH] community[minor]: Add mmr and similarity_score_threshold retrieval to DatabricksVectorSearch (#16829) - **Description:** This PR adds support for `search_types="mmr"` and `search_type="similarity_score_threshold"` to retrievers using `DatabricksVectorSearch`, - **Issue:** - **Dependencies:** - **Twitter handle:** --------- Co-authored-by: Bagatur --- .../vectorstores/databricks_vector_search.py | 134 +++++++++++++++- .../test_databricks_vector_search.py | 150 +++++++++++++++++- 2 files changed, 278 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/databricks_vector_search.py b/libs/community/langchain_community/vectorstores/databricks_vector_search.py index 8505ea9cbc0..0a5e7a6d514 100644 --- a/libs/community/langchain_community/vectorstores/databricks_vector_search.py +++ b/libs/community/langchain_community/vectorstores/databricks_vector_search.py @@ -3,12 +3,15 @@ from __future__ import annotations import json import logging import uuid -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VST, VectorStore +from langchain_community.vectorstores.utils import maximal_marginal_relevance + if TYPE_CHECKING: from databricks.vector_search.client import VectorSearchIndex @@ -321,6 +324,126 @@ class DatabricksVectorSearch(VectorStore): ) return self._parse_search_response(search_resp) + @staticmethod + def _identity_fn(score: float) -> float: + return score + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + Databricks Vector search uses a normalized score 1/(1+d) where d + is the L2 distance. Hence, we simply return the identity function. + """ + + return self._identity_fn + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filters: Optional[Any] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filters: Filters to apply to the query. Defaults to None. + Returns: + List of Documents selected by maximal marginal relevance. + """ + if not self._is_databricks_managed_embeddings(): + assert self.embeddings is not None, "embedding model is required." + query_vector = self.embeddings.embed_query(query) + else: + raise ValueError( + "`max_marginal_relevance_search` is not supported for index with " + "Databricks-managed embeddings." + ) + + docs = self.max_marginal_relevance_search_by_vector( + query_vector, + k, + fetch_k, + lambda_mult=lambda_mult, + filters=filters, + ) + return docs + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filters: Optional[Any] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filters: Filters to apply to the query. Defaults to None. + Returns: + List of Documents selected by maximal marginal relevance. + """ + if not self._is_databricks_managed_embeddings(): + embedding_column = self._embedding_vector_column_name() + else: + raise ValueError( + "`max_marginal_relevance_search` is not supported for index with " + "Databricks-managed embeddings." + ) + + search_resp = self.index.similarity_search( + columns=list(set(self.columns + [embedding_column])), + query_text=None, + query_vector=embedding, + filters=filters, + num_results=fetch_k, + ) + + embeddings_result_index = ( + search_resp.get("manifest").get("columns").index({"name": embedding_column}) + ) + embeddings = [ + doc[embeddings_result_index] + for doc in search_resp.get("result").get("data_array") + ] + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embeddings, + k=k, + lambda_mult=lambda_mult, + ) + + ignore_cols: List = ( + [embedding_column] if embedding_column not in self.columns else [] + ) + candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols) + selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected] + return selected_results + def similarity_search_by_vector( self, embedding: List[float], @@ -373,8 +496,13 @@ class DatabricksVectorSearch(VectorStore): ) return self._parse_search_response(search_resp) - def _parse_search_response(self, search_resp: dict) -> List[Tuple[Document, float]]: + def _parse_search_response( + self, search_resp: dict, ignore_cols: Optional[List[str]] = None + ) -> List[Tuple[Document, float]]: """Parse the search response into a list of Documents with score.""" + if ignore_cols is None: + ignore_cols = [] + columns = [ col["name"] for col in search_resp.get("manifest", dict()).get("columns", []) @@ -386,7 +514,7 @@ class DatabricksVectorSearch(VectorStore): metadata = { col: value for col, value in zip(columns[:-1], result[:-1]) - if col not in [self.primary_key, self.text_column] + if col not in ([self.primary_key, self.text_column] + ignore_cols) } metadata[self.primary_key] = doc_id score = result[-1] diff --git a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py index 1a03890f637..4bdcee9acfb 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py @@ -1,7 +1,8 @@ +import itertools import random import uuid -from typing import List -from unittest.mock import MagicMock +from typing import Dict, List, Optional, Set +from unittest.mock import MagicMock, patch import pytest @@ -120,6 +121,52 @@ EXAMPLE_SEARCH_RESPONSE = { "next_page_token": "", } +EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE: Dict = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": DEFAULT_PRIMARY_KEY}, + {"name": DEFAULT_TEXT_COLUMN}, + {"name": "score"}, + ], + }, + "result": { + "row_count": len(fake_texts), + "data_array": sorted( + [[str(uuid.uuid4()), s, 0.5] for s in fake_texts], + key=lambda x: x[2], # type: ignore + reverse=True, + ), + }, + "next_page_token": "", +} + +EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": DEFAULT_PRIMARY_KEY}, + {"name": DEFAULT_TEXT_COLUMN}, + {"name": DEFAULT_VECTOR_COLUMN}, + {"name": "score"}, + ], + }, + "result": { + "row_count": len(fake_texts), + "data_array": sorted( + [ + [str(uuid.uuid4()), s, e, random.uniform(0, 1)] + for s, e in zip( + fake_texts, DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) + ) + ], + key=lambda x: x[2], # type: ignore + reverse=True, + ), + }, + "next_page_token": "", +} + def mock_index(index_details: dict) -> MagicMock: from databricks.vector_search.client import VectorSearchIndex @@ -129,11 +176,14 @@ def mock_index(index_details: dict) -> MagicMock: return index -def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch: +def default_databricks_vector_search( + index: MagicMock, columns: Optional[List[str]] = None +) -> DatabricksVectorSearch: return DatabricksVectorSearch( index, embedding=DEFAULT_EMBEDDING_MODEL, text_column=DEFAULT_TEXT_COLUMN, + columns=columns, ) @@ -456,6 +506,100 @@ def test_similarity_search(index_details: dict) -> None: assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details, columns, expected_columns", + [ + (DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, None, {"id"}), + (DIRECT_ACCESS_INDEX, None, {"id"}), + ( + DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, + [DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN], + {"text_vector", "id"}, + ), + ( + DIRECT_ACCESS_INDEX, + [DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN], + {"text_vector", "id"}, + ), + ], +) +def test_mmr_search( + index_details: dict, columns: Optional[List[str]], expected_columns: Set[str] +) -> None: + index = mock_index(index_details) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING + vectorsearch = default_databricks_vector_search(index, columns) + query = fake_texts[0] + filters = {"some filter": True} + limit = 1 + + search_result = vectorsearch.max_marginal_relevance_search( + query, k=limit, filters=filters + ) + assert [doc.page_content for doc in search_result] == [fake_texts[0]] + assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns] + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] +) +def test_mmr_parameters(index_details: dict) -> None: + index = mock_index(index_details) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING + query = fake_texts[0] + limit = 1 + fetch_k = 3 + lambda_mult = 0.25 + filters = {"some filter": True} + + with patch( + "langchain_community.vectorstores.databricks_vector_search.maximal_marginal_relevance" + ) as mock_mmr: + mock_mmr.return_value = [2] + retriever = default_databricks_vector_search(index).as_retriever( + search_type="mmr", + search_kwargs={ + "k": limit, + "fetch_k": fetch_k, + "lambda_mult": lambda_mult, + "filters": filters, + }, + ) + search_result = retriever.get_relevant_documents(query) + + mock_mmr.assert_called_once() + assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult + assert index.similarity_search.call_args[1]["num_results"] == fetch_k + assert index.similarity_search.call_args[1]["filters"] == filters + assert len(search_result) == limit + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details, threshold", itertools.product(ALL_INDEXES, [0.4, 0.5, 0.8]) +) +def test_similarity_score_threshold(index_details: dict, threshold: float) -> None: + index = mock_index(index_details) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE + uniform_response_score = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE["result"][ + "data_array" + ][0][2] + query = fake_texts[0] + limit = len(fake_texts) + + retriever = default_databricks_vector_search(index).as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": limit, "score_threshold": threshold}, + ) + search_result = retriever.get_relevant_documents(query) + if uniform_response_score >= threshold: + assert len(search_result) == len(fake_texts) + else: + assert len(search_result) == 0 + + @pytest.mark.requires("databricks", "databricks.vector_search") @pytest.mark.parametrize( "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]