community[patch]: Advanced filtering for HANA Cloud Vector Engine (#20821)

- **Description:**
This PR adds support for advanced filtering to the integration of HANA
Vector Engine.
The newly supported filtering operators are: $eq, $ne, $gt, $gte, $lt,
$lte, $between, $in, $nin, $like, $and, $or

  - **Issue:** N/A
  - **Dependencies:** no new dependencies added

Added integration tests to:
`libs/community/tests/integration_tests/vectorstores/test_hanavector.py`

Description of the new capabilities in notebook:
`docs/docs/integrations/vectorstores/hanavector.ipynb`
This commit is contained in:
Martin Kolb
2024-04-24 22:47:27 +02:00
committed by GitHub
parent 12e5ec6de3
commit 0186e4e633
3 changed files with 447 additions and 12 deletions

View File

@@ -357,6 +357,179 @@
"print(len(docs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Advanced filtering\n",
"In addition to the basic value-based filtering capabilities, it is possible to use more advanced filtering.\n",
"The table below shows the available filter operators.\n",
"\n",
"| Operator | Semantic |\n",
"|----------|-------------------------|\n",
"| `$eq` | Equality (==) |\n",
"| `$ne` | Inequality (!=) |\n",
"| `$lt` | Less than (<) |\n",
"| `$lte` | Less than or equal (<=) |\n",
"| `$gt` | Greater than (>) |\n",
"| `$gte` | Greater than or equal (>=) |\n",
"| `$in` | Contained in a set of given values (in) |\n",
"| `$nin` | Not contained in a set of given values (not in) |\n",
"| `$between` | Between the range of two boundary values |\n",
"| `$like` | Text equality based on the \"LIKE\" semantics in SQL (using \"%\" as wildcard) |\n",
"| `$and` | Logical \"and\", supporting 2 or more operands |\n",
"| `$or` | Logical \"or\", supporting 2 or more operands |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Prepare some test documents\n",
"docs = [\n",
" Document(\n",
" page_content=\"First\",\n",
" metadata={\"name\": \"adam\", \"is_active\": True, \"id\": 1, \"height\": 10.0},\n",
" ),\n",
" Document(\n",
" page_content=\"Second\",\n",
" metadata={\"name\": \"bob\", \"is_active\": False, \"id\": 2, \"height\": 5.7},\n",
" ),\n",
" Document(\n",
" page_content=\"Third\",\n",
" metadata={\"name\": \"jane\", \"is_active\": True, \"id\": 3, \"height\": 2.4},\n",
" ),\n",
"]\n",
"\n",
"db = HanaDB(\n",
" connection=connection,\n",
" embedding=embeddings,\n",
" table_name=\"LANGCHAIN_DEMO_ADVANCED_FILTER\",\n",
")\n",
"\n",
"# Delete already existing documents from the table\n",
"db.delete(filter={})\n",
"db.add_documents(docs)\n",
"\n",
"\n",
"# Helper function for printing filter results\n",
"def print_filter_result(result):\n",
" if len(result) == 0:\n",
" print(\"<empty result>\")\n",
" for doc in result:\n",
" print(doc.metadata)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Filtering with `$ne`, `$gt`, `$gte`, `$lt`, `$lte`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"advanced_filter = {\"id\": {\"$ne\": 1}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"id\": {\"$gt\": 1}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"id\": {\"$gte\": 1}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"id\": {\"$lt\": 1}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"id\": {\"$lte\": 1}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Filtering with `$between`, `$in`, `$nin`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"advanced_filter = {\"id\": {\"$between\": (1, 2)}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"name\": {\"$in\": [\"adam\", \"bob\"]}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"name\": {\"$nin\": [\"adam\", \"bob\"]}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Text filtering with `$like`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"advanced_filter = {\"name\": {\"$like\": \"a%\"}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"name\": {\"$like\": \"%a%\"}}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Combined filtering with `$and`, `$or`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"advanced_filter = {\"$or\": [{\"id\": 1}, {\"name\": \"bob\"}]}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"$and\": [{\"id\": 1}, {\"id\": 2}]}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n",
"\n",
"advanced_filter = {\"$or\": [{\"id\": 1}, {\"id\": 2}, {\"id\": 3}]}\n",
"print(f\"Filter: {advanced_filter}\")\n",
"print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))"
]
},
{
"cell_type": "markdown",
"metadata": {},

View File

@@ -8,6 +8,7 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
@@ -34,6 +35,27 @@ HANA_DISTANCE_FUNCTION: dict = {
DistanceStrategy.EUCLIDEAN_DISTANCE: ("L2DISTANCE", "ASC"),
}
COMPARISONS_TO_SQL = {
"$eq": "=",
"$ne": "<>",
"$lt": "<",
"$lte": "<=",
"$gt": ">",
"$gte": ">=",
}
IN_OPERATORS_TO_SQL = {
"$in": "IN",
"$nin": "NOT IN",
}
BETWEEN_OPERATOR = "$between"
LIKE_OPERATOR = "$like"
LOGICAL_OPERATORS_TO_SQL = {"$and": "AND", "$or": "OR"}
default_distance_strategy = DistanceStrategy.COSINE
default_table_name: str = "EMBEDDINGS"
default_content_column: str = "VEC_TEXT"
@@ -404,29 +426,99 @@ class HanaDB(VectorStore):
return [doc for doc, _ in docs_and_scores]
def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def]
query_tuple = []
where_str = ""
if filter:
where_str, query_tuple = self._process_filter_object(filter)
where_str = " WHERE " + where_str
return where_str, query_tuple
def _process_filter_object(self, filter): # type: ignore[no-untyped-def]
query_tuple = []
where_str = ""
if filter:
for i, key in enumerate(filter.keys()):
if i == 0:
where_str += " WHERE "
else:
filter_value = filter[key]
if i != 0:
where_str += " AND "
where_str += f" JSON_VALUE({self.metadata_column}, '$.{key}') = ?"
# Handling of 'special' boolean operators "$and", "$or"
if key in LOGICAL_OPERATORS_TO_SQL:
logical_operator = LOGICAL_OPERATORS_TO_SQL[key]
logical_operands = filter_value
for j, logical_operand in enumerate(logical_operands):
if j != 0:
where_str += f" {logical_operator} "
(
where_str_logical,
query_tuple_logical,
) = self._process_filter_object(logical_operand)
where_str += where_str_logical
query_tuple += query_tuple_logical
continue
if isinstance(filter[key], bool):
if filter[key]:
query_tuple.append("true")
operator = "="
sql_param = "?"
if isinstance(filter_value, bool):
query_tuple.append("true" if filter_value else "false")
elif isinstance(filter_value, int) or isinstance(filter_value, str):
query_tuple.append(filter_value)
elif isinstance(filter_value, Dict):
# Handling of 'special' operators starting with "$"
special_op = next(iter(filter_value))
special_val = filter_value[special_op]
# "$eq", "$ne", "$lt", "$lte", "$gt", "$gte"
if special_op in COMPARISONS_TO_SQL:
operator = COMPARISONS_TO_SQL[special_op]
if isinstance(special_val, bool):
query_tuple.append("true" if filter_value else "false")
elif isinstance(special_val, float):
sql_param = "CAST(? as float)"
query_tuple.append(special_val)
else:
query_tuple.append(special_val)
# "$between"
elif special_op == BETWEEN_OPERATOR:
between_from = special_val[0]
between_to = special_val[1]
operator = "BETWEEN"
sql_param = "? AND ?"
query_tuple.append(between_from)
query_tuple.append(between_to)
# "$like"
elif special_op == LIKE_OPERATOR:
operator = "LIKE"
query_tuple.append(special_val)
# "$in", "$nin"
elif special_op in IN_OPERATORS_TO_SQL:
operator = IN_OPERATORS_TO_SQL[special_op]
if isinstance(special_val, list):
for i, list_entry in enumerate(special_val):
if i == 0:
sql_param = "("
sql_param = sql_param + "?"
if i == (len(special_val) - 1):
sql_param = sql_param + ")"
else:
sql_param = sql_param + ","
query_tuple.append(list_entry)
else:
raise ValueError(
f"Unsupported value for {operator}: {special_val}"
)
else:
query_tuple.append("false")
elif isinstance(filter[key], int) or isinstance(filter[key], str):
query_tuple.append(filter[key])
raise ValueError(f"Unsupported operator: {special_op}")
else:
raise ValueError(
f"Unsupported filter data-type: {type(filter[key])}"
f"Unsupported filter data-type: {type(filter_value)}"
)
where_str += (
f" JSON_VALUE({self.metadata_column}, '$.{key}')"
f" {operator} {sql_param}"
)
return where_str, query_tuple
def delete( # type: ignore[override]

View File

@@ -2,7 +2,7 @@
import os
import random
from typing import List
from typing import Any, Dict, List
import numpy as np
import pytest
@@ -12,6 +12,23 @@ from langchain_community.vectorstores.utils import DistanceStrategy
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
DOCUMENTS,
TYPE_1_FILTERING_TEST_CASES,
TYPE_2_FILTERING_TEST_CASES,
TYPE_3_FILTERING_TEST_CASES,
TYPE_4_FILTERING_TEST_CASES,
TYPE_5_FILTERING_TEST_CASES,
)
TYPE_4B_FILTERING_TEST_CASES = [
# Test $nin, which is missing in TYPE_4_FILTERING_TEST_CASES
(
{"name": {"$nin": ["adam", "bob"]}},
[3],
),
]
try:
from hdbcli import dbapi
@@ -924,3 +941,156 @@ def test_hanavector_table_mixed_case_names(texts: List[str]) -> None:
# check results of similarity search
assert texts[0] == vectordb.similarity_search(texts[0], 1)[0].page_content
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_hanavector_enhanced_filter_1() -> None:
table_name = "TEST_TABLE_ENHANCED_FILTER_1"
# Delete table if it exists
drop_table(test_setup.conn, table_name)
vectorDB = HanaDB(
connection=test_setup.conn,
embedding=embedding,
table_name=table_name,
)
vectorDB.add_documents(DOCUMENTS)
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES)
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_pgvector_with_with_metadata_filters_1(
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
table_name = "TEST_TABLE_ENHANCED_FILTER_1"
drop_table(test_setup.conn, table_name)
vectorDB = HanaDB(
connection=test_setup.conn,
embedding=embedding,
table_name=table_name,
)
vectorDB.add_documents(DOCUMENTS)
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
ids = [doc.metadata["id"] for doc in docs]
assert len(ids) == len(expected_ids), test_filter
assert set(ids).issubset(expected_ids), test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES)
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_pgvector_with_with_metadata_filters_2(
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
table_name = "TEST_TABLE_ENHANCED_FILTER_2"
drop_table(test_setup.conn, table_name)
vectorDB = HanaDB(
connection=test_setup.conn,
embedding=embedding,
table_name=table_name,
)
vectorDB.add_documents(DOCUMENTS)
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
ids = [doc.metadata["id"] for doc in docs]
assert len(ids) == len(expected_ids), test_filter
assert set(ids).issubset(expected_ids), test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES)
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_pgvector_with_with_metadata_filters_3(
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
table_name = "TEST_TABLE_ENHANCED_FILTER_3"
drop_table(test_setup.conn, table_name)
vectorDB = HanaDB(
connection=test_setup.conn,
embedding=embedding,
table_name=table_name,
)
vectorDB.add_documents(DOCUMENTS)
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
ids = [doc.metadata["id"] for doc in docs]
assert len(ids) == len(expected_ids), test_filter
assert set(ids).issubset(expected_ids), test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES)
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_pgvector_with_with_metadata_filters_4(
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
table_name = "TEST_TABLE_ENHANCED_FILTER_4"
drop_table(test_setup.conn, table_name)
vectorDB = HanaDB(
connection=test_setup.conn,
embedding=embedding,
table_name=table_name,
)
vectorDB.add_documents(DOCUMENTS)
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
ids = [doc.metadata["id"] for doc in docs]
assert len(ids) == len(expected_ids), test_filter
assert set(ids).issubset(expected_ids), test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4B_FILTERING_TEST_CASES)
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_pgvector_with_with_metadata_filters_4b(
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
table_name = "TEST_TABLE_ENHANCED_FILTER_4B"
drop_table(test_setup.conn, table_name)
vectorDB = HanaDB(
connection=test_setup.conn,
embedding=embedding,
table_name=table_name,
)
vectorDB.add_documents(DOCUMENTS)
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
ids = [doc.metadata["id"] for doc in docs]
assert len(ids) == len(expected_ids), test_filter
assert set(ids).issubset(expected_ids), test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES)
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_pgvector_with_with_metadata_filters_5(
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
table_name = "TEST_TABLE_ENHANCED_FILTER_5"
drop_table(test_setup.conn, table_name)
vectorDB = HanaDB(
connection=test_setup.conn,
embedding=embedding,
table_name=table_name,
)
vectorDB.add_documents(DOCUMENTS)
docs = vectorDB.similarity_search("meow", k=5, filter=test_filter)
ids = [doc.metadata["id"] for doc in docs]
assert len(ids) == len(expected_ids), test_filter
assert set(ids).issubset(expected_ids), test_filter