community: add 'get_col_comments' option for retrieve database columns comments (#30646)

## Description
Added support for retrieving column comments in the SQL Database
utility. This feature allows users to see comments associated with
database columns when querying table information. Column comments
provide valuable metadata that helps LLMs better understand the
semantics and purpose of database columns.

A new optional parameter `get_col_comments` was added to the
`get_table_info` method, defaulting to `False` for backward
compatibility. When set to `True`, it retrieves and formats column
comments for each table.

Currently, this feature is supported on PostgreSQL, MySQL, and Oracle
databases.

## Implementation
You should create Table with column comments before.

```python
db = SQLDatabase.from_uri("YOUR_DB_URI")
print(db.get_table_info(get_col_comments=True)) 
```
## Result
```
CREATE TABLE test_table (
	name VARCHAR
        school VARCHAR)
/*
Column Comments: {'name': person name, 'school":school_name}
*/

/*
3 rows from test_table:
name
a
b
c
*/
```

## Benefits
1. Enhances LLM's understanding of database schema semantics
2. Preserves valuable domain knowledge embedded in database design
3. Improves accuracy of SQL query generation
4. Provides more context for data interpretation

Tests are available in
`langchain/libs/community/tests/test_sql_get_table_info.py`.

---------

Co-authored-by: chbae <chbae@gcsc.co.kr>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Bae-ChangHyun 2025-04-29 00:19:46 +09:00 committed by GitHub
parent 3fb0a55122
commit a2863f8757
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 308 additions and 6 deletions

View File

@ -316,7 +316,9 @@ class SQLDatabase:
"""Information about all tables in the database."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
def get_table_info(
self, table_names: Optional[List[str]] = None, get_col_comments: bool = False
) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
@ -356,14 +358,39 @@ class SQLDatabase:
tables.append(self._custom_table_info[table.name])
continue
# Ignore JSON datatyped columns
for k, v in table.columns.items(): # AttributeError: items in sqlalchemy v1
if type(v.type) is NullType:
table._columns.remove(v)
# Ignore JSON datatyped columns - SQLAlchemy v1.x compatibility
try:
# For SQLAlchemy v2.x
for k, v in table.columns.items():
if type(v.type) is NullType:
table._columns.remove(v)
except AttributeError:
# For SQLAlchemy v1.x
for k, v in dict(table.columns).items():
if type(v.type) is NullType:
table._columns.remove(v)
# add create table command
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
# Add column comments as dictionary
if get_col_comments:
try:
column_comments_dict = {}
for column in table.columns:
if column.comment:
column_comments_dict[column.name] = column.comment
if column_comments_dict:
table_info += (
f"\n\n/*\nColumn Comments: {column_comments_dict}\n*/"
)
except Exception:
raise ValueError(
"Column comments are available on PostgreSQL, MySQL, Oracle"
)
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
)

View File

@ -0,0 +1,259 @@
import unittest
from typing import Dict, Optional
from unittest.mock import MagicMock, patch
from sqlalchemy import Column, Integer, MetaData, String, Table
from langchain_community.utilities.sql_database import SQLDatabase
class TestSQLDatabaseComments(unittest.TestCase):
"""Test class for column comment functionality in SQLDatabase"""
def setUp(self) -> None:
"""Setup before each test"""
# Mock Engine
self.mock_engine = MagicMock()
self.mock_engine.dialect.name = "postgresql" # Default to PostgreSQL
# Mock inspector and start patch *before* SQLDatabase initialization
self.mock_inspector = MagicMock()
# Mock table name list and other inspector methods called during init
self.mock_inspector.get_table_names.return_value = ["test_table"]
self.mock_inspector.get_view_names.return_value = []
self.mock_inspector.get_indexes.return_value = []
# Mock get_columns to return something reasonable for reflection
self.mock_inspector.get_columns.return_value = [
{
"name": "id",
"type": Integer(),
"nullable": False,
"default": None,
"autoincrement": "auto",
"comment": None,
},
{
"name": "name",
"type": String(100),
"nullable": True,
"default": None,
"autoincrement": "auto",
"comment": None,
},
{
"name": "age",
"type": Integer(),
"nullable": True,
"default": None,
"autoincrement": "auto",
"comment": None,
},
]
# Mock get_pk_constraint for reflection
self.mock_inspector.get_pk_constraint.return_value = {
"constrained_columns": ["id"],
"name": None,
}
# Mock get_foreign_keys for reflection
self.mock_inspector.get_foreign_keys.return_value = []
# Patch sqlalchemy.inspect to return our mock inspector
self.patch_inspector = patch(
"langchain_community.utilities.sql_database.inspect",
return_value=self.mock_inspector,
)
# Start the patch *before* creating the SQLDatabase instance
self.mock_inspect = self.patch_inspector.start()
# Mock metadata
self.metadata = MetaData()
# Create test database object *after* patching inspect
try:
self.db = SQLDatabase(
engine=self.mock_engine,
metadata=self.metadata,
lazy_table_reflection=True,
)
except Exception as e:
self.fail(f"Unexpected exception during SQLDatabase init: {e}")
def tearDown(self) -> None:
"""Cleanup after each test"""
self.patch_inspector.stop()
def setup_mock_table_with_comments(
self, dialect: str, comments: Optional[Dict[str, str]] = None
) -> Table:
"""Setup a mock table with comments
Args:
dialect (str): Database dialect to test (postgresql, mysql, oracle)
comments (dict, optional): Column comments. Uses default comments if None
Returns:
Table: The created mock table
"""
# Default comments
if comments is None:
comments = {
"id": "Primary key",
"name": "Name of the person",
"age": "Age of the person",
}
# Set engine dialect
self.mock_engine.dialect.name = dialect
# Clear existing metadata if necessary, or use a fresh MetaData object
self.metadata.clear()
# Create test table
test_table = Table(
"test_table",
self.metadata,
Column("id", Integer, primary_key=True, comment=comments.get("id")),
Column("name", String(100), comment=comments.get("name")),
Column("age", Integer, comment=comments.get("age")),
)
# Mock reflection to return the columns with comments
# This is crucial because lazy reflection will call inspect later
self.mock_inspector.get_columns.return_value = [
{
"name": "id",
"type": Integer(),
"nullable": False,
"default": None,
"autoincrement": "auto",
"comment": comments.get("id"),
},
{
"name": "name",
"type": String(100),
"nullable": True,
"default": None,
"autoincrement": "auto",
"comment": comments.get("name"),
},
{
"name": "age",
"type": Integer(),
"nullable": True,
"default": None,
"autoincrement": "auto",
"comment": comments.get("age"),
},
]
self.mock_inspector.get_table_names.return_value = [
"test_table"
] # Ensure table is discoverable
# No need to mock CreateTable here, let the actual code call it.
# We will patch it during the get_table_info call in the tests.
# No need to manually add table to metadata, reflection handles it
# self.metadata._add_table("test_table", None, test_table)
return test_table
def _run_test_with_mocked_createtable(self, dialect: str) -> None:
"""Helper function to run comment tests with CreateTable mocked."""
self.setup_mock_table_with_comments(dialect)
# Define the expected CREATE TABLE string
expected_create_table_sql = (
"CREATE TABLE test_table (\n\tid INTEGER NOT NULL, "
"\n\tname VARCHAR(100), \n\tage INTEGER, \n\tPRIMARY KEY (id)\n)"
)
# Patch CreateTable specifically for the get_table_info call
with patch(
"langchain_community.utilities.sql_database.CreateTable"
) as MockCreateTable:
# Mock the compile method to return a specific string
mock_compiler = MockCreateTable.return_value.compile
mock_compiler.return_value = expected_create_table_sql
# Call get_table_info with get_col_comments=True
table_info = self.db.get_table_info(get_col_comments=True)
# Verify CREATE TABLE statement (using the mocked value)
self.assertIn(expected_create_table_sql.strip(), table_info)
# Verify comments are included in table info in the correct format
self.assertIn("/*\nColumn Comments:", table_info)
self.assertIn("'id': 'Primary key'", table_info)
self.assertIn("'name': 'Name of the person'", table_info)
self.assertIn("'age': 'Age of the person'", table_info)
self.assertIn("*/", table_info)
def test_postgres_get_col_comments(self) -> None:
"""Test retrieving column comments from PostgreSQL"""
self._run_test_with_mocked_createtable("postgresql")
def test_mysql_get_col_comments(self) -> None:
"""Test retrieving column comments from MySQL"""
self._run_test_with_mocked_createtable("mysql")
def test_oracle_get_col_comments(self) -> None:
"""Test retrieving column comments from Oracle"""
self._run_test_with_mocked_createtable("oracle")
def test_sqlite_no_comments(self) -> None:
"""Test that SQLite does not add a comment block when comments are missing."""
# Setup SQLite table (comments will be ignored by SQLAlchemy for SQLite)
self.setup_mock_table_with_comments("sqlite", comments={})
# Mock reflection to return columns *without* comments
self.mock_inspector.get_columns.return_value = [
{
"name": "id",
"type": Integer(),
"nullable": False,
"default": None,
"autoincrement": "auto",
"comment": None,
},
{
"name": "name",
"type": String(100),
"nullable": True,
"default": None,
"autoincrement": "auto",
"comment": None,
},
{
"name": "age",
"type": Integer(),
"nullable": True,
"default": None,
"autoincrement": "auto",
"comment": None,
},
]
# Define the expected CREATE TABLE string
expected_create_table_sql = (
"CREATE TABLE test_table (\n\tid INTEGER NOT NULL, "
"\n\tname VARCHAR(100), \n\tage INTEGER, \n\tPRIMARY KEY (id)\n)"
)
# Patch CreateTable specifically for the get_table_info call
with patch(
"langchain_community.utilities.sql_database.CreateTable"
) as MockCreateTable:
mock_compiler = MockCreateTable.return_value.compile
mock_compiler.return_value = expected_create_table_sql
# Call get_table_info with get_col_comments=True
# Even if True, SQLite won't have comments to add.
table_info = self.db.get_table_info(get_col_comments=True)
# Verify CREATE TABLE statement
self.assertIn(expected_create_table_sql.strip(), table_info)
# Verify comments block is NOT included
self.assertNotIn("Column Comments:", table_info)
if __name__ == "__main__":
unittest.main()

View File

@ -35,6 +35,8 @@ def create_sql_query_chain(
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
*,
get_col_comments: Optional[bool] = None,
) -> Runnable[Union[SQLInput, SQLInputWithTables, dict[str, Any]], str]:
"""Create a chain that generates SQL queries.
@ -59,6 +61,8 @@ def create_sql_query_chain(
prompt: The prompt to use. If none is provided, will choose one
based on dialect. Defaults to None. See Prompt section below for more.
k: The number of results per select statement to return. Defaults to 5.
get_col_comments: Whether to retrieve column comments along with table info.
Defaults to False.
Returns:
A chain that takes in a question and generates a SQL query that answers
@ -127,10 +131,22 @@ def create_sql_query_chain(
if "dialect" in prompt_to_use.input_variables:
prompt_to_use = prompt_to_use.partial(dialect=db.dialect)
table_info_kwargs = {}
if get_col_comments:
if db.dialect not in ("postgresql", "mysql", "oracle"):
raise ValueError(
f"get_col_comments=True is only supported for dialects "
f"'postgresql', 'mysql', and 'oracle'. Received dialect: "
f"{db.dialect}"
)
else:
table_info_kwargs["get_col_comments"] = True
inputs = {
"input": lambda x: x["question"] + "\nSQLQuery: ",
"table_info": lambda x: db.get_table_info(
table_names=x.get("table_names_to_use")
table_names=x.get("table_names_to_use"),
**table_info_kwargs,
),
}
return (