diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index 2b6b84ca4a1..1a94434afde 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union import sqlalchemy @@ -44,6 +45,16 @@ def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix +def sanitize_schema(schema: str) -> str: + """Sanitize a schema name to only contain letters, digits, and underscores.""" + if not re.match(r"^[a-zA-Z0-9_]+$", schema): + raise ValueError( + f"Schema name '{schema}' contains invalid characters. " + "Schema names must contain only letters, digits, and underscores." + ) + return schema + + class SQLDatabase: """SQLAlchemy wrapper around a database.""" @@ -465,6 +476,11 @@ class SQLDatabase: (self._schema,), execution_options=execution_options, ) + elif self.dialect == "hana": + connection.exec_driver_sql( + f"SET SCHEMA {sanitize_schema(self._schema)}", + execution_options=execution_options, + ) if isinstance(command, str): command = text(command) diff --git a/libs/community/tests/unit_tests/test_sql_database.py b/libs/community/tests/unit_tests/test_sql_database.py index 6acb734a543..349abd19313 100644 --- a/libs/community/tests/unit_tests/test_sql_database.py +++ b/libs/community/tests/unit_tests/test_sql_database.py @@ -16,7 +16,11 @@ from sqlalchemy import ( ) from sqlalchemy.engine import Engine, Result -from langchain_community.utilities.sql_database import SQLDatabase, truncate_word +from langchain_community.utilities.sql_database import ( + SQLDatabase, + sanitize_schema, + truncate_word, +) is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1 @@ -262,3 +266,25 @@ def test_truncate_word() -> None: assert truncate_word("Hello World", length=-10) == "Hello World" assert truncate_word("Hello World", length=5, suffix="!!!") == "He!!!" assert truncate_word("Hello World", length=12, suffix="!!!") == "Hello World" + + +def test_sanitize_schema() -> None: + valid_schema_names = [ + "test_schema", + "schema123", + "TEST_SCHEMA_123", + "_schema_", + ] + for schema in valid_schema_names: + assert sanitize_schema(schema) == schema + + invalid_schema_names = [ + "test-schema", + "schema.name", + "schema$", + "schema name", + ] + for schema in invalid_schema_names: + with pytest.raises(ValueError) as ex: + sanitize_schema(schema) + assert f"Schema name '{schema}' contains invalid characters" in str(ex.value)