From 2b6a262f84502d5ff165dd5316f3ef8b7bc095f8 Mon Sep 17 00:00:00 2001 From: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com> Date: Fri, 26 Jul 2024 13:20:18 +0900 Subject: [PATCH] community[patch]: Replace `filters` argument to `filter` in DatabricksVectorSearch (#24530) The [DatabricksVectorSearch](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/vectorstores/databricks_vector_search.py#L21) class exposes similarity search APIs with argument `filters`, which is inconsistent with other VS classes who uses `filter` (singular). This PR updates the argument and add alias for backward compatibility. --------- Signed-off-by: B-Step62 --- .../vectorstores/databricks_vector_search.py | 88 +++++++++++++------ .../test_databricks_vector_search.py | 56 +++++++++++- 2 files changed, 114 insertions(+), 30 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/databricks_vector_search.py b/libs/community/langchain_community/vectorstores/databricks_vector_search.py index dd76b61797a..556f7c7878a 100644 --- a/libs/community/langchain_community/vectorstores/databricks_vector_search.py +++ b/libs/community/langchain_community/vectorstores/databricks_vector_search.py @@ -3,9 +3,20 @@ from __future__ import annotations import json import logging import uuid -from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, +) import numpy as np +from langchain_core._api import warn_deprecated from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VST, VectorStore @@ -193,7 +204,7 @@ class DatabricksVectorSearch(VectorStore): cls: Type[VST], texts: List[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, + metadatas: Optional[List[Dict]] = None, **kwargs: Any, ) -> VST: raise NotImplementedError( @@ -204,7 +215,7 @@ class DatabricksVectorSearch(VectorStore): def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[List[Dict]] = None, ids: Optional[List[Any]] = None, **kwargs: Any, ) -> List[str]: @@ -280,7 +291,7 @@ class DatabricksVectorSearch(VectorStore): self, query: str, k: int = 4, - filters: Optional[Any] = None, + filter: Optional[Dict[str, Any]] = None, *, query_type: Optional[str] = None, **kwargs: Any, @@ -290,14 +301,18 @@ class DatabricksVectorSearch(VectorStore): Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filters: Filters to apply to the query. Defaults to None. + filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: List of Documents most similar to the embedding. """ docs_with_score = self.similarity_search_with_score( - query=query, k=k, filters=filters, query_type=query_type, **kwargs + query=query, + k=k, + filter=filter, + query_type=query_type, + **kwargs, ) return [doc for doc, _ in docs_with_score] @@ -305,7 +320,7 @@ class DatabricksVectorSearch(VectorStore): self, query: str, k: int = 4, - filters: Optional[Any] = None, + filter: Optional[Dict[str, Any]] = None, *, query_type: Optional[str] = None, **kwargs: Any, @@ -315,7 +330,7 @@ class DatabricksVectorSearch(VectorStore): Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filters: Filters to apply to the query. Defaults to None. + filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: @@ -328,12 +343,11 @@ class DatabricksVectorSearch(VectorStore): assert self.embeddings is not None, "embedding model is required." query_text = None query_vector = self.embeddings.embed_query(query) - search_resp = self.index.similarity_search( columns=self.columns, query_text=query_text, query_vector=query_vector, - filters=filters, + filters=filter or _alias_filters(kwargs), num_results=k, query_type=query_type, ) @@ -357,7 +371,7 @@ class DatabricksVectorSearch(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filters: Optional[Any] = None, + filter: Optional[Dict[str, Any]] = None, *, query_type: Optional[str] = None, **kwargs: Any, @@ -375,7 +389,7 @@ class DatabricksVectorSearch(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - filters: Filters to apply to the query. Defaults to None. + filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: List of Documents selected by maximal marginal relevance. @@ -394,7 +408,7 @@ class DatabricksVectorSearch(VectorStore): k, fetch_k, lambda_mult=lambda_mult, - filters=filters, + filter=filter or _alias_filters(kwargs), query_type=query_type, ) return docs @@ -405,7 +419,7 @@ class DatabricksVectorSearch(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filters: Optional[Any] = None, + filter: Optional[Any] = None, *, query_type: Optional[str] = None, **kwargs: Any, @@ -423,7 +437,7 @@ class DatabricksVectorSearch(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - filters: Filters to apply to the query. Defaults to None. + filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: List of Documents selected by maximal marginal relevance. @@ -435,12 +449,11 @@ class DatabricksVectorSearch(VectorStore): "`max_marginal_relevance_search` is not supported for index with " "Databricks-managed embeddings." ) - search_resp = self.index.similarity_search( columns=list(set(self.columns + [embedding_column])), query_text=None, query_vector=embedding, - filters=filters, + filters=filter or _alias_filters(kwargs), num_results=fetch_k, query_type=query_type, ) @@ -471,7 +484,7 @@ class DatabricksVectorSearch(VectorStore): self, embedding: List[float], k: int = 4, - filters: Optional[Any] = None, + filter: Optional[Any] = None, *, query_type: Optional[str] = None, **kwargs: Any, @@ -481,14 +494,18 @@ class DatabricksVectorSearch(VectorStore): Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filters: Filters to apply to the query. Defaults to None. + filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: List of Documents most similar to the embedding. """ docs_with_score = self.similarity_search_by_vector_with_score( - embedding=embedding, k=k, filters=filters, query_type=query_type, **kwargs + embedding=embedding, + k=k, + filter=filter, + query_type=query_type, + **kwargs, ) return [doc for doc, _ in docs_with_score] @@ -496,7 +513,7 @@ class DatabricksVectorSearch(VectorStore): self, embedding: List[float], k: int = 4, - filters: Optional[Any] = None, + filter: Optional[Any] = None, *, query_type: Optional[str] = None, **kwargs: Any, @@ -506,7 +523,7 @@ class DatabricksVectorSearch(VectorStore): Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. - filters: Filters to apply to the query. Defaults to None. + filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: @@ -520,14 +537,14 @@ class DatabricksVectorSearch(VectorStore): search_resp = self.index.similarity_search( columns=self.columns, query_vector=embedding, - filters=filters, + filters=filter or _alias_filters(kwargs), num_results=k, query_type=query_type, ) return self._parse_search_response(search_resp) def _parse_search_response( - self, search_resp: dict, ignore_cols: Optional[List[str]] = None + self, search_resp: Dict, ignore_cols: Optional[List[str]] = None ) -> List[Tuple[Document, float]]: """Parse the search response into a list of Documents with score.""" if ignore_cols is None: @@ -552,7 +569,7 @@ class DatabricksVectorSearch(VectorStore): docs_with_score.append((doc, score)) return docs_with_score - def _index_schema(self) -> Optional[dict]: + def _index_schema(self) -> Optional[Dict]: """Return the index schema as a dictionary. Return None if no schema found. """ @@ -574,7 +591,7 @@ class DatabricksVectorSearch(VectorStore): """ return self._embedding_vector_column().get("embedding_dimension") - def _embedding_vector_column(self) -> dict: + def _embedding_vector_column(self) -> Dict: """Return the embedding vector column configs as a dictionary. Empty if the index is not a self-managed embedding index. """ @@ -591,7 +608,7 @@ class DatabricksVectorSearch(VectorStore): """ return self._embedding_source_column().get("name") - def _embedding_source_column(self) -> dict: + def _embedding_source_column(self) -> Dict: """Return the embedding source column configs as a dictionary. Empty if the index is not a Databricks-managed embedding index. """ @@ -629,3 +646,20 @@ class DatabricksVectorSearch(VectorStore): """Raise ValueError if the required arg with name `arg_name` is None.""" if not arg: raise ValueError(f"`{arg_name}` is required for this index.") + + +def _alias_filters(kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + The `filters` argument was used in the previous versions. It is now + replaced with `filter` for consistency with other vector stores, but + we still support `filters` for backward compatibility. + """ + if "filters" in kwargs: + warn_deprecated( + since="0.2.11", + removal="0.3", + message="DatabricksVectorSearch received a key `filters` in search_kwargs. " + "`filters` was deprecated since langchain-community 0.2.11 and will " + "be removed in 0.3. Please use `filter` instead.", + ) + return kwargs.pop("filters", None) diff --git a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py index 703a6e84a82..75c040d126b 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py @@ -493,7 +493,7 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No limit = 7 search_result = vectorsearch.similarity_search( - query, k=limit, filters=filters, query_type=query_type + query, k=limit, filter=filters, query_type=query_type ) if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS: index.similarity_search.assert_called_once_with( @@ -518,6 +518,27 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_similarity_search_both_filter_and_filters_passed() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + vectorsearch = default_databricks_vector_search(index) + query = "foo" + filter = {"some filter": True} + filters = {"some other filter": False} + + vectorsearch.similarity_search(query, filter=filter, filters=filters) + index.similarity_search.assert_called_once_with( + columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], + query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query), + # `filter` should prevail over `filters` + filters=filter, + num_results=4, + query_text=None, + query_type=None, + ) + + @pytest.mark.requires("databricks", "databricks.vector_search") @pytest.mark.parametrize( "index_details, columns, expected_columns", @@ -576,7 +597,7 @@ def test_mmr_parameters(index_details: dict) -> None: "k": limit, "fetch_k": fetch_k, "lambda_mult": lambda_mult, - "filters": filters, + "filter": filters, }, ) search_result = retriever.invoke(query) @@ -625,7 +646,7 @@ def test_similarity_search_by_vector(index_details: dict) -> None: limit = 7 search_result = vectorsearch.similarity_search_by_vector( - query_embedding, k=limit, filters=filters + query_embedding, k=limit, filter=filters ) index.similarity_search.assert_called_once_with( columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], @@ -681,3 +702,32 @@ def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> No "`similarity_search_by_vector` is not supported for index with " "Databricks-managed embeddings." in str(ex.value) ) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "method", + [ + "similarity_search", + "similarity_search_with_score", + "similarity_search_by_vector", + "similarity_search_by_vector_with_score", + "max_marginal_relevance_search", + "max_marginal_relevance_search_by_vector", + ], +) +def test_filter_arg_alias(method: str) -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorsearch = default_databricks_vector_search(index) + query = "foo" + query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + if "by_vector" in method: + getattr(vectorsearch, method)(query_embedding, k=limit, filters=filters) + else: + getattr(vectorsearch, method)(query, k=limit, filters=filters) + + index_call_args = index.similarity_search.call_args[1] + assert index_call_args["filters"] == filters