community: add HANA dialect to SQLDatabase (#30475)

This PR includes support for HANA dialect in SQLDatabase, which is a
wrapper class for SQLAlchemy.

Currently, it is unable to set schema name when using HANA DB with
Langchain. And, it does not show any message to user so that it makes
hard for user to figure out why the SQL does not work as expected.

Here is the reference document for HANA DB to set schema for the
session.

- [SET SCHEMA Statement (Session
Management)](https://help.sap.com/docs/SAP_HANA_PLATFORM/4fe29514fd584807ac9f2a04f6754767/20fd550375191014b886a338afb4cd5f.html)
This commit is contained in:
Kyungho Byoun 2025-03-28 04:19:50 +09:00 committed by GitHub
parent 1cf91a2386
commit e6b6c07395
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 1 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import re
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
import sqlalchemy
@ -44,6 +45,16 @@ def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
def sanitize_schema(schema: str) -> str:
"""Sanitize a schema name to only contain letters, digits, and underscores."""
if not re.match(r"^[a-zA-Z0-9_]+$", schema):
raise ValueError(
f"Schema name '{schema}' contains invalid characters. "
"Schema names must contain only letters, digits, and underscores."
)
return schema
class SQLDatabase:
"""SQLAlchemy wrapper around a database."""
@ -465,6 +476,11 @@ class SQLDatabase:
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "hana":
connection.exec_driver_sql(
f"SET SCHEMA {sanitize_schema(self._schema)}",
execution_options=execution_options,
)
if isinstance(command, str):
command = text(command)

View File

@ -16,7 +16,11 @@ from sqlalchemy import (
)
from sqlalchemy.engine import Engine, Result
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
from langchain_community.utilities.sql_database import (
SQLDatabase,
sanitize_schema,
truncate_word,
)
is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1
@ -262,3 +266,25 @@ def test_truncate_word() -> None:
assert truncate_word("Hello World", length=-10) == "Hello World"
assert truncate_word("Hello World", length=5, suffix="!!!") == "He!!!"
assert truncate_word("Hello World", length=12, suffix="!!!") == "Hello World"
def test_sanitize_schema() -> None:
valid_schema_names = [
"test_schema",
"schema123",
"TEST_SCHEMA_123",
"_schema_",
]
for schema in valid_schema_names:
assert sanitize_schema(schema) == schema
invalid_schema_names = [
"test-schema",
"schema.name",
"schema$",
"schema name",
]
for schema in invalid_schema_names:
with pytest.raises(ValueError) as ex:
sanitize_schema(schema)
assert f"Schema name '{schema}' contains invalid characters" in str(ex.value)