mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
community[patch]: Replace filters
argument to filter
in DatabricksVectorSearch (#24530)
The [DatabricksVectorSearch](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/vectorstores/databricks_vector_search.py#L21) class exposes similarity search APIs with argument `filters`, which is inconsistent with other VS classes who uses `filter` (singular). This PR updates the argument and add alias for backward compatibility. --------- Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
This commit is contained in:
@@ -493,7 +493,7 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search(
|
||||
query, k=limit, filters=filters, query_type=query_type
|
||||
query, k=limit, filter=filters, query_type=query_type
|
||||
)
|
||||
if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS:
|
||||
index.similarity_search.assert_called_once_with(
|
||||
@@ -518,6 +518,27 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
|
||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_similarity_search_both_filter_and_filters_passed() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
query = "foo"
|
||||
filter = {"some filter": True}
|
||||
filters = {"some other filter": False}
|
||||
|
||||
vectorsearch.similarity_search(query, filter=filter, filters=filters)
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query),
|
||||
# `filter` should prevail over `filters`
|
||||
filters=filter,
|
||||
num_results=4,
|
||||
query_text=None,
|
||||
query_type=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details, columns, expected_columns",
|
||||
@@ -576,7 +597,7 @@ def test_mmr_parameters(index_details: dict) -> None:
|
||||
"k": limit,
|
||||
"fetch_k": fetch_k,
|
||||
"lambda_mult": lambda_mult,
|
||||
"filters": filters,
|
||||
"filter": filters,
|
||||
},
|
||||
)
|
||||
search_result = retriever.invoke(query)
|
||||
@@ -625,7 +646,7 @@ def test_similarity_search_by_vector(index_details: dict) -> None:
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filters=filters
|
||||
query_embedding, k=limit, filter=filters
|
||||
)
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
@@ -681,3 +702,32 @@ def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> No
|
||||
"`similarity_search_by_vector` is not supported for index with "
|
||||
"Databricks-managed embeddings." in str(ex.value)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"method",
|
||||
[
|
||||
"similarity_search",
|
||||
"similarity_search_with_score",
|
||||
"similarity_search_by_vector",
|
||||
"similarity_search_by_vector_with_score",
|
||||
"max_marginal_relevance_search",
|
||||
"max_marginal_relevance_search_by_vector",
|
||||
],
|
||||
)
|
||||
def test_filter_arg_alias(method: str) -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
query = "foo"
|
||||
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
if "by_vector" in method:
|
||||
getattr(vectorsearch, method)(query_embedding, k=limit, filters=filters)
|
||||
else:
|
||||
getattr(vectorsearch, method)(query, k=limit, filters=filters)
|
||||
|
||||
index_call_args = index.similarity_search.call_args[1]
|
||||
assert index_call_args["filters"] == filters
|
||||
|
Reference in New Issue
Block a user