mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
community[patch]: Performant filter columns option for Hanavector (#21971)
**Description:** Backwards compatible extension of the initialisation interface of HanaDB to allow the user to specify specific_metadata_columns that are used for metadata storage of selected keys which yields increased filter performance. Any not-mentioned metadata remains in the general metadata column as part of a JSON string. Furthermore switched to executemany for batch inserts into HanaDB. **Issue:** N/A **Dependencies:** no new dependencies added **Twitter handle:** @sapopensource --------- Co-authored-by: Martin Kolb <martin.kolb@sap.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""SAP HANA Cloud Vector Engine"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
@@ -85,6 +86,8 @@ class HanaDB(VectorStore):
|
||||
metadata_column: str = default_metadata_column,
|
||||
vector_column: str = default_vector_column,
|
||||
vector_column_length: int = default_vector_column_length,
|
||||
*,
|
||||
specific_metadata_columns: Optional[List[str]] = None,
|
||||
):
|
||||
# Check if the hdbcli package is installed
|
||||
if importlib.util.find_spec("hdbcli") is None:
|
||||
@@ -110,6 +113,9 @@ class HanaDB(VectorStore):
|
||||
self.metadata_column = HanaDB._sanitize_name(metadata_column)
|
||||
self.vector_column = HanaDB._sanitize_name(vector_column)
|
||||
self.vector_column_length = HanaDB._sanitize_int(vector_column_length)
|
||||
self.specific_metadata_columns = HanaDB._sanitize_specific_metadata_columns(
|
||||
specific_metadata_columns or []
|
||||
)
|
||||
|
||||
# Check if the table exists, and eventually create it
|
||||
if not self._table_exists(self.table_name):
|
||||
@@ -139,6 +145,8 @@ class HanaDB(VectorStore):
|
||||
["REAL_VECTOR"],
|
||||
self.vector_column_length,
|
||||
)
|
||||
for column_name in self.specific_metadata_columns:
|
||||
self._check_column(self.table_name, column_name)
|
||||
|
||||
def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def]
|
||||
sql_str = (
|
||||
@@ -156,7 +164,9 @@ class HanaDB(VectorStore):
|
||||
cur.close()
|
||||
return False
|
||||
|
||||
def _check_column(self, table_name, column_name, column_type, column_length=None): # type: ignore[no-untyped-def]
|
||||
def _check_column( # type: ignore[no-untyped-def]
|
||||
self, table_name, column_name, column_type=None, column_length=None
|
||||
):
|
||||
sql_str = (
|
||||
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
|
||||
"SCHEMA_NAME = CURRENT_SCHEMA "
|
||||
@@ -170,10 +180,11 @@ class HanaDB(VectorStore):
|
||||
if len(rows) == 0:
|
||||
raise AttributeError(f"Column {column_name} does not exist")
|
||||
# Check data type
|
||||
if rows[0][0] not in column_type:
|
||||
raise AttributeError(
|
||||
f"Column {column_name} has the wrong type: {rows[0][0]}"
|
||||
)
|
||||
if column_type:
|
||||
if rows[0][0] not in column_type:
|
||||
raise AttributeError(
|
||||
f"Column {column_name} has the wrong type: {rows[0][0]}"
|
||||
)
|
||||
# Check length, if parameter was provided
|
||||
if column_length is not None:
|
||||
if rows[0][1] != column_length:
|
||||
@@ -189,17 +200,20 @@ class HanaDB(VectorStore):
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_name(input_str: str) -> str: # type: ignore[misc]
|
||||
# Remove characters that are not alphanumeric or underscores
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_int(input_int: any) -> int: # type: ignore[valid-type]
|
||||
value = int(str(input_int))
|
||||
if value < -1:
|
||||
raise ValueError(f"Value ({value}) must not be smaller than -1")
|
||||
return int(str(input_int))
|
||||
|
||||
def _sanitize_list_float(embedding: List[float]) -> List[float]: # type: ignore[misc]
|
||||
@staticmethod
|
||||
def _sanitize_list_float(embedding: List[float]) -> List[float]:
|
||||
for value in embedding:
|
||||
if not isinstance(value, float):
|
||||
raise ValueError(f"Value ({value}) does not have type float")
|
||||
@@ -208,13 +222,36 @@ class HanaDB(VectorStore):
|
||||
# Compile pattern only once, for better performance
|
||||
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
||||
|
||||
def _sanitize_metadata_keys(metadata: dict) -> dict: # type: ignore[misc]
|
||||
@staticmethod
|
||||
def _sanitize_metadata_keys(metadata: dict) -> dict:
|
||||
for key in metadata.keys():
|
||||
if not HanaDB._compiled_pattern.match(key):
|
||||
raise ValueError(f"Invalid metadata key {key}")
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_specific_metadata_columns(
|
||||
specific_metadata_columns: List[str],
|
||||
) -> List[str]:
|
||||
metadata_columns = []
|
||||
for c in specific_metadata_columns:
|
||||
sanitized_name = HanaDB._sanitize_name(c)
|
||||
metadata_columns.append(sanitized_name)
|
||||
return metadata_columns
|
||||
|
||||
def _split_off_special_metadata(self, metadata: dict) -> Tuple[dict, list]:
|
||||
# Use provided values by default or fallback
|
||||
special_metadata = []
|
||||
|
||||
if not metadata:
|
||||
return {}, []
|
||||
|
||||
for column_name in self.specific_metadata_columns:
|
||||
special_metadata.append(metadata.get(column_name, None))
|
||||
|
||||
return metadata, special_metadata
|
||||
|
||||
def add_texts( # type: ignore[override]
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@@ -238,30 +275,45 @@ class HanaDB(VectorStore):
|
||||
if embeddings is None:
|
||||
embeddings = self.embedding.embed_documents(list(texts))
|
||||
|
||||
# Create sql parameters array
|
||||
sql_params = []
|
||||
for i, text in enumerate(texts):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
metadata, extracted_special_metadata = self._split_off_special_metadata(
|
||||
metadata
|
||||
)
|
||||
embedding = (
|
||||
embeddings[i]
|
||||
if embeddings
|
||||
else self.embedding.embed_documents([text])[0]
|
||||
)
|
||||
sql_params.append(
|
||||
(
|
||||
text,
|
||||
json.dumps(HanaDB._sanitize_metadata_keys(metadata)),
|
||||
f"[{','.join(map(str, embedding))}]",
|
||||
*extracted_special_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Insert data into the table
|
||||
cur = self.connection.cursor()
|
||||
try:
|
||||
# Insert data into the table
|
||||
for i, text in enumerate(texts):
|
||||
# Use provided values by default or fallback
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
embedding = (
|
||||
embeddings[i]
|
||||
if embeddings
|
||||
else self.embedding.embed_documents([text])[0]
|
||||
)
|
||||
sql_str = (
|
||||
f'INSERT INTO "{self.table_name}" ("{self.content_column}", '
|
||||
f'"{self.metadata_column}", "{self.vector_column}") '
|
||||
f"VALUES (?, ?, TO_REAL_VECTOR (?));"
|
||||
)
|
||||
cur.execute(
|
||||
sql_str,
|
||||
(
|
||||
text,
|
||||
json.dumps(HanaDB._sanitize_metadata_keys(metadata)),
|
||||
f"[{','.join(map(str, embedding))}]",
|
||||
),
|
||||
specific_metadata_columns_string = '", "'.join(
|
||||
self.specific_metadata_columns
|
||||
)
|
||||
if specific_metadata_columns_string:
|
||||
specific_metadata_columns_string = (
|
||||
', "' + specific_metadata_columns_string + '"'
|
||||
)
|
||||
sql_str = (
|
||||
f'INSERT INTO "{self.table_name}" ("{self.content_column}", '
|
||||
f'"{self.metadata_column}", '
|
||||
f'"{self.vector_column}"{specific_metadata_columns_string}) '
|
||||
f"VALUES (?, ?, TO_REAL_VECTOR (?)"
|
||||
f"{', ?' * len(self.specific_metadata_columns)});"
|
||||
)
|
||||
cur.executemany(sql_str, sql_params)
|
||||
finally:
|
||||
cur.close()
|
||||
return []
|
||||
@@ -279,6 +331,8 @@ class HanaDB(VectorStore):
|
||||
metadata_column: str = default_metadata_column,
|
||||
vector_column: str = default_vector_column,
|
||||
vector_column_length: int = default_vector_column_length,
|
||||
*,
|
||||
specific_metadata_columns: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create a HanaDB instance from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
@@ -297,6 +351,7 @@ class HanaDB(VectorStore):
|
||||
metadata_column=metadata_column,
|
||||
vector_column=vector_column,
|
||||
vector_column_length=vector_column_length, # -1 means dynamic length
|
||||
specific_metadata_columns=specific_metadata_columns,
|
||||
)
|
||||
instance.add_texts(texts, metadatas)
|
||||
return instance
|
||||
@@ -514,10 +569,12 @@ class HanaDB(VectorStore):
|
||||
f"Unsupported filter data-type: {type(filter_value)}"
|
||||
)
|
||||
|
||||
where_str += (
|
||||
f" JSON_VALUE({self.metadata_column}, '$.{key}')"
|
||||
f" {operator} {sql_param}"
|
||||
selector = (
|
||||
f' "{key}"'
|
||||
if key in self.specific_metadata_columns
|
||||
else f"JSON_VALUE({self.metadata_column}, '$.{key}')"
|
||||
)
|
||||
where_str += f"{selector} " f"{operator} {sql_param}"
|
||||
|
||||
return where_str, query_tuple
|
||||
|
||||
|
Reference in New Issue
Block a user