mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 10:23:18 +00:00
community[patch]: Advanced filtering for HANA Cloud Vector Engine (#20821)
- **Description:** This PR adds support for advanced filtering to the integration of HANA Vector Engine. The newly supported filtering operators are: $eq, $ne, $gt, $gte, $lt, $lte, $between, $in, $nin, $like, $and, $or - **Issue:** N/A - **Dependencies:** no new dependencies added Added integration tests to: `libs/community/tests/integration_tests/vectorstores/test_hanavector.py` Description of the new capabilities in notebook: `docs/docs/integrations/vectorstores/hanavector.ipynb`
This commit is contained in:
@@ -357,6 +357,179 @@
|
||||
"print(len(docs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced filtering\n",
|
||||
"In addition to the basic value-based filtering capabilities, it is possible to use more advanced filtering.\n",
|
||||
"The table below shows the available filter operators.\n",
|
||||
"\n",
|
||||
"| Operator | Semantic |\n",
|
||||
"|----------|-------------------------|\n",
|
||||
"| `$eq` | Equality (==) |\n",
|
||||
"| `$ne` | Inequality (!=) |\n",
|
||||
"| `$lt` | Less than (<) |\n",
|
||||
"| `$lte` | Less than or equal (<=) |\n",
|
||||
"| `$gt` | Greater than (>) |\n",
|
||||
"| `$gte` | Greater than or equal (>=) |\n",
|
||||
"| `$in` | Contained in a set of given values (in) |\n",
|
||||
"| `$nin` | Not contained in a set of given values (not in) |\n",
|
||||
"| `$between` | Between the range of two boundary values |\n",
|
||||
"| `$like` | Text equality based on the \"LIKE\" semantics in SQL (using \"%\" as wildcard) |\n",
|
||||
"| `$and` | Logical \"and\", supporting 2 or more operands |\n",
|
||||
"| `$or` | Logical \"or\", supporting 2 or more operands |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prepare some test documents\n",
|
||||
"docs = [\n",
|
||||
" Document(\n",
|
||||
" page_content=\"First\",\n",
|
||||
" metadata={\"name\": \"adam\", \"is_active\": True, \"id\": 1, \"height\": 10.0},\n",
|
||||
" ),\n",
|
||||
" Document(\n",
|
||||
" page_content=\"Second\",\n",
|
||||
" metadata={\"name\": \"bob\", \"is_active\": False, \"id\": 2, \"height\": 5.7},\n",
|
||||
" ),\n",
|
||||
" Document(\n",
|
||||
" page_content=\"Third\",\n",
|
||||
" metadata={\"name\": \"jane\", \"is_active\": True, \"id\": 3, \"height\": 2.4},\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"db = HanaDB(\n",
|
||||
" connection=connection,\n",
|
||||
" embedding=embeddings,\n",
|
||||
" table_name=\"LANGCHAIN_DEMO_ADVANCED_FILTER\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Delete already existing documents from the table\n",
|
||||
"db.delete(filter={})\n",
|
||||
"db.add_documents(docs)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Helper function for printing filter results\n",
|
||||
"def print_filter_result(result):\n",
|
||||
" if len(result) == 0:\n",
|
||||
" print(\"<empty result>\")\n",
|
||||
" for doc in result:\n",
|
||||
" print(doc.metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Filtering with `$ne`, `$gt`, `$gte`, `$lt`, `$lte`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"advanced_filter = {\"id\": {\"$ne\": 1}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"id\": {\"$gt\": 1}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"id\": {\"$gte\": 1}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"id\": {\"$lt\": 1}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"id\": {\"$lte\": 1}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Filtering with `$between`, `$in`, `$nin`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"advanced_filter = {\"id\": {\"$between\": (1, 2)}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"name\": {\"$in\": [\"adam\", \"bob\"]}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"name\": {\"$nin\": [\"adam\", \"bob\"]}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Text filtering with `$like`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"advanced_filter = {\"name\": {\"$like\": \"a%\"}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"name\": {\"$like\": \"%a%\"}}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Combined filtering with `$and`, `$or`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"advanced_filter = {\"$or\": [{\"id\": 1}, {\"name\": \"bob\"}]}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"$and\": [{\"id\": 1}, {\"id\": 2}]}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
|
||||
"\n",
|
||||
"advanced_filter = {\"$or\": [{\"id\": 1}, {\"id\": 2}, {\"id\": 3}]}\n",
|
||||
"print(f\"Filter: {advanced_filter}\")\n",
|
||||
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@@ -8,6 +8,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
@@ -34,6 +35,27 @@ HANA_DISTANCE_FUNCTION: dict = {
|
||||
DistanceStrategy.EUCLIDEAN_DISTANCE: ("L2DISTANCE", "ASC"),
|
||||
}
|
||||
|
||||
COMPARISONS_TO_SQL = {
|
||||
"$eq": "=",
|
||||
"$ne": "<>",
|
||||
"$lt": "<",
|
||||
"$lte": "<=",
|
||||
"$gt": ">",
|
||||
"$gte": ">=",
|
||||
}
|
||||
|
||||
IN_OPERATORS_TO_SQL = {
|
||||
"$in": "IN",
|
||||
"$nin": "NOT IN",
|
||||
}
|
||||
|
||||
BETWEEN_OPERATOR = "$between"
|
||||
|
||||
LIKE_OPERATOR = "$like"
|
||||
|
||||
LOGICAL_OPERATORS_TO_SQL = {"$and": "AND", "$or": "OR"}
|
||||
|
||||
|
||||
default_distance_strategy = DistanceStrategy.COSINE
|
||||
default_table_name: str = "EMBEDDINGS"
|
||||
default_content_column: str = "VEC_TEXT"
|
||||
@@ -404,29 +426,99 @@ class HanaDB(VectorStore):
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def]
|
||||
query_tuple = []
|
||||
where_str = ""
|
||||
if filter:
|
||||
where_str, query_tuple = self._process_filter_object(filter)
|
||||
where_str = " WHERE " + where_str
|
||||
return where_str, query_tuple
|
||||
|
||||
def _process_filter_object(self, filter): # type: ignore[no-untyped-def]
|
||||
query_tuple = []
|
||||
where_str = ""
|
||||
if filter:
|
||||
for i, key in enumerate(filter.keys()):
|
||||
if i == 0:
|
||||
where_str += " WHERE "
|
||||
else:
|
||||
filter_value = filter[key]
|
||||
if i != 0:
|
||||
where_str += " AND "
|
||||
|
||||
where_str += f" JSON_VALUE({self.metadata_column}, '$.{key}') = ?"
|
||||
# Handling of 'special' boolean operators "$and", "$or"
|
||||
if key in LOGICAL_OPERATORS_TO_SQL:
|
||||
logical_operator = LOGICAL_OPERATORS_TO_SQL[key]
|
||||
logical_operands = filter_value
|
||||
for j, logical_operand in enumerate(logical_operands):
|
||||
if j != 0:
|
||||
where_str += f" {logical_operator} "
|
||||
(
|
||||
where_str_logical,
|
||||
query_tuple_logical,
|
||||
) = self._process_filter_object(logical_operand)
|
||||
where_str += where_str_logical
|
||||
query_tuple += query_tuple_logical
|
||||
continue
|
||||
|
||||
if isinstance(filter[key], bool):
|
||||
if filter[key]:
|
||||
query_tuple.append("true")
|
||||
operator = "="
|
||||
sql_param = "?"
|
||||
|
||||
if isinstance(filter_value, bool):
|
||||
query_tuple.append("true" if filter_value else "false")
|
||||
elif isinstance(filter_value, int) or isinstance(filter_value, str):
|
||||
query_tuple.append(filter_value)
|
||||
elif isinstance(filter_value, Dict):
|
||||
# Handling of 'special' operators starting with "$"
|
||||
special_op = next(iter(filter_value))
|
||||
special_val = filter_value[special_op]
|
||||
# "$eq", "$ne", "$lt", "$lte", "$gt", "$gte"
|
||||
if special_op in COMPARISONS_TO_SQL:
|
||||
operator = COMPARISONS_TO_SQL[special_op]
|
||||
if isinstance(special_val, bool):
|
||||
query_tuple.append("true" if filter_value else "false")
|
||||
elif isinstance(special_val, float):
|
||||
sql_param = "CAST(? as float)"
|
||||
query_tuple.append(special_val)
|
||||
else:
|
||||
query_tuple.append(special_val)
|
||||
# "$between"
|
||||
elif special_op == BETWEEN_OPERATOR:
|
||||
between_from = special_val[0]
|
||||
between_to = special_val[1]
|
||||
operator = "BETWEEN"
|
||||
sql_param = "? AND ?"
|
||||
query_tuple.append(between_from)
|
||||
query_tuple.append(between_to)
|
||||
# "$like"
|
||||
elif special_op == LIKE_OPERATOR:
|
||||
operator = "LIKE"
|
||||
query_tuple.append(special_val)
|
||||
# "$in", "$nin"
|
||||
elif special_op in IN_OPERATORS_TO_SQL:
|
||||
operator = IN_OPERATORS_TO_SQL[special_op]
|
||||
if isinstance(special_val, list):
|
||||
for i, list_entry in enumerate(special_val):
|
||||
if i == 0:
|
||||
sql_param = "("
|
||||
sql_param = sql_param + "?"
|
||||
if i == (len(special_val) - 1):
|
||||
sql_param = sql_param + ")"
|
||||
else:
|
||||
sql_param = sql_param + ","
|
||||
query_tuple.append(list_entry)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported value for {operator}: {special_val}"
|
||||
)
|
||||
else:
|
||||
query_tuple.append("false")
|
||||
elif isinstance(filter[key], int) or isinstance(filter[key], str):
|
||||
query_tuple.append(filter[key])
|
||||
raise ValueError(f"Unsupported operator: {special_op}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter data-type: {type(filter[key])}"
|
||||
f"Unsupported filter data-type: {type(filter_value)}"
|
||||
)
|
||||
|
||||
where_str += (
|
||||
f" JSON_VALUE({self.metadata_column}, '$.{key}')"
|
||||
f" {operator} {sql_param}"
|
||||
)
|
||||
|
||||
return where_str, query_tuple
|
||||
|
||||
def delete( # type: ignore[override]
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -12,6 +12,23 @@ from langchain_community.vectorstores.utils import DistanceStrategy
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
|
||||
DOCUMENTS,
|
||||
TYPE_1_FILTERING_TEST_CASES,
|
||||
TYPE_2_FILTERING_TEST_CASES,
|
||||
TYPE_3_FILTERING_TEST_CASES,
|
||||
TYPE_4_FILTERING_TEST_CASES,
|
||||
TYPE_5_FILTERING_TEST_CASES,
|
||||
)
|
||||
|
||||
TYPE_4B_FILTERING_TEST_CASES = [
|
||||
# Test $nin, which is missing in TYPE_4_FILTERING_TEST_CASES
|
||||
(
|
||||
{"name": {"$nin": ["adam", "bob"]}},
|
||||
[3],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
from hdbcli import dbapi
|
||||
@@ -924,3 +941,156 @@ def test_hanavector_table_mixed_case_names(texts: List[str]) -> None:
|
||||
|
||||
# check results of similarity search
|
||||
assert texts[0] == vectordb.similarity_search(texts[0], 1)[0].page_content
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_enhanced_filter_1() -> None:
|
||||
table_name = "TEST_TABLE_ENHANCED_FILTER_1"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
vectorDB = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vectorDB.add_documents(DOCUMENTS)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES)
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_pgvector_with_with_metadata_filters_1(
|
||||
test_filter: Dict[str, Any],
|
||||
expected_ids: List[int],
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_ENHANCED_FILTER_1"
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
vectorDB = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vectorDB.add_documents(DOCUMENTS)
|
||||
|
||||
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
|
||||
ids = [doc.metadata["id"] for doc in docs]
|
||||
assert len(ids) == len(expected_ids), test_filter
|
||||
assert set(ids).issubset(expected_ids), test_filter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES)
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_pgvector_with_with_metadata_filters_2(
|
||||
test_filter: Dict[str, Any],
|
||||
expected_ids: List[int],
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_ENHANCED_FILTER_2"
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
vectorDB = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vectorDB.add_documents(DOCUMENTS)
|
||||
|
||||
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
|
||||
ids = [doc.metadata["id"] for doc in docs]
|
||||
assert len(ids) == len(expected_ids), test_filter
|
||||
assert set(ids).issubset(expected_ids), test_filter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES)
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_pgvector_with_with_metadata_filters_3(
|
||||
test_filter: Dict[str, Any],
|
||||
expected_ids: List[int],
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_ENHANCED_FILTER_3"
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
vectorDB = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vectorDB.add_documents(DOCUMENTS)
|
||||
|
||||
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
|
||||
ids = [doc.metadata["id"] for doc in docs]
|
||||
assert len(ids) == len(expected_ids), test_filter
|
||||
assert set(ids).issubset(expected_ids), test_filter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES)
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_pgvector_with_with_metadata_filters_4(
|
||||
test_filter: Dict[str, Any],
|
||||
expected_ids: List[int],
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_ENHANCED_FILTER_4"
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
vectorDB = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vectorDB.add_documents(DOCUMENTS)
|
||||
|
||||
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
|
||||
ids = [doc.metadata["id"] for doc in docs]
|
||||
assert len(ids) == len(expected_ids), test_filter
|
||||
assert set(ids).issubset(expected_ids), test_filter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4B_FILTERING_TEST_CASES)
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_pgvector_with_with_metadata_filters_4b(
|
||||
test_filter: Dict[str, Any],
|
||||
expected_ids: List[int],
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_ENHANCED_FILTER_4B"
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
vectorDB = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vectorDB.add_documents(DOCUMENTS)
|
||||
|
||||
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
|
||||
ids = [doc.metadata["id"] for doc in docs]
|
||||
assert len(ids) == len(expected_ids), test_filter
|
||||
assert set(ids).issubset(expected_ids), test_filter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES)
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_pgvector_with_with_metadata_filters_5(
|
||||
test_filter: Dict[str, Any],
|
||||
expected_ids: List[int],
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_ENHANCED_FILTER_5"
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
vectorDB = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vectorDB.add_documents(DOCUMENTS)
|
||||
|
||||
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
|
||||
ids = [doc.metadata["id"] for doc in docs]
|
||||
assert len(ids) == len(expected_ids), test_filter
|
||||
assert set(ids).issubset(expected_ids), test_filter
|
||||
|
Reference in New Issue
Block a user