mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 15:03:21 +00:00
community[patch]: Performant filter columns option for Hanavector (#21971)
**Description:** Backwards compatible extension of the initialisation interface of HanaDB to allow the user to specify specific_metadata_columns that are used for metadata storage of selected keys which yields increased filter performance. Any not-mentioned metadata remains in the general metadata column as part of a JSON string. Furthermore switched to executemany for batch inserts into HanaDB. **Issue:** N/A **Dependencies:** no new dependencies added **Twitter handle:** @sapopensource --------- Co-authored-by: Martin Kolb <martin.kolb@sap.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""SAP HANA Cloud Vector Engine"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
@@ -85,6 +86,8 @@ class HanaDB(VectorStore):
|
||||
metadata_column: str = default_metadata_column,
|
||||
vector_column: str = default_vector_column,
|
||||
vector_column_length: int = default_vector_column_length,
|
||||
*,
|
||||
specific_metadata_columns: Optional[List[str]] = None,
|
||||
):
|
||||
# Check if the hdbcli package is installed
|
||||
if importlib.util.find_spec("hdbcli") is None:
|
||||
@@ -110,6 +113,9 @@ class HanaDB(VectorStore):
|
||||
self.metadata_column = HanaDB._sanitize_name(metadata_column)
|
||||
self.vector_column = HanaDB._sanitize_name(vector_column)
|
||||
self.vector_column_length = HanaDB._sanitize_int(vector_column_length)
|
||||
self.specific_metadata_columns = HanaDB._sanitize_specific_metadata_columns(
|
||||
specific_metadata_columns or []
|
||||
)
|
||||
|
||||
# Check if the table exists, and eventually create it
|
||||
if not self._table_exists(self.table_name):
|
||||
@@ -139,6 +145,8 @@ class HanaDB(VectorStore):
|
||||
["REAL_VECTOR"],
|
||||
self.vector_column_length,
|
||||
)
|
||||
for column_name in self.specific_metadata_columns:
|
||||
self._check_column(self.table_name, column_name)
|
||||
|
||||
def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def]
|
||||
sql_str = (
|
||||
@@ -156,7 +164,9 @@ class HanaDB(VectorStore):
|
||||
cur.close()
|
||||
return False
|
||||
|
||||
def _check_column(self, table_name, column_name, column_type, column_length=None): # type: ignore[no-untyped-def]
|
||||
def _check_column( # type: ignore[no-untyped-def]
|
||||
self, table_name, column_name, column_type=None, column_length=None
|
||||
):
|
||||
sql_str = (
|
||||
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
|
||||
"SCHEMA_NAME = CURRENT_SCHEMA "
|
||||
@@ -170,10 +180,11 @@ class HanaDB(VectorStore):
|
||||
if len(rows) == 0:
|
||||
raise AttributeError(f"Column {column_name} does not exist")
|
||||
# Check data type
|
||||
if rows[0][0] not in column_type:
|
||||
raise AttributeError(
|
||||
f"Column {column_name} has the wrong type: {rows[0][0]}"
|
||||
)
|
||||
if column_type:
|
||||
if rows[0][0] not in column_type:
|
||||
raise AttributeError(
|
||||
f"Column {column_name} has the wrong type: {rows[0][0]}"
|
||||
)
|
||||
# Check length, if parameter was provided
|
||||
if column_length is not None:
|
||||
if rows[0][1] != column_length:
|
||||
@@ -189,17 +200,20 @@ class HanaDB(VectorStore):
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_name(input_str: str) -> str: # type: ignore[misc]
|
||||
# Remove characters that are not alphanumeric or underscores
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_int(input_int: any) -> int: # type: ignore[valid-type]
|
||||
value = int(str(input_int))
|
||||
if value < -1:
|
||||
raise ValueError(f"Value ({value}) must not be smaller than -1")
|
||||
return int(str(input_int))
|
||||
|
||||
def _sanitize_list_float(embedding: List[float]) -> List[float]: # type: ignore[misc]
|
||||
@staticmethod
|
||||
def _sanitize_list_float(embedding: List[float]) -> List[float]:
|
||||
for value in embedding:
|
||||
if not isinstance(value, float):
|
||||
raise ValueError(f"Value ({value}) does not have type float")
|
||||
@@ -208,13 +222,36 @@ class HanaDB(VectorStore):
|
||||
# Compile pattern only once, for better performance
|
||||
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
||||
|
||||
def _sanitize_metadata_keys(metadata: dict) -> dict: # type: ignore[misc]
|
||||
@staticmethod
|
||||
def _sanitize_metadata_keys(metadata: dict) -> dict:
|
||||
for key in metadata.keys():
|
||||
if not HanaDB._compiled_pattern.match(key):
|
||||
raise ValueError(f"Invalid metadata key {key}")
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_specific_metadata_columns(
|
||||
specific_metadata_columns: List[str],
|
||||
) -> List[str]:
|
||||
metadata_columns = []
|
||||
for c in specific_metadata_columns:
|
||||
sanitized_name = HanaDB._sanitize_name(c)
|
||||
metadata_columns.append(sanitized_name)
|
||||
return metadata_columns
|
||||
|
||||
def _split_off_special_metadata(self, metadata: dict) -> Tuple[dict, list]:
|
||||
# Use provided values by default or fallback
|
||||
special_metadata = []
|
||||
|
||||
if not metadata:
|
||||
return {}, []
|
||||
|
||||
for column_name in self.specific_metadata_columns:
|
||||
special_metadata.append(metadata.get(column_name, None))
|
||||
|
||||
return metadata, special_metadata
|
||||
|
||||
def add_texts( # type: ignore[override]
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@@ -238,30 +275,45 @@ class HanaDB(VectorStore):
|
||||
if embeddings is None:
|
||||
embeddings = self.embedding.embed_documents(list(texts))
|
||||
|
||||
# Create sql parameters array
|
||||
sql_params = []
|
||||
for i, text in enumerate(texts):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
metadata, extracted_special_metadata = self._split_off_special_metadata(
|
||||
metadata
|
||||
)
|
||||
embedding = (
|
||||
embeddings[i]
|
||||
if embeddings
|
||||
else self.embedding.embed_documents([text])[0]
|
||||
)
|
||||
sql_params.append(
|
||||
(
|
||||
text,
|
||||
json.dumps(HanaDB._sanitize_metadata_keys(metadata)),
|
||||
f"[{','.join(map(str, embedding))}]",
|
||||
*extracted_special_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Insert data into the table
|
||||
cur = self.connection.cursor()
|
||||
try:
|
||||
# Insert data into the table
|
||||
for i, text in enumerate(texts):
|
||||
# Use provided values by default or fallback
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
embedding = (
|
||||
embeddings[i]
|
||||
if embeddings
|
||||
else self.embedding.embed_documents([text])[0]
|
||||
)
|
||||
sql_str = (
|
||||
f'INSERT INTO "{self.table_name}" ("{self.content_column}", '
|
||||
f'"{self.metadata_column}", "{self.vector_column}") '
|
||||
f"VALUES (?, ?, TO_REAL_VECTOR (?));"
|
||||
)
|
||||
cur.execute(
|
||||
sql_str,
|
||||
(
|
||||
text,
|
||||
json.dumps(HanaDB._sanitize_metadata_keys(metadata)),
|
||||
f"[{','.join(map(str, embedding))}]",
|
||||
),
|
||||
specific_metadata_columns_string = '", "'.join(
|
||||
self.specific_metadata_columns
|
||||
)
|
||||
if specific_metadata_columns_string:
|
||||
specific_metadata_columns_string = (
|
||||
', "' + specific_metadata_columns_string + '"'
|
||||
)
|
||||
sql_str = (
|
||||
f'INSERT INTO "{self.table_name}" ("{self.content_column}", '
|
||||
f'"{self.metadata_column}", '
|
||||
f'"{self.vector_column}"{specific_metadata_columns_string}) '
|
||||
f"VALUES (?, ?, TO_REAL_VECTOR (?)"
|
||||
f"{', ?' * len(self.specific_metadata_columns)});"
|
||||
)
|
||||
cur.executemany(sql_str, sql_params)
|
||||
finally:
|
||||
cur.close()
|
||||
return []
|
||||
@@ -279,6 +331,8 @@ class HanaDB(VectorStore):
|
||||
metadata_column: str = default_metadata_column,
|
||||
vector_column: str = default_vector_column,
|
||||
vector_column_length: int = default_vector_column_length,
|
||||
*,
|
||||
specific_metadata_columns: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create a HanaDB instance from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
@@ -297,6 +351,7 @@ class HanaDB(VectorStore):
|
||||
metadata_column=metadata_column,
|
||||
vector_column=vector_column,
|
||||
vector_column_length=vector_column_length, # -1 means dynamic length
|
||||
specific_metadata_columns=specific_metadata_columns,
|
||||
)
|
||||
instance.add_texts(texts, metadatas)
|
||||
return instance
|
||||
@@ -514,10 +569,12 @@ class HanaDB(VectorStore):
|
||||
f"Unsupported filter data-type: {type(filter_value)}"
|
||||
)
|
||||
|
||||
where_str += (
|
||||
f" JSON_VALUE({self.metadata_column}, '$.{key}')"
|
||||
f" {operator} {sql_param}"
|
||||
selector = (
|
||||
f' "{key}"'
|
||||
if key in self.specific_metadata_columns
|
||||
else f"JSON_VALUE({self.metadata_column}, '$.{key}')"
|
||||
)
|
||||
where_str += f"{selector} " f"{operator} {sql_param}"
|
||||
|
||||
return where_str, query_tuple
|
||||
|
||||
|
@@ -65,6 +65,7 @@ test_setup = ConfigData()
|
||||
|
||||
|
||||
def generateSchemaName(cursor): # type: ignore[no-untyped-def]
|
||||
# return "Langchain"
|
||||
cursor.execute(
|
||||
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
|
||||
"DUMMY;"
|
||||
@@ -85,6 +86,7 @@ def setup_module(module): # type: ignore[no-untyped-def]
|
||||
password=os.environ.get("HANA_DB_PASSWORD"),
|
||||
autocommit=True,
|
||||
sslValidateCertificate=False,
|
||||
# encrypt=True
|
||||
)
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
@@ -100,6 +102,7 @@ def setup_module(module): # type: ignore[no-untyped-def]
|
||||
|
||||
|
||||
def teardown_module(module): # type: ignore[no-untyped-def]
|
||||
# return
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE"
|
||||
@@ -112,7 +115,7 @@ def teardown_module(module): # type: ignore[no-untyped-def]
|
||||
|
||||
@pytest.fixture
|
||||
def texts() -> List[str]:
|
||||
return ["foo", "bar", "baz"]
|
||||
return ["foo", "bar", "baz", "bak", "cat"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -121,6 +124,8 @@ def metadatas() -> List[str]:
|
||||
{"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item]
|
||||
{"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item]
|
||||
{"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item]
|
||||
{"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"}, # type: ignore[list-item]
|
||||
{"start": 300, "quality": "ugly", "Owner": "Steve"}, # type: ignore[list-item]
|
||||
]
|
||||
|
||||
|
||||
@@ -640,14 +645,14 @@ def test_hanavector_delete_with_filter(texts: List[str], metadatas: List[dict])
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 3
|
||||
search_result = vectorDB.similarity_search(texts[0], 10)
|
||||
assert len(search_result) == 5
|
||||
|
||||
# Delete one of the three entries
|
||||
assert vectorDB.delete(filter={"start": 100, "end": 200})
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 2
|
||||
search_result = vectorDB.similarity_search(texts[0], 10)
|
||||
assert len(search_result) == 4
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
@@ -667,14 +672,14 @@ async def test_hanavector_delete_with_filter_async(
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 3
|
||||
search_result = vectorDB.similarity_search(texts[0], 10)
|
||||
assert len(search_result) == 5
|
||||
|
||||
# Delete one of the three entries
|
||||
assert await vectorDB.adelete(filter={"start": 100, "end": 200})
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 2
|
||||
search_result = vectorDB.similarity_search(texts[0], 10)
|
||||
assert len(search_result) == 4
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
@@ -861,7 +866,7 @@ def test_hanavector_filter_prepared_statement_params(
|
||||
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
|
||||
cur.execute(sql_str, (query_value))
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 2
|
||||
assert len(rows) == 3
|
||||
|
||||
# query_value = False
|
||||
query_value = "false" # type: ignore[assignment]
|
||||
@@ -1094,3 +1099,336 @@ def test_pgvector_with_with_metadata_filters_5(
|
||||
ids = [doc.metadata["id"] for doc in docs]
|
||||
assert len(ids) == len(expected_ids), test_filter
|
||||
assert set(ids).issubset(expected_ids), test_filter
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_preexisting_specific_columns_for_metadata_fill(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "PREEXISTING_FILTER_COLUMNS"
|
||||
# drop_table(test_setup.conn, table_name)
|
||||
|
||||
sql_str = (
|
||||
f'CREATE TABLE "{table_name}" ('
|
||||
f'"VEC_TEXT" NCLOB, '
|
||||
f'"VEC_META" NCLOB, '
|
||||
f'"VEC_VECTOR" REAL_VECTOR, '
|
||||
f'"Owner" NVARCHAR(100), '
|
||||
f'"quality" NVARCHAR(100));'
|
||||
)
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
specific_metadata_columns=["Owner", "quality"],
|
||||
)
|
||||
|
||||
c = 0
|
||||
try:
|
||||
sql_str = f'SELECT COUNT(*) FROM {table_name} WHERE "quality"=' f"'ugly'"
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
c = rows[0][0]
|
||||
finally:
|
||||
cur.close()
|
||||
assert c == 3
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "bar"
|
||||
|
||||
docs = vectorDB.similarity_search(
|
||||
"hello", k=5, filter={"start": 100, "quality": "good"}
|
||||
)
|
||||
assert len(docs) == 0
|
||||
|
||||
docs = vectorDB.similarity_search(
|
||||
"hello", k=5, filter={"start": 0, "quality": "good"}
|
||||
)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_preexisting_specific_columns_for_metadata_via_array(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "PREEXISTING_FILTER_COLUMNS_VIA_ARRAY"
|
||||
# drop_table(test_setup.conn, table_name)
|
||||
|
||||
sql_str = (
|
||||
f'CREATE TABLE "{table_name}" ('
|
||||
f'"VEC_TEXT" NCLOB, '
|
||||
f'"VEC_META" NCLOB, '
|
||||
f'"VEC_VECTOR" REAL_VECTOR, '
|
||||
f'"Owner" NVARCHAR(100), '
|
||||
f'"quality" NVARCHAR(100));'
|
||||
)
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
specific_metadata_columns=["quality"],
|
||||
)
|
||||
|
||||
c = 0
|
||||
try:
|
||||
sql_str = f'SELECT COUNT(*) FROM {table_name} WHERE "quality"=' f"'ugly'"
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
c = rows[0][0]
|
||||
finally:
|
||||
cur.close()
|
||||
assert c == 3
|
||||
|
||||
try:
|
||||
sql_str = f'SELECT COUNT(*) FROM {table_name} WHERE "Owner"=' f"'Steve'"
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
c = rows[0][0]
|
||||
finally:
|
||||
cur.close()
|
||||
assert c == 0
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "bar"
|
||||
|
||||
docs = vectorDB.similarity_search(
|
||||
"hello", k=5, filter={"start": 100, "quality": "good"}
|
||||
)
|
||||
assert len(docs) == 0
|
||||
|
||||
docs = vectorDB.similarity_search(
|
||||
"hello", k=5, filter={"start": 0, "quality": "good"}
|
||||
)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_preexisting_specific_columns_for_metadata_multiple_columns(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "PREEXISTING_FILTER_MULTIPLE_COLUMNS"
|
||||
# drop_table(test_setup.conn, table_name)
|
||||
|
||||
sql_str = (
|
||||
f'CREATE TABLE "{table_name}" ('
|
||||
f'"VEC_TEXT" NCLOB, '
|
||||
f'"VEC_META" NCLOB, '
|
||||
f'"VEC_VECTOR" REAL_VECTOR, '
|
||||
f'"quality" NVARCHAR(100), '
|
||||
f'"start" INTEGER);'
|
||||
)
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
specific_metadata_columns=["quality", "start"],
|
||||
)
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "bar"
|
||||
|
||||
docs = vectorDB.similarity_search(
|
||||
"hello", k=5, filter={"start": 100, "quality": "good"}
|
||||
)
|
||||
assert len(docs) == 0
|
||||
|
||||
docs = vectorDB.similarity_search(
|
||||
"hello", k=5, filter={"start": 0, "quality": "good"}
|
||||
)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_preexisting_specific_columns_for_metadata_empty_columns(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "PREEXISTING_FILTER_MULTIPLE_COLUMNS_EMPTY"
|
||||
# drop_table(test_setup.conn, table_name)
|
||||
|
||||
sql_str = (
|
||||
f'CREATE TABLE "{table_name}" ('
|
||||
f'"VEC_TEXT" NCLOB, '
|
||||
f'"VEC_META" NCLOB, '
|
||||
f'"VEC_VECTOR" REAL_VECTOR, '
|
||||
f'"quality" NVARCHAR(100), '
|
||||
f'"ready" BOOLEAN, '
|
||||
f'"start" INTEGER);'
|
||||
)
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
specific_metadata_columns=["quality", "ready", "start"],
|
||||
)
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "bar"
|
||||
|
||||
docs = vectorDB.similarity_search(
|
||||
"hello", k=5, filter={"start": 100, "quality": "good"}
|
||||
)
|
||||
assert len(docs) == 0
|
||||
|
||||
docs = vectorDB.similarity_search(
|
||||
"hello", k=5, filter={"start": 0, "quality": "good"}
|
||||
)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"ready": True})
|
||||
assert len(docs) == 3
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_preexisting_specific_columns_for_metadata_wrong_type_or_non_existing(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "PREEXISTING_FILTER_COLUMNS_WRONG_TYPE"
|
||||
# drop_table(test_setup.conn, table_name)
|
||||
|
||||
sql_str = (
|
||||
f'CREATE TABLE "{table_name}" ('
|
||||
f'"VEC_TEXT" NCLOB, '
|
||||
f'"VEC_META" NCLOB, '
|
||||
f'"VEC_VECTOR" REAL_VECTOR, '
|
||||
f'"quality" INTEGER); '
|
||||
)
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
# Check if table is created
|
||||
exception_occured = False
|
||||
try:
|
||||
HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
specific_metadata_columns=["quality"],
|
||||
)
|
||||
exception_occured = False
|
||||
except dbapi.Error: # Nothing we should do here, hdbcli will throw an error
|
||||
exception_occured = True
|
||||
assert exception_occured # Check if table is created
|
||||
|
||||
exception_occured = False
|
||||
try:
|
||||
HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
specific_metadata_columns=["NonExistingColumn"],
|
||||
)
|
||||
exception_occured = False
|
||||
except AttributeError: # Nothing we should do here, hdbcli will throw an error
|
||||
exception_occured = True
|
||||
assert exception_occured
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_preexisting_specific_columns_for_returned_metadata_completeness(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "PREEXISTING_FILTER_COLUMNS_METADATA_COMPLETENESS"
|
||||
# drop_table(test_setup.conn, table_name)
|
||||
|
||||
sql_str = (
|
||||
f'CREATE TABLE "{table_name}" ('
|
||||
f'"VEC_TEXT" NCLOB, '
|
||||
f'"VEC_META" NCLOB, '
|
||||
f'"VEC_VECTOR" REAL_VECTOR, '
|
||||
f'"quality" NVARCHAR(100), '
|
||||
f'"NonExisting" NVARCHAR(100), '
|
||||
f'"ready" BOOLEAN, '
|
||||
f'"start" INTEGER);'
|
||||
)
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
specific_metadata_columns=["quality", "ready", "start", "NonExisting"],
|
||||
)
|
||||
|
||||
docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
assert docs[0].metadata["end"] == 100
|
||||
assert docs[0].metadata["start"] == 0
|
||||
assert docs[0].metadata["quality"] == "good"
|
||||
assert docs[0].metadata["ready"]
|
||||
assert "NonExisting" not in docs[0].metadata.keys()
|
||||
|
Reference in New Issue
Block a user