community[patch]: Replace filters argument to filter in DatabricksVectorSearch (#24530)

The
[DatabricksVectorSearch](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/vectorstores/databricks_vector_search.py#L21)
class exposes similarity search APIs with argument `filters`, which is
inconsistent with other VS classes who uses `filter` (singular). This PR
updates the argument and add alias for backward compatibility.

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
This commit is contained in:
Yuki Watanabe 2024-07-26 13:20:18 +09:00 committed by GitHub
parent 148766ddc1
commit 2b6a262f84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 114 additions and 30 deletions

View File

@ -3,9 +3,20 @@ from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
)
import numpy as np
from langchain_core._api import warn_deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore
@ -193,7 +204,7 @@ class DatabricksVectorSearch(VectorStore):
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[List[Dict]] = None,
**kwargs: Any,
) -> VST:
raise NotImplementedError(
@ -204,7 +215,7 @@ class DatabricksVectorSearch(VectorStore):
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[Any]] = None,
**kwargs: Any,
) -> List[str]:
@ -280,7 +291,7 @@ class DatabricksVectorSearch(VectorStore):
self,
query: str,
k: int = 4,
filters: Optional[Any] = None,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
@ -290,14 +301,18 @@ class DatabricksVectorSearch(VectorStore):
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filters: Filters to apply to the query. Defaults to None.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding.
"""
docs_with_score = self.similarity_search_with_score(
query=query, k=k, filters=filters, query_type=query_type, **kwargs
query=query,
k=k,
filter=filter,
query_type=query_type,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
@ -305,7 +320,7 @@ class DatabricksVectorSearch(VectorStore):
self,
query: str,
k: int = 4,
filters: Optional[Any] = None,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
@ -315,7 +330,7 @@ class DatabricksVectorSearch(VectorStore):
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filters: Filters to apply to the query. Defaults to None.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
@ -328,12 +343,11 @@ class DatabricksVectorSearch(VectorStore):
assert self.embeddings is not None, "embedding model is required."
query_text = None
query_vector = self.embeddings.embed_query(query)
search_resp = self.index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=filters,
filters=filter or _alias_filters(kwargs),
num_results=k,
query_type=query_type,
)
@ -357,7 +371,7 @@ class DatabricksVectorSearch(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filters: Optional[Any] = None,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
@ -375,7 +389,7 @@ class DatabricksVectorSearch(VectorStore):
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filters: Filters to apply to the query. Defaults to None.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents selected by maximal marginal relevance.
@ -394,7 +408,7 @@ class DatabricksVectorSearch(VectorStore):
k,
fetch_k,
lambda_mult=lambda_mult,
filters=filters,
filter=filter or _alias_filters(kwargs),
query_type=query_type,
)
return docs
@ -405,7 +419,7 @@ class DatabricksVectorSearch(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filters: Optional[Any] = None,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
@ -423,7 +437,7 @@ class DatabricksVectorSearch(VectorStore):
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filters: Filters to apply to the query. Defaults to None.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents selected by maximal marginal relevance.
@ -435,12 +449,11 @@ class DatabricksVectorSearch(VectorStore):
"`max_marginal_relevance_search` is not supported for index with "
"Databricks-managed embeddings."
)
search_resp = self.index.similarity_search(
columns=list(set(self.columns + [embedding_column])),
query_text=None,
query_vector=embedding,
filters=filters,
filters=filter or _alias_filters(kwargs),
num_results=fetch_k,
query_type=query_type,
)
@ -471,7 +484,7 @@ class DatabricksVectorSearch(VectorStore):
self,
embedding: List[float],
k: int = 4,
filters: Optional[Any] = None,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
@ -481,14 +494,18 @@ class DatabricksVectorSearch(VectorStore):
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filters: Filters to apply to the query. Defaults to None.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding.
"""
docs_with_score = self.similarity_search_by_vector_with_score(
embedding=embedding, k=k, filters=filters, query_type=query_type, **kwargs
embedding=embedding,
k=k,
filter=filter,
query_type=query_type,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
@ -496,7 +513,7 @@ class DatabricksVectorSearch(VectorStore):
self,
embedding: List[float],
k: int = 4,
filters: Optional[Any] = None,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
@ -506,7 +523,7 @@ class DatabricksVectorSearch(VectorStore):
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filters: Filters to apply to the query. Defaults to None.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
@ -520,14 +537,14 @@ class DatabricksVectorSearch(VectorStore):
search_resp = self.index.similarity_search(
columns=self.columns,
query_vector=embedding,
filters=filters,
filters=filter or _alias_filters(kwargs),
num_results=k,
query_type=query_type,
)
return self._parse_search_response(search_resp)
def _parse_search_response(
self, search_resp: dict, ignore_cols: Optional[List[str]] = None
self, search_resp: Dict, ignore_cols: Optional[List[str]] = None
) -> List[Tuple[Document, float]]:
"""Parse the search response into a list of Documents with score."""
if ignore_cols is None:
@ -552,7 +569,7 @@ class DatabricksVectorSearch(VectorStore):
docs_with_score.append((doc, score))
return docs_with_score
def _index_schema(self) -> Optional[dict]:
def _index_schema(self) -> Optional[Dict]:
"""Return the index schema as a dictionary.
Return None if no schema found.
"""
@ -574,7 +591,7 @@ class DatabricksVectorSearch(VectorStore):
"""
return self._embedding_vector_column().get("embedding_dimension")
def _embedding_vector_column(self) -> dict:
def _embedding_vector_column(self) -> Dict:
"""Return the embedding vector column configs as a dictionary.
Empty if the index is not a self-managed embedding index.
"""
@ -591,7 +608,7 @@ class DatabricksVectorSearch(VectorStore):
"""
return self._embedding_source_column().get("name")
def _embedding_source_column(self) -> dict:
def _embedding_source_column(self) -> Dict:
"""Return the embedding source column configs as a dictionary.
Empty if the index is not a Databricks-managed embedding index.
"""
@ -629,3 +646,20 @@ class DatabricksVectorSearch(VectorStore):
"""Raise ValueError if the required arg with name `arg_name` is None."""
if not arg:
raise ValueError(f"`{arg_name}` is required for this index.")
def _alias_filters(kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
The `filters` argument was used in the previous versions. It is now
replaced with `filter` for consistency with other vector stores, but
we still support `filters` for backward compatibility.
"""
if "filters" in kwargs:
warn_deprecated(
since="0.2.11",
removal="0.3",
message="DatabricksVectorSearch received a key `filters` in search_kwargs. "
"`filters` was deprecated since langchain-community 0.2.11 and will "
"be removed in 0.3. Please use `filter` instead.",
)
return kwargs.pop("filters", None)

View File

@ -493,7 +493,7 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
limit = 7
search_result = vectorsearch.similarity_search(
query, k=limit, filters=filters, query_type=query_type
query, k=limit, filter=filters, query_type=query_type
)
if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS:
index.similarity_search.assert_called_once_with(
@ -518,6 +518,27 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
def test_similarity_search_both_filter_and_filters_passed() -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
query = "foo"
filter = {"some filter": True}
filters = {"some other filter": False}
vectorsearch.similarity_search(query, filter=filter, filters=filters)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query),
# `filter` should prevail over `filters`
filters=filter,
num_results=4,
query_text=None,
query_type=None,
)
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, columns, expected_columns",
@ -576,7 +597,7 @@ def test_mmr_parameters(index_details: dict) -> None:
"k": limit,
"fetch_k": fetch_k,
"lambda_mult": lambda_mult,
"filters": filters,
"filter": filters,
},
)
search_result = retriever.invoke(query)
@ -625,7 +646,7 @@ def test_similarity_search_by_vector(index_details: dict) -> None:
limit = 7
search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filters=filters
query_embedding, k=limit, filter=filters
)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
@ -681,3 +702,32 @@ def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> No
"`similarity_search_by_vector` is not supported for index with "
"Databricks-managed embeddings." in str(ex.value)
)
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"method",
[
"similarity_search",
"similarity_search_with_score",
"similarity_search_by_vector",
"similarity_search_by_vector_with_score",
"max_marginal_relevance_search",
"max_marginal_relevance_search_by_vector",
],
)
def test_filter_arg_alias(method: str) -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
vectorsearch = default_databricks_vector_search(index)
query = "foo"
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
if "by_vector" in method:
getattr(vectorsearch, method)(query_embedding, k=limit, filters=filters)
else:
getattr(vectorsearch, method)(query, k=limit, filters=filters)
index_call_args = index.similarity_search.call_args[1]
assert index_call_args["filters"] == filters