community[patch]: Replace filters argument to filter in DatabricksVectorSearch (#24530)

The
[DatabricksVectorSearch](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/vectorstores/databricks_vector_search.py#L21)
class exposes similarity search APIs with argument `filters`, which is
inconsistent with other VS classes who uses `filter` (singular). This PR
updates the argument and add alias for backward compatibility.

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
This commit is contained in:
Yuki Watanabe
2024-07-26 13:20:18 +09:00
committed by GitHub
parent 148766ddc1
commit 2b6a262f84
2 changed files with 114 additions and 30 deletions

View File

@@ -493,7 +493,7 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
limit = 7
search_result = vectorsearch.similarity_search(
query, k=limit, filters=filters, query_type=query_type
query, k=limit, filter=filters, query_type=query_type
)
if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS:
index.similarity_search.assert_called_once_with(
@@ -518,6 +518,27 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
def test_similarity_search_both_filter_and_filters_passed() -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
query = "foo"
filter = {"some filter": True}
filters = {"some other filter": False}
vectorsearch.similarity_search(query, filter=filter, filters=filters)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query),
# `filter` should prevail over `filters`
filters=filter,
num_results=4,
query_text=None,
query_type=None,
)
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, columns, expected_columns",
@@ -576,7 +597,7 @@ def test_mmr_parameters(index_details: dict) -> None:
"k": limit,
"fetch_k": fetch_k,
"lambda_mult": lambda_mult,
"filters": filters,
"filter": filters,
},
)
search_result = retriever.invoke(query)
@@ -625,7 +646,7 @@ def test_similarity_search_by_vector(index_details: dict) -> None:
limit = 7
search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filters=filters
query_embedding, k=limit, filter=filters
)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
@@ -681,3 +702,32 @@ def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> No
"`similarity_search_by_vector` is not supported for index with "
"Databricks-managed embeddings." in str(ex.value)
)
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"method",
[
"similarity_search",
"similarity_search_with_score",
"similarity_search_by_vector",
"similarity_search_by_vector_with_score",
"max_marginal_relevance_search",
"max_marginal_relevance_search_by_vector",
],
)
def test_filter_arg_alias(method: str) -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
vectorsearch = default_databricks_vector_search(index)
query = "foo"
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
if "by_vector" in method:
getattr(vectorsearch, method)(query_embedding, k=limit, filters=filters)
else:
getattr(vectorsearch, method)(query, k=limit, filters=filters)
index_call_args = index.similarity_search.call_args[1]
assert index_call_args["filters"] == filters