mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 22:28:03 +00:00
community: add HANA dialect to SQLDatabase (#30475)
This PR includes support for HANA dialect in SQLDatabase, which is a wrapper class for SQLAlchemy. Currently, it is unable to set schema name when using HANA DB with Langchain. And, it does not show any message to user so that it makes hard for user to figure out why the SQL does not work as expected. Here is the reference document for HANA DB to set schema for the session. - [SET SCHEMA Statement (Session Management)](https://help.sap.com/docs/SAP_HANA_PLATFORM/4fe29514fd584807ac9f2a04f6754767/20fd550375191014b886a338afb4cd5f.html)
This commit is contained in:
parent
1cf91a2386
commit
e6b6c07395
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy
|
||||
@ -44,6 +45,16 @@ def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
|
||||
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
|
||||
|
||||
|
||||
def sanitize_schema(schema: str) -> str:
|
||||
"""Sanitize a schema name to only contain letters, digits, and underscores."""
|
||||
if not re.match(r"^[a-zA-Z0-9_]+$", schema):
|
||||
raise ValueError(
|
||||
f"Schema name '{schema}' contains invalid characters. "
|
||||
"Schema names must contain only letters, digits, and underscores."
|
||||
)
|
||||
return schema
|
||||
|
||||
|
||||
class SQLDatabase:
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
@ -465,6 +476,11 @@ class SQLDatabase:
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "hana":
|
||||
connection.exec_driver_sql(
|
||||
f"SET SCHEMA {sanitize_schema(self._schema)}",
|
||||
execution_options=execution_options,
|
||||
)
|
||||
|
||||
if isinstance(command, str):
|
||||
command = text(command)
|
||||
|
@ -16,7 +16,11 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.engine import Engine, Result
|
||||
|
||||
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
|
||||
from langchain_community.utilities.sql_database import (
|
||||
SQLDatabase,
|
||||
sanitize_schema,
|
||||
truncate_word,
|
||||
)
|
||||
|
||||
is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1
|
||||
|
||||
@ -262,3 +266,25 @@ def test_truncate_word() -> None:
|
||||
assert truncate_word("Hello World", length=-10) == "Hello World"
|
||||
assert truncate_word("Hello World", length=5, suffix="!!!") == "He!!!"
|
||||
assert truncate_word("Hello World", length=12, suffix="!!!") == "Hello World"
|
||||
|
||||
|
||||
def test_sanitize_schema() -> None:
|
||||
valid_schema_names = [
|
||||
"test_schema",
|
||||
"schema123",
|
||||
"TEST_SCHEMA_123",
|
||||
"_schema_",
|
||||
]
|
||||
for schema in valid_schema_names:
|
||||
assert sanitize_schema(schema) == schema
|
||||
|
||||
invalid_schema_names = [
|
||||
"test-schema",
|
||||
"schema.name",
|
||||
"schema$",
|
||||
"schema name",
|
||||
]
|
||||
for schema in invalid_schema_names:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
sanitize_schema(schema)
|
||||
assert f"Schema name '{schema}' contains invalid characters" in str(ex.value)
|
||||
|
Loading…
Reference in New Issue
Block a user