community[patch]: Replace filters argument to filter in DatabricksVectorSearch (#24530)

The
[DatabricksVectorSearch](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/vectorstores/databricks_vector_search.py#L21)
class exposes similarity search APIs with argument `filters`, which is
inconsistent with other VS classes who uses `filter` (singular). This PR
updates the argument and add alias for backward compatibility.

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
This commit is contained in:
Yuki Watanabe 2024-07-26 13:20:18 +09:00 committed by GitHub
parent 148766ddc1
commit 2b6a262f84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 114 additions and 30 deletions

View File

@ -3,9 +3,20 @@ from __future__ import annotations
import json import json
import logging import logging
import uuid import uuid
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
)
import numpy as np import numpy as np
from langchain_core._api import warn_deprecated
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore from langchain_core.vectorstores import VST, VectorStore
@ -193,7 +204,7 @@ class DatabricksVectorSearch(VectorStore):
cls: Type[VST], cls: Type[VST],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[Dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
raise NotImplementedError( raise NotImplementedError(
@ -204,7 +215,7 @@ class DatabricksVectorSearch(VectorStore):
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[Dict]] = None,
ids: Optional[List[Any]] = None, ids: Optional[List[Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
@ -280,7 +291,7 @@ class DatabricksVectorSearch(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filters: Optional[Any] = None, filter: Optional[Dict[str, Any]] = None,
*, *,
query_type: Optional[str] = None, query_type: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -290,14 +301,18 @@ class DatabricksVectorSearch(VectorStore):
Args: Args:
query: Text to look up documents similar to. query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filters: Filters to apply to the query. Defaults to None. filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID". query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns: Returns:
List of Documents most similar to the embedding. List of Documents most similar to the embedding.
""" """
docs_with_score = self.similarity_search_with_score( docs_with_score = self.similarity_search_with_score(
query=query, k=k, filters=filters, query_type=query_type, **kwargs query=query,
k=k,
filter=filter,
query_type=query_type,
**kwargs,
) )
return [doc for doc, _ in docs_with_score] return [doc for doc, _ in docs_with_score]
@ -305,7 +320,7 @@ class DatabricksVectorSearch(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filters: Optional[Any] = None, filter: Optional[Dict[str, Any]] = None,
*, *,
query_type: Optional[str] = None, query_type: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -315,7 +330,7 @@ class DatabricksVectorSearch(VectorStore):
Args: Args:
query: Text to look up documents similar to. query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filters: Filters to apply to the query. Defaults to None. filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID". query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns: Returns:
@ -328,12 +343,11 @@ class DatabricksVectorSearch(VectorStore):
assert self.embeddings is not None, "embedding model is required." assert self.embeddings is not None, "embedding model is required."
query_text = None query_text = None
query_vector = self.embeddings.embed_query(query) query_vector = self.embeddings.embed_query(query)
search_resp = self.index.similarity_search( search_resp = self.index.similarity_search(
columns=self.columns, columns=self.columns,
query_text=query_text, query_text=query_text,
query_vector=query_vector, query_vector=query_vector,
filters=filters, filters=filter or _alias_filters(kwargs),
num_results=k, num_results=k,
query_type=query_type, query_type=query_type,
) )
@ -357,7 +371,7 @@ class DatabricksVectorSearch(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filters: Optional[Any] = None, filter: Optional[Dict[str, Any]] = None,
*, *,
query_type: Optional[str] = None, query_type: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -375,7 +389,7 @@ class DatabricksVectorSearch(VectorStore):
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
Defaults to 0.5. Defaults to 0.5.
filters: Filters to apply to the query. Defaults to None. filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID". query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
@ -394,7 +408,7 @@ class DatabricksVectorSearch(VectorStore):
k, k,
fetch_k, fetch_k,
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
filters=filters, filter=filter or _alias_filters(kwargs),
query_type=query_type, query_type=query_type,
) )
return docs return docs
@ -405,7 +419,7 @@ class DatabricksVectorSearch(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filters: Optional[Any] = None, filter: Optional[Any] = None,
*, *,
query_type: Optional[str] = None, query_type: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -423,7 +437,7 @@ class DatabricksVectorSearch(VectorStore):
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
Defaults to 0.5. Defaults to 0.5.
filters: Filters to apply to the query. Defaults to None. filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID". query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
@ -435,12 +449,11 @@ class DatabricksVectorSearch(VectorStore):
"`max_marginal_relevance_search` is not supported for index with " "`max_marginal_relevance_search` is not supported for index with "
"Databricks-managed embeddings." "Databricks-managed embeddings."
) )
search_resp = self.index.similarity_search( search_resp = self.index.similarity_search(
columns=list(set(self.columns + [embedding_column])), columns=list(set(self.columns + [embedding_column])),
query_text=None, query_text=None,
query_vector=embedding, query_vector=embedding,
filters=filters, filters=filter or _alias_filters(kwargs),
num_results=fetch_k, num_results=fetch_k,
query_type=query_type, query_type=query_type,
) )
@ -471,7 +484,7 @@ class DatabricksVectorSearch(VectorStore):
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
filters: Optional[Any] = None, filter: Optional[Any] = None,
*, *,
query_type: Optional[str] = None, query_type: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -481,14 +494,18 @@ class DatabricksVectorSearch(VectorStore):
Args: Args:
embedding: Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filters: Filters to apply to the query. Defaults to None. filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID". query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns: Returns:
List of Documents most similar to the embedding. List of Documents most similar to the embedding.
""" """
docs_with_score = self.similarity_search_by_vector_with_score( docs_with_score = self.similarity_search_by_vector_with_score(
embedding=embedding, k=k, filters=filters, query_type=query_type, **kwargs embedding=embedding,
k=k,
filter=filter,
query_type=query_type,
**kwargs,
) )
return [doc for doc, _ in docs_with_score] return [doc for doc, _ in docs_with_score]
@ -496,7 +513,7 @@ class DatabricksVectorSearch(VectorStore):
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
filters: Optional[Any] = None, filter: Optional[Any] = None,
*, *,
query_type: Optional[str] = None, query_type: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -506,7 +523,7 @@ class DatabricksVectorSearch(VectorStore):
Args: Args:
embedding: Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filters: Filters to apply to the query. Defaults to None. filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID". query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns: Returns:
@ -520,14 +537,14 @@ class DatabricksVectorSearch(VectorStore):
search_resp = self.index.similarity_search( search_resp = self.index.similarity_search(
columns=self.columns, columns=self.columns,
query_vector=embedding, query_vector=embedding,
filters=filters, filters=filter or _alias_filters(kwargs),
num_results=k, num_results=k,
query_type=query_type, query_type=query_type,
) )
return self._parse_search_response(search_resp) return self._parse_search_response(search_resp)
def _parse_search_response( def _parse_search_response(
self, search_resp: dict, ignore_cols: Optional[List[str]] = None self, search_resp: Dict, ignore_cols: Optional[List[str]] = None
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Parse the search response into a list of Documents with score.""" """Parse the search response into a list of Documents with score."""
if ignore_cols is None: if ignore_cols is None:
@ -552,7 +569,7 @@ class DatabricksVectorSearch(VectorStore):
docs_with_score.append((doc, score)) docs_with_score.append((doc, score))
return docs_with_score return docs_with_score
def _index_schema(self) -> Optional[dict]: def _index_schema(self) -> Optional[Dict]:
"""Return the index schema as a dictionary. """Return the index schema as a dictionary.
Return None if no schema found. Return None if no schema found.
""" """
@ -574,7 +591,7 @@ class DatabricksVectorSearch(VectorStore):
""" """
return self._embedding_vector_column().get("embedding_dimension") return self._embedding_vector_column().get("embedding_dimension")
def _embedding_vector_column(self) -> dict: def _embedding_vector_column(self) -> Dict:
"""Return the embedding vector column configs as a dictionary. """Return the embedding vector column configs as a dictionary.
Empty if the index is not a self-managed embedding index. Empty if the index is not a self-managed embedding index.
""" """
@ -591,7 +608,7 @@ class DatabricksVectorSearch(VectorStore):
""" """
return self._embedding_source_column().get("name") return self._embedding_source_column().get("name")
def _embedding_source_column(self) -> dict: def _embedding_source_column(self) -> Dict:
"""Return the embedding source column configs as a dictionary. """Return the embedding source column configs as a dictionary.
Empty if the index is not a Databricks-managed embedding index. Empty if the index is not a Databricks-managed embedding index.
""" """
@ -629,3 +646,20 @@ class DatabricksVectorSearch(VectorStore):
"""Raise ValueError if the required arg with name `arg_name` is None.""" """Raise ValueError if the required arg with name `arg_name` is None."""
if not arg: if not arg:
raise ValueError(f"`{arg_name}` is required for this index.") raise ValueError(f"`{arg_name}` is required for this index.")
def _alias_filters(kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
The `filters` argument was used in the previous versions. It is now
replaced with `filter` for consistency with other vector stores, but
we still support `filters` for backward compatibility.
"""
if "filters" in kwargs:
warn_deprecated(
since="0.2.11",
removal="0.3",
message="DatabricksVectorSearch received a key `filters` in search_kwargs. "
"`filters` was deprecated since langchain-community 0.2.11 and will "
"be removed in 0.3. Please use `filter` instead.",
)
return kwargs.pop("filters", None)

View File

@ -493,7 +493,7 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
limit = 7 limit = 7
search_result = vectorsearch.similarity_search( search_result = vectorsearch.similarity_search(
query, k=limit, filters=filters, query_type=query_type query, k=limit, filter=filters, query_type=query_type
) )
if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS: if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS:
index.similarity_search.assert_called_once_with( index.similarity_search.assert_called_once_with(
@ -518,6 +518,27 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
def test_similarity_search_both_filter_and_filters_passed() -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
query = "foo"
filter = {"some filter": True}
filters = {"some other filter": False}
vectorsearch.similarity_search(query, filter=filter, filters=filters)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query),
# `filter` should prevail over `filters`
filters=filter,
num_results=4,
query_text=None,
query_type=None,
)
@pytest.mark.requires("databricks", "databricks.vector_search") @pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"index_details, columns, expected_columns", "index_details, columns, expected_columns",
@ -576,7 +597,7 @@ def test_mmr_parameters(index_details: dict) -> None:
"k": limit, "k": limit,
"fetch_k": fetch_k, "fetch_k": fetch_k,
"lambda_mult": lambda_mult, "lambda_mult": lambda_mult,
"filters": filters, "filter": filters,
}, },
) )
search_result = retriever.invoke(query) search_result = retriever.invoke(query)
@ -625,7 +646,7 @@ def test_similarity_search_by_vector(index_details: dict) -> None:
limit = 7 limit = 7
search_result = vectorsearch.similarity_search_by_vector( search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filters=filters query_embedding, k=limit, filter=filters
) )
index.similarity_search.assert_called_once_with( index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
@ -681,3 +702,32 @@ def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> No
"`similarity_search_by_vector` is not supported for index with " "`similarity_search_by_vector` is not supported for index with "
"Databricks-managed embeddings." in str(ex.value) "Databricks-managed embeddings." in str(ex.value)
) )
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"method",
[
"similarity_search",
"similarity_search_with_score",
"similarity_search_by_vector",
"similarity_search_by_vector_with_score",
"max_marginal_relevance_search",
"max_marginal_relevance_search_by_vector",
],
)
def test_filter_arg_alias(method: str) -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
vectorsearch = default_databricks_vector_search(index)
query = "foo"
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
if "by_vector" in method:
getattr(vectorsearch, method)(query_embedding, k=limit, filters=filters)
else:
getattr(vectorsearch, method)(query, k=limit, filters=filters)
index_call_args = index.similarity_search.call_args[1]
assert index_call_args["filters"] == filters