mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community[patch]: More flexible handling for entity names in vector store "HANA Cloud" (#19523)
- **Description:** Added support for lower-case and mixed-case names The names for tables and columns previouly had to be UPPER_CASE. With this enhancement, also lower_case and MixedCase are supported, - **Issue:** N/A - **Dependencies:** no new dependecies added - **Twitter handle:** @sapopensource
This commit is contained in:
parent
a1ff21f90f
commit
e5bdb26f76
@ -92,10 +92,10 @@ class HanaDB(VectorStore):
|
|||||||
# Check if the table exists, and eventually create it
|
# Check if the table exists, and eventually create it
|
||||||
if not self._table_exists(self.table_name):
|
if not self._table_exists(self.table_name):
|
||||||
sql_str = (
|
sql_str = (
|
||||||
f"CREATE TABLE {self.table_name}("
|
f'CREATE TABLE "{self.table_name}"('
|
||||||
f"{self.content_column} NCLOB, "
|
f'"{self.content_column}" NCLOB, '
|
||||||
f"{self.metadata_column} NCLOB, "
|
f'"{self.metadata_column}" NCLOB, '
|
||||||
f"{self.vector_column} REAL_VECTOR "
|
f'"{self.vector_column}" REAL_VECTOR '
|
||||||
)
|
)
|
||||||
if self.vector_column_length == -1:
|
if self.vector_column_length == -1:
|
||||||
sql_str += ");"
|
sql_str += ");"
|
||||||
@ -228,8 +228,8 @@ class HanaDB(VectorStore):
|
|||||||
else self.embedding.embed_documents([text])[0]
|
else self.embedding.embed_documents([text])[0]
|
||||||
)
|
)
|
||||||
sql_str = (
|
sql_str = (
|
||||||
f"INSERT INTO {self.table_name} ({self.content_column}, "
|
f'INSERT INTO "{self.table_name}" ("{self.content_column}", '
|
||||||
f"{self.metadata_column}, {self.vector_column}) "
|
f'"{self.metadata_column}", "{self.vector_column}") '
|
||||||
f"VALUES (?, ?, TO_REAL_VECTOR (?));"
|
f"VALUES (?, ?, TO_REAL_VECTOR (?));"
|
||||||
)
|
)
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@ -340,12 +340,12 @@ class HanaDB(VectorStore):
|
|||||||
embedding_as_str = ",".join(map(str, embedding))
|
embedding_as_str = ",".join(map(str, embedding))
|
||||||
sql_str = (
|
sql_str = (
|
||||||
f"SELECT TOP {k}"
|
f"SELECT TOP {k}"
|
||||||
f" {self.content_column}, " # row[0]
|
f' "{self.content_column}", ' # row[0]
|
||||||
f" {self.metadata_column}, " # row[1]
|
f' "{self.metadata_column}", ' # row[1]
|
||||||
f" TO_NVARCHAR({self.vector_column}), " # row[2]
|
f' TO_NVARCHAR("{self.vector_column}"), ' # row[2]
|
||||||
f" {distance_func_name}({self.vector_column}, TO_REAL_VECTOR "
|
f' {distance_func_name}("{self.vector_column}", TO_REAL_VECTOR '
|
||||||
f" (ARRAY({embedding_as_str}))) AS CS " # row[3]
|
f" (ARRAY({embedding_as_str}))) AS CS " # row[3]
|
||||||
f"FROM {self.table_name}"
|
f'FROM "{self.table_name}"'
|
||||||
)
|
)
|
||||||
order_str = f" order by CS {HANA_DISTANCE_FUNCTION[self.distance_strategy][1]}"
|
order_str = f" order by CS {HANA_DISTANCE_FUNCTION[self.distance_strategy][1]}"
|
||||||
where_str, query_tuple = self._create_where_by_filter(filter)
|
where_str, query_tuple = self._create_where_by_filter(filter)
|
||||||
@ -451,7 +451,7 @@ class HanaDB(VectorStore):
|
|||||||
raise ValueError("Parameter 'filter' is required when calling 'delete'")
|
raise ValueError("Parameter 'filter' is required when calling 'delete'")
|
||||||
|
|
||||||
where_str, query_tuple = self._create_where_by_filter(filter)
|
where_str, query_tuple = self._create_where_by_filter(filter)
|
||||||
sql_str = f"DELETE FROM {self.table_name} {where_str}"
|
sql_str = f'DELETE FROM "{self.table_name}" {where_str}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cur = self.connection.cursor()
|
cur = self.connection.cursor()
|
||||||
|
@ -890,3 +890,37 @@ def test_invalid_metadata_keys(texts: List[str], metadatas: List[dict]) -> None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
exception_occured = True
|
exception_occured = True
|
||||||
assert exception_occured
|
assert exception_occured
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||||
|
def test_hanavector_table_mixed_case_names(texts: List[str]) -> None:
|
||||||
|
table_name = "MyTableName"
|
||||||
|
content_column = "TextColumn"
|
||||||
|
metadata_column = "MetaColumn"
|
||||||
|
vector_column = "VectorColumn"
|
||||||
|
|
||||||
|
vectordb = HanaDB(
|
||||||
|
connection=test_setup.conn,
|
||||||
|
embedding=embedding,
|
||||||
|
distance_strategy=DistanceStrategy.COSINE,
|
||||||
|
table_name=table_name,
|
||||||
|
content_column=content_column,
|
||||||
|
metadata_column=metadata_column,
|
||||||
|
vector_column=vector_column,
|
||||||
|
)
|
||||||
|
|
||||||
|
vectordb.add_texts(texts=texts)
|
||||||
|
|
||||||
|
# check that embeddings have been created in the table
|
||||||
|
number_of_texts = len(texts)
|
||||||
|
number_of_rows = -1
|
||||||
|
sql_str = f'SELECT COUNT(*) FROM "{table_name}"'
|
||||||
|
cur = test_setup.conn.cursor()
|
||||||
|
cur.execute(sql_str)
|
||||||
|
if cur.has_result_set():
|
||||||
|
rows = cur.fetchall()
|
||||||
|
number_of_rows = rows[0][0]
|
||||||
|
assert number_of_rows == number_of_texts
|
||||||
|
|
||||||
|
# check results of similarity search
|
||||||
|
assert texts[0] == vectordb.similarity_search(texts[0], 1)[0].page_content
|
||||||
|
Loading…
Reference in New Issue
Block a user