community[patch]: Performant filter columns option for Hanavector (#21971)

**Description:** Backwards compatible extension of the initialisation
interface of HanaDB to allow the user to specify
specific_metadata_columns that are used for metadata storage of selected
keys which yields increased filter performance. Any not-mentioned
metadata remains in the general metadata column as part of a JSON
string. Furthermore switched to executemany for batch inserts into
HanaDB.

**Issue:** N/A

**Dependencies:** no new dependencies added

**Twitter handle:** @sapopensource

---------

Co-authored-by: Martin Kolb <martin.kolb@sap.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
SaschaStoll
2024-05-22 22:21:21 +02:00
committed by GitHub
parent 16b55b0704
commit 709664a079
3 changed files with 551 additions and 49 deletions

View File

@@ -1,4 +1,5 @@
"""SAP HANA Cloud Vector Engine"""
from __future__ import annotations
import importlib.util
@@ -85,6 +86,8 @@ class HanaDB(VectorStore):
metadata_column: str = default_metadata_column,
vector_column: str = default_vector_column,
vector_column_length: int = default_vector_column_length,
*,
specific_metadata_columns: Optional[List[str]] = None,
):
# Check if the hdbcli package is installed
if importlib.util.find_spec("hdbcli") is None:
@@ -110,6 +113,9 @@ class HanaDB(VectorStore):
self.metadata_column = HanaDB._sanitize_name(metadata_column)
self.vector_column = HanaDB._sanitize_name(vector_column)
self.vector_column_length = HanaDB._sanitize_int(vector_column_length)
self.specific_metadata_columns = HanaDB._sanitize_specific_metadata_columns(
specific_metadata_columns or []
)
# Check if the table exists, and eventually create it
if not self._table_exists(self.table_name):
@@ -139,6 +145,8 @@ class HanaDB(VectorStore):
["REAL_VECTOR"],
self.vector_column_length,
)
for column_name in self.specific_metadata_columns:
self._check_column(self.table_name, column_name)
def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def]
sql_str = (
@@ -156,7 +164,9 @@ class HanaDB(VectorStore):
cur.close()
return False
def _check_column(self, table_name, column_name, column_type, column_length=None): # type: ignore[no-untyped-def]
def _check_column( # type: ignore[no-untyped-def]
self, table_name, column_name, column_type=None, column_length=None
):
sql_str = (
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
"SCHEMA_NAME = CURRENT_SCHEMA "
@@ -170,10 +180,11 @@ class HanaDB(VectorStore):
if len(rows) == 0:
raise AttributeError(f"Column {column_name} does not exist")
# Check data type
if rows[0][0] not in column_type:
raise AttributeError(
f"Column {column_name} has the wrong type: {rows[0][0]}"
)
if column_type:
if rows[0][0] not in column_type:
raise AttributeError(
f"Column {column_name} has the wrong type: {rows[0][0]}"
)
# Check length, if parameter was provided
if column_length is not None:
if rows[0][1] != column_length:
@@ -189,17 +200,20 @@ class HanaDB(VectorStore):
def embeddings(self) -> Embeddings:
return self.embedding
@staticmethod
def _sanitize_name(input_str: str) -> str: # type: ignore[misc]
# Remove characters that are not alphanumeric or underscores
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
@staticmethod
def _sanitize_int(input_int: any) -> int: # type: ignore[valid-type]
value = int(str(input_int))
if value < -1:
raise ValueError(f"Value ({value}) must not be smaller than -1")
return int(str(input_int))
def _sanitize_list_float(embedding: List[float]) -> List[float]: # type: ignore[misc]
@staticmethod
def _sanitize_list_float(embedding: List[float]) -> List[float]:
for value in embedding:
if not isinstance(value, float):
raise ValueError(f"Value ({value}) does not have type float")
@@ -208,13 +222,36 @@ class HanaDB(VectorStore):
# Compile pattern only once, for better performance
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
def _sanitize_metadata_keys(metadata: dict) -> dict: # type: ignore[misc]
@staticmethod
def _sanitize_metadata_keys(metadata: dict) -> dict:
for key in metadata.keys():
if not HanaDB._compiled_pattern.match(key):
raise ValueError(f"Invalid metadata key {key}")
return metadata
@staticmethod
def _sanitize_specific_metadata_columns(
specific_metadata_columns: List[str],
) -> List[str]:
metadata_columns = []
for c in specific_metadata_columns:
sanitized_name = HanaDB._sanitize_name(c)
metadata_columns.append(sanitized_name)
return metadata_columns
def _split_off_special_metadata(self, metadata: dict) -> Tuple[dict, list]:
# Use provided values by default or fallback
special_metadata = []
if not metadata:
return {}, []
for column_name in self.specific_metadata_columns:
special_metadata.append(metadata.get(column_name, None))
return metadata, special_metadata
def add_texts( # type: ignore[override]
self,
texts: Iterable[str],
@@ -238,30 +275,45 @@ class HanaDB(VectorStore):
if embeddings is None:
embeddings = self.embedding.embed_documents(list(texts))
# Create sql parameters array
sql_params = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
metadata, extracted_special_metadata = self._split_off_special_metadata(
metadata
)
embedding = (
embeddings[i]
if embeddings
else self.embedding.embed_documents([text])[0]
)
sql_params.append(
(
text,
json.dumps(HanaDB._sanitize_metadata_keys(metadata)),
f"[{','.join(map(str, embedding))}]",
*extracted_special_metadata,
)
)
# Insert data into the table
cur = self.connection.cursor()
try:
# Insert data into the table
for i, text in enumerate(texts):
# Use provided values by default or fallback
metadata = metadatas[i] if metadatas else {}
embedding = (
embeddings[i]
if embeddings
else self.embedding.embed_documents([text])[0]
)
sql_str = (
f'INSERT INTO "{self.table_name}" ("{self.content_column}", '
f'"{self.metadata_column}", "{self.vector_column}") '
f"VALUES (?, ?, TO_REAL_VECTOR (?));"
)
cur.execute(
sql_str,
(
text,
json.dumps(HanaDB._sanitize_metadata_keys(metadata)),
f"[{','.join(map(str, embedding))}]",
),
specific_metadata_columns_string = '", "'.join(
self.specific_metadata_columns
)
if specific_metadata_columns_string:
specific_metadata_columns_string = (
', "' + specific_metadata_columns_string + '"'
)
sql_str = (
f'INSERT INTO "{self.table_name}" ("{self.content_column}", '
f'"{self.metadata_column}", '
f'"{self.vector_column}"{specific_metadata_columns_string}) '
f"VALUES (?, ?, TO_REAL_VECTOR (?)"
f"{', ?' * len(self.specific_metadata_columns)});"
)
cur.executemany(sql_str, sql_params)
finally:
cur.close()
return []
@@ -279,6 +331,8 @@ class HanaDB(VectorStore):
metadata_column: str = default_metadata_column,
vector_column: str = default_vector_column,
vector_column_length: int = default_vector_column_length,
*,
specific_metadata_columns: Optional[List[str]] = None,
):
"""Create a HanaDB instance from raw documents.
This is a user-friendly interface that:
@@ -297,6 +351,7 @@ class HanaDB(VectorStore):
metadata_column=metadata_column,
vector_column=vector_column,
vector_column_length=vector_column_length, # -1 means dynamic length
specific_metadata_columns=specific_metadata_columns,
)
instance.add_texts(texts, metadatas)
return instance
@@ -514,10 +569,12 @@ class HanaDB(VectorStore):
f"Unsupported filter data-type: {type(filter_value)}"
)
where_str += (
f" JSON_VALUE({self.metadata_column}, '$.{key}')"
f" {operator} {sql_param}"
selector = (
f' "{key}"'
if key in self.specific_metadata_columns
else f"JSON_VALUE({self.metadata_column}, '$.{key}')"
)
where_str += f"{selector} " f"{operator} {sql_param}"
return where_str, query_tuple

View File

@@ -65,6 +65,7 @@ test_setup = ConfigData()
def generateSchemaName(cursor): # type: ignore[no-untyped-def]
# return "Langchain"
cursor.execute(
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
"DUMMY;"
@@ -85,6 +86,7 @@ def setup_module(module): # type: ignore[no-untyped-def]
password=os.environ.get("HANA_DB_PASSWORD"),
autocommit=True,
sslValidateCertificate=False,
# encrypt=True
)
try:
cur = test_setup.conn.cursor()
@@ -100,6 +102,7 @@ def setup_module(module): # type: ignore[no-untyped-def]
def teardown_module(module): # type: ignore[no-untyped-def]
# return
try:
cur = test_setup.conn.cursor()
sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE"
@@ -112,7 +115,7 @@ def teardown_module(module): # type: ignore[no-untyped-def]
@pytest.fixture
def texts() -> List[str]:
return ["foo", "bar", "baz"]
return ["foo", "bar", "baz", "bak", "cat"]
@pytest.fixture
@@ -121,6 +124,8 @@ def metadatas() -> List[str]:
{"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item]
{"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item]
{"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item]
{"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"}, # type: ignore[list-item]
{"start": 300, "quality": "ugly", "Owner": "Steve"}, # type: ignore[list-item]
]
@@ -640,14 +645,14 @@ def test_hanavector_delete_with_filter(texts: List[str], metadatas: List[dict])
table_name=table_name,
)
search_result = vectorDB.similarity_search(texts[0], 3)
assert len(search_result) == 3
search_result = vectorDB.similarity_search(texts[0], 10)
assert len(search_result) == 5
# Delete one of the three entries
assert vectorDB.delete(filter={"start": 100, "end": 200})
search_result = vectorDB.similarity_search(texts[0], 3)
assert len(search_result) == 2
search_result = vectorDB.similarity_search(texts[0], 10)
assert len(search_result) == 4
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -667,14 +672,14 @@ async def test_hanavector_delete_with_filter_async(
table_name=table_name,
)
search_result = vectorDB.similarity_search(texts[0], 3)
assert len(search_result) == 3
search_result = vectorDB.similarity_search(texts[0], 10)
assert len(search_result) == 5
# Delete one of the three entries
assert await vectorDB.adelete(filter={"start": 100, "end": 200})
search_result = vectorDB.similarity_search(texts[0], 3)
assert len(search_result) == 2
search_result = vectorDB.similarity_search(texts[0], 10)
assert len(search_result) == 4
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -861,7 +866,7 @@ def test_hanavector_filter_prepared_statement_params(
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
cur.execute(sql_str, (query_value))
rows = cur.fetchall()
assert len(rows) == 2
assert len(rows) == 3
# query_value = False
query_value = "false" # type: ignore[assignment]
@@ -1094,3 +1099,336 @@ def test_pgvector_with_with_metadata_filters_5(
ids = [doc.metadata["id"] for doc in docs]
assert len(ids) == len(expected_ids), test_filter
assert set(ids).issubset(expected_ids), test_filter
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_preexisting_specific_columns_for_metadata_fill(
texts: List[str], metadatas: List[dict]
) -> None:
table_name = "PREEXISTING_FILTER_COLUMNS"
# drop_table(test_setup.conn, table_name)
sql_str = (
f'CREATE TABLE "{table_name}" ('
f'"VEC_TEXT" NCLOB, '
f'"VEC_META" NCLOB, '
f'"VEC_VECTOR" REAL_VECTOR, '
f'"Owner" NVARCHAR(100), '
f'"quality" NVARCHAR(100));'
)
try:
cur = test_setup.conn.cursor()
cur.execute(sql_str)
finally:
cur.close()
vectorDB = HanaDB.from_texts(
connection=test_setup.conn,
texts=texts,
metadatas=metadatas,
embedding=embedding,
table_name=table_name,
specific_metadata_columns=["Owner", "quality"],
)
c = 0
try:
sql_str = f'SELECT COUNT(*) FROM {table_name} WHERE "quality"=' f"'ugly'"
cur = test_setup.conn.cursor()
cur.execute(sql_str)
if cur.has_result_set():
rows = cur.fetchall()
c = rows[0][0]
finally:
cur.close()
assert c == 3
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
assert len(docs) == 1
assert docs[0].page_content == "foo"
docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100})
assert len(docs) == 1
assert docs[0].page_content == "bar"
docs = vectorDB.similarity_search(
"hello", k=5, filter={"start": 100, "quality": "good"}
)
assert len(docs) == 0
docs = vectorDB.similarity_search(
"hello", k=5, filter={"start": 0, "quality": "good"}
)
assert len(docs) == 1
assert docs[0].page_content == "foo"
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_preexisting_specific_columns_for_metadata_via_array(
texts: List[str], metadatas: List[dict]
) -> None:
table_name = "PREEXISTING_FILTER_COLUMNS_VIA_ARRAY"
# drop_table(test_setup.conn, table_name)
sql_str = (
f'CREATE TABLE "{table_name}" ('
f'"VEC_TEXT" NCLOB, '
f'"VEC_META" NCLOB, '
f'"VEC_VECTOR" REAL_VECTOR, '
f'"Owner" NVARCHAR(100), '
f'"quality" NVARCHAR(100));'
)
try:
cur = test_setup.conn.cursor()
cur.execute(sql_str)
finally:
cur.close()
vectorDB = HanaDB.from_texts(
connection=test_setup.conn,
texts=texts,
metadatas=metadatas,
embedding=embedding,
table_name=table_name,
specific_metadata_columns=["quality"],
)
c = 0
try:
sql_str = f'SELECT COUNT(*) FROM {table_name} WHERE "quality"=' f"'ugly'"
cur = test_setup.conn.cursor()
cur.execute(sql_str)
if cur.has_result_set():
rows = cur.fetchall()
c = rows[0][0]
finally:
cur.close()
assert c == 3
try:
sql_str = f'SELECT COUNT(*) FROM {table_name} WHERE "Owner"=' f"'Steve'"
cur = test_setup.conn.cursor()
cur.execute(sql_str)
if cur.has_result_set():
rows = cur.fetchall()
c = rows[0][0]
finally:
cur.close()
assert c == 0
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
assert len(docs) == 1
assert docs[0].page_content == "foo"
docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100})
assert len(docs) == 1
assert docs[0].page_content == "bar"
docs = vectorDB.similarity_search(
"hello", k=5, filter={"start": 100, "quality": "good"}
)
assert len(docs) == 0
docs = vectorDB.similarity_search(
"hello", k=5, filter={"start": 0, "quality": "good"}
)
assert len(docs) == 1
assert docs[0].page_content == "foo"
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_preexisting_specific_columns_for_metadata_multiple_columns(
texts: List[str], metadatas: List[dict]
) -> None:
table_name = "PREEXISTING_FILTER_MULTIPLE_COLUMNS"
# drop_table(test_setup.conn, table_name)
sql_str = (
f'CREATE TABLE "{table_name}" ('
f'"VEC_TEXT" NCLOB, '
f'"VEC_META" NCLOB, '
f'"VEC_VECTOR" REAL_VECTOR, '
f'"quality" NVARCHAR(100), '
f'"start" INTEGER);'
)
try:
cur = test_setup.conn.cursor()
cur.execute(sql_str)
finally:
cur.close()
vectorDB = HanaDB.from_texts(
connection=test_setup.conn,
texts=texts,
metadatas=metadatas,
embedding=embedding,
table_name=table_name,
specific_metadata_columns=["quality", "start"],
)
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
assert len(docs) == 1
assert docs[0].page_content == "foo"
docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100})
assert len(docs) == 1
assert docs[0].page_content == "bar"
docs = vectorDB.similarity_search(
"hello", k=5, filter={"start": 100, "quality": "good"}
)
assert len(docs) == 0
docs = vectorDB.similarity_search(
"hello", k=5, filter={"start": 0, "quality": "good"}
)
assert len(docs) == 1
assert docs[0].page_content == "foo"
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_preexisting_specific_columns_for_metadata_empty_columns(
texts: List[str], metadatas: List[dict]
) -> None:
table_name = "PREEXISTING_FILTER_MULTIPLE_COLUMNS_EMPTY"
# drop_table(test_setup.conn, table_name)
sql_str = (
f'CREATE TABLE "{table_name}" ('
f'"VEC_TEXT" NCLOB, '
f'"VEC_META" NCLOB, '
f'"VEC_VECTOR" REAL_VECTOR, '
f'"quality" NVARCHAR(100), '
f'"ready" BOOLEAN, '
f'"start" INTEGER);'
)
try:
cur = test_setup.conn.cursor()
cur.execute(sql_str)
finally:
cur.close()
vectorDB = HanaDB.from_texts(
connection=test_setup.conn,
texts=texts,
metadatas=metadatas,
embedding=embedding,
table_name=table_name,
specific_metadata_columns=["quality", "ready", "start"],
)
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
assert len(docs) == 1
assert docs[0].page_content == "foo"
docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100})
assert len(docs) == 1
assert docs[0].page_content == "bar"
docs = vectorDB.similarity_search(
"hello", k=5, filter={"start": 100, "quality": "good"}
)
assert len(docs) == 0
docs = vectorDB.similarity_search(
"hello", k=5, filter={"start": 0, "quality": "good"}
)
assert len(docs) == 1
assert docs[0].page_content == "foo"
docs = vectorDB.similarity_search("hello", k=5, filter={"ready": True})
assert len(docs) == 3
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_preexisting_specific_columns_for_metadata_wrong_type_or_non_existing(
texts: List[str], metadatas: List[dict]
) -> None:
table_name = "PREEXISTING_FILTER_COLUMNS_WRONG_TYPE"
# drop_table(test_setup.conn, table_name)
sql_str = (
f'CREATE TABLE "{table_name}" ('
f'"VEC_TEXT" NCLOB, '
f'"VEC_META" NCLOB, '
f'"VEC_VECTOR" REAL_VECTOR, '
f'"quality" INTEGER); '
)
try:
cur = test_setup.conn.cursor()
cur.execute(sql_str)
finally:
cur.close()
# Check if table is created
exception_occured = False
try:
HanaDB.from_texts(
connection=test_setup.conn,
texts=texts,
metadatas=metadatas,
embedding=embedding,
table_name=table_name,
specific_metadata_columns=["quality"],
)
exception_occured = False
except dbapi.Error: # Nothing we should do here, hdbcli will throw an error
exception_occured = True
assert exception_occured # Check if table is created
exception_occured = False
try:
HanaDB.from_texts(
connection=test_setup.conn,
texts=texts,
metadatas=metadatas,
embedding=embedding,
table_name=table_name,
specific_metadata_columns=["NonExistingColumn"],
)
exception_occured = False
except AttributeError: # Nothing we should do here, hdbcli will throw an error
exception_occured = True
assert exception_occured
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
def test_preexisting_specific_columns_for_returned_metadata_completeness(
texts: List[str], metadatas: List[dict]
) -> None:
table_name = "PREEXISTING_FILTER_COLUMNS_METADATA_COMPLETENESS"
# drop_table(test_setup.conn, table_name)
sql_str = (
f'CREATE TABLE "{table_name}" ('
f'"VEC_TEXT" NCLOB, '
f'"VEC_META" NCLOB, '
f'"VEC_VECTOR" REAL_VECTOR, '
f'"quality" NVARCHAR(100), '
f'"NonExisting" NVARCHAR(100), '
f'"ready" BOOLEAN, '
f'"start" INTEGER);'
)
try:
cur = test_setup.conn.cursor()
cur.execute(sql_str)
finally:
cur.close()
vectorDB = HanaDB.from_texts(
connection=test_setup.conn,
texts=texts,
metadatas=metadatas,
embedding=embedding,
table_name=table_name,
specific_metadata_columns=["quality", "ready", "start", "NonExisting"],
)
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
assert len(docs) == 1
assert docs[0].page_content == "foo"
assert docs[0].metadata["end"] == 100
assert docs[0].metadata["start"] == 0
assert docs[0].metadata["quality"] == "good"
assert docs[0].metadata["ready"]
assert "NonExisting" not in docs[0].metadata.keys()