mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
community: add 'get_col_comments' option for retrieve database columns comments (#30646)
## Description Added support for retrieving column comments in the SQL Database utility. This feature allows users to see comments associated with database columns when querying table information. Column comments provide valuable metadata that helps LLMs better understand the semantics and purpose of database columns. A new optional parameter `get_col_comments` was added to the `get_table_info` method, defaulting to `False` for backward compatibility. When set to `True`, it retrieves and formats column comments for each table. Currently, this feature is supported on PostgreSQL, MySQL, and Oracle databases. ## Implementation You should create Table with column comments before. ```python db = SQLDatabase.from_uri("YOUR_DB_URI") print(db.get_table_info(get_col_comments=True)) ``` ## Result ``` CREATE TABLE test_table ( name VARCHAR school VARCHAR) /* Column Comments: {'name': person name, 'school":school_name} */ /* 3 rows from test_table: name a b c */ ``` ## Benefits 1. Enhances LLM's understanding of database schema semantics 2. Preserves valuable domain knowledge embedded in database design 3. Improves accuracy of SQL query generation 4. Provides more context for data interpretation Tests are available in `langchain/libs/community/tests/test_sql_get_table_info.py`. --------- Co-authored-by: chbae <chbae@gcsc.co.kr> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
3fb0a55122
commit
a2863f8757
@ -316,7 +316,9 @@ class SQLDatabase:
|
||||
"""Information about all tables in the database."""
|
||||
return self.get_table_info()
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
def get_table_info(
|
||||
self, table_names: Optional[List[str]] = None, get_col_comments: bool = False
|
||||
) -> str:
|
||||
"""Get information about specified tables.
|
||||
|
||||
Follows best practices as specified in: Rajkumar et al, 2022
|
||||
@ -356,14 +358,39 @@ class SQLDatabase:
|
||||
tables.append(self._custom_table_info[table.name])
|
||||
continue
|
||||
|
||||
# Ignore JSON datatyped columns
|
||||
for k, v in table.columns.items(): # AttributeError: items in sqlalchemy v1
|
||||
if type(v.type) is NullType:
|
||||
table._columns.remove(v)
|
||||
# Ignore JSON datatyped columns - SQLAlchemy v1.x compatibility
|
||||
try:
|
||||
# For SQLAlchemy v2.x
|
||||
for k, v in table.columns.items():
|
||||
if type(v.type) is NullType:
|
||||
table._columns.remove(v)
|
||||
except AttributeError:
|
||||
# For SQLAlchemy v1.x
|
||||
for k, v in dict(table.columns).items():
|
||||
if type(v.type) is NullType:
|
||||
table._columns.remove(v)
|
||||
|
||||
# add create table command
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
|
||||
# Add column comments as dictionary
|
||||
if get_col_comments:
|
||||
try:
|
||||
column_comments_dict = {}
|
||||
for column in table.columns:
|
||||
if column.comment:
|
||||
column_comments_dict[column.name] = column.comment
|
||||
|
||||
if column_comments_dict:
|
||||
table_info += (
|
||||
f"\n\n/*\nColumn Comments: {column_comments_dict}\n*/"
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Column comments are available on PostgreSQL, MySQL, Oracle"
|
||||
)
|
||||
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
|
259
libs/community/tests/unit_tests/test_sql_get_table_info.py
Normal file
259
libs/community/tests/unit_tests/test_sql_get_table_info.py
Normal file
@ -0,0 +1,259 @@
|
||||
import unittest
|
||||
from typing import Dict, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from sqlalchemy import Column, Integer, MetaData, String, Table
|
||||
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
|
||||
|
||||
class TestSQLDatabaseComments(unittest.TestCase):
|
||||
"""Test class for column comment functionality in SQLDatabase"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Setup before each test"""
|
||||
# Mock Engine
|
||||
self.mock_engine = MagicMock()
|
||||
self.mock_engine.dialect.name = "postgresql" # Default to PostgreSQL
|
||||
|
||||
# Mock inspector and start patch *before* SQLDatabase initialization
|
||||
self.mock_inspector = MagicMock()
|
||||
# Mock table name list and other inspector methods called during init
|
||||
self.mock_inspector.get_table_names.return_value = ["test_table"]
|
||||
self.mock_inspector.get_view_names.return_value = []
|
||||
self.mock_inspector.get_indexes.return_value = []
|
||||
# Mock get_columns to return something reasonable for reflection
|
||||
self.mock_inspector.get_columns.return_value = [
|
||||
{
|
||||
"name": "id",
|
||||
"type": Integer(),
|
||||
"nullable": False,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": None,
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": String(100),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": None,
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"type": Integer(),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": None,
|
||||
},
|
||||
]
|
||||
# Mock get_pk_constraint for reflection
|
||||
self.mock_inspector.get_pk_constraint.return_value = {
|
||||
"constrained_columns": ["id"],
|
||||
"name": None,
|
||||
}
|
||||
# Mock get_foreign_keys for reflection
|
||||
self.mock_inspector.get_foreign_keys.return_value = []
|
||||
|
||||
# Patch sqlalchemy.inspect to return our mock inspector
|
||||
self.patch_inspector = patch(
|
||||
"langchain_community.utilities.sql_database.inspect",
|
||||
return_value=self.mock_inspector,
|
||||
)
|
||||
# Start the patch *before* creating the SQLDatabase instance
|
||||
self.mock_inspect = self.patch_inspector.start()
|
||||
|
||||
# Mock metadata
|
||||
self.metadata = MetaData()
|
||||
|
||||
# Create test database object *after* patching inspect
|
||||
try:
|
||||
self.db = SQLDatabase(
|
||||
engine=self.mock_engine,
|
||||
metadata=self.metadata,
|
||||
lazy_table_reflection=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.fail(f"Unexpected exception during SQLDatabase init: {e}")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""Cleanup after each test"""
|
||||
self.patch_inspector.stop()
|
||||
|
||||
def setup_mock_table_with_comments(
|
||||
self, dialect: str, comments: Optional[Dict[str, str]] = None
|
||||
) -> Table:
|
||||
"""Setup a mock table with comments
|
||||
|
||||
Args:
|
||||
dialect (str): Database dialect to test (postgresql, mysql, oracle)
|
||||
comments (dict, optional): Column comments. Uses default comments if None
|
||||
|
||||
Returns:
|
||||
Table: The created mock table
|
||||
"""
|
||||
# Default comments
|
||||
if comments is None:
|
||||
comments = {
|
||||
"id": "Primary key",
|
||||
"name": "Name of the person",
|
||||
"age": "Age of the person",
|
||||
}
|
||||
|
||||
# Set engine dialect
|
||||
self.mock_engine.dialect.name = dialect
|
||||
|
||||
# Clear existing metadata if necessary, or use a fresh MetaData object
|
||||
self.metadata.clear()
|
||||
|
||||
# Create test table
|
||||
test_table = Table(
|
||||
"test_table",
|
||||
self.metadata,
|
||||
Column("id", Integer, primary_key=True, comment=comments.get("id")),
|
||||
Column("name", String(100), comment=comments.get("name")),
|
||||
Column("age", Integer, comment=comments.get("age")),
|
||||
)
|
||||
|
||||
# Mock reflection to return the columns with comments
|
||||
# This is crucial because lazy reflection will call inspect later
|
||||
self.mock_inspector.get_columns.return_value = [
|
||||
{
|
||||
"name": "id",
|
||||
"type": Integer(),
|
||||
"nullable": False,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": comments.get("id"),
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": String(100),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": comments.get("name"),
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"type": Integer(),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": comments.get("age"),
|
||||
},
|
||||
]
|
||||
self.mock_inspector.get_table_names.return_value = [
|
||||
"test_table"
|
||||
] # Ensure table is discoverable
|
||||
|
||||
# No need to mock CreateTable here, let the actual code call it.
|
||||
# We will patch it during the get_table_info call in the tests.
|
||||
|
||||
# No need to manually add table to metadata, reflection handles it
|
||||
# self.metadata._add_table("test_table", None, test_table)
|
||||
|
||||
return test_table
|
||||
|
||||
def _run_test_with_mocked_createtable(self, dialect: str) -> None:
|
||||
"""Helper function to run comment tests with CreateTable mocked."""
|
||||
self.setup_mock_table_with_comments(dialect)
|
||||
|
||||
# Define the expected CREATE TABLE string
|
||||
expected_create_table_sql = (
|
||||
"CREATE TABLE test_table (\n\tid INTEGER NOT NULL, "
|
||||
"\n\tname VARCHAR(100), \n\tage INTEGER, \n\tPRIMARY KEY (id)\n)"
|
||||
)
|
||||
|
||||
# Patch CreateTable specifically for the get_table_info call
|
||||
with patch(
|
||||
"langchain_community.utilities.sql_database.CreateTable"
|
||||
) as MockCreateTable:
|
||||
# Mock the compile method to return a specific string
|
||||
mock_compiler = MockCreateTable.return_value.compile
|
||||
mock_compiler.return_value = expected_create_table_sql
|
||||
|
||||
# Call get_table_info with get_col_comments=True
|
||||
table_info = self.db.get_table_info(get_col_comments=True)
|
||||
|
||||
# Verify CREATE TABLE statement (using the mocked value)
|
||||
self.assertIn(expected_create_table_sql.strip(), table_info)
|
||||
|
||||
# Verify comments are included in table info in the correct format
|
||||
self.assertIn("/*\nColumn Comments:", table_info)
|
||||
self.assertIn("'id': 'Primary key'", table_info)
|
||||
self.assertIn("'name': 'Name of the person'", table_info)
|
||||
self.assertIn("'age': 'Age of the person'", table_info)
|
||||
self.assertIn("*/", table_info)
|
||||
|
||||
def test_postgres_get_col_comments(self) -> None:
|
||||
"""Test retrieving column comments from PostgreSQL"""
|
||||
self._run_test_with_mocked_createtable("postgresql")
|
||||
|
||||
def test_mysql_get_col_comments(self) -> None:
|
||||
"""Test retrieving column comments from MySQL"""
|
||||
self._run_test_with_mocked_createtable("mysql")
|
||||
|
||||
def test_oracle_get_col_comments(self) -> None:
|
||||
"""Test retrieving column comments from Oracle"""
|
||||
self._run_test_with_mocked_createtable("oracle")
|
||||
|
||||
def test_sqlite_no_comments(self) -> None:
|
||||
"""Test that SQLite does not add a comment block when comments are missing."""
|
||||
# Setup SQLite table (comments will be ignored by SQLAlchemy for SQLite)
|
||||
self.setup_mock_table_with_comments("sqlite", comments={})
|
||||
# Mock reflection to return columns *without* comments
|
||||
self.mock_inspector.get_columns.return_value = [
|
||||
{
|
||||
"name": "id",
|
||||
"type": Integer(),
|
||||
"nullable": False,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": None,
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": String(100),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": None,
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"type": Integer(),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": "auto",
|
||||
"comment": None,
|
||||
},
|
||||
]
|
||||
|
||||
# Define the expected CREATE TABLE string
|
||||
expected_create_table_sql = (
|
||||
"CREATE TABLE test_table (\n\tid INTEGER NOT NULL, "
|
||||
"\n\tname VARCHAR(100), \n\tage INTEGER, \n\tPRIMARY KEY (id)\n)"
|
||||
)
|
||||
|
||||
# Patch CreateTable specifically for the get_table_info call
|
||||
with patch(
|
||||
"langchain_community.utilities.sql_database.CreateTable"
|
||||
) as MockCreateTable:
|
||||
mock_compiler = MockCreateTable.return_value.compile
|
||||
mock_compiler.return_value = expected_create_table_sql
|
||||
|
||||
# Call get_table_info with get_col_comments=True
|
||||
# Even if True, SQLite won't have comments to add.
|
||||
table_info = self.db.get_table_info(get_col_comments=True)
|
||||
|
||||
# Verify CREATE TABLE statement
|
||||
self.assertIn(expected_create_table_sql.strip(), table_info)
|
||||
# Verify comments block is NOT included
|
||||
self.assertNotIn("Column Comments:", table_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -35,6 +35,8 @@ def create_sql_query_chain(
|
||||
db: SQLDatabase,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
k: int = 5,
|
||||
*,
|
||||
get_col_comments: Optional[bool] = None,
|
||||
) -> Runnable[Union[SQLInput, SQLInputWithTables, dict[str, Any]], str]:
|
||||
"""Create a chain that generates SQL queries.
|
||||
|
||||
@ -59,6 +61,8 @@ def create_sql_query_chain(
|
||||
prompt: The prompt to use. If none is provided, will choose one
|
||||
based on dialect. Defaults to None. See Prompt section below for more.
|
||||
k: The number of results per select statement to return. Defaults to 5.
|
||||
get_col_comments: Whether to retrieve column comments along with table info.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
A chain that takes in a question and generates a SQL query that answers
|
||||
@ -127,10 +131,22 @@ def create_sql_query_chain(
|
||||
if "dialect" in prompt_to_use.input_variables:
|
||||
prompt_to_use = prompt_to_use.partial(dialect=db.dialect)
|
||||
|
||||
table_info_kwargs = {}
|
||||
if get_col_comments:
|
||||
if db.dialect not in ("postgresql", "mysql", "oracle"):
|
||||
raise ValueError(
|
||||
f"get_col_comments=True is only supported for dialects "
|
||||
f"'postgresql', 'mysql', and 'oracle'. Received dialect: "
|
||||
f"{db.dialect}"
|
||||
)
|
||||
else:
|
||||
table_info_kwargs["get_col_comments"] = True
|
||||
|
||||
inputs = {
|
||||
"input": lambda x: x["question"] + "\nSQLQuery: ",
|
||||
"table_info": lambda x: db.get_table_info(
|
||||
table_names=x.get("table_names_to_use")
|
||||
table_names=x.get("table_names_to_use"),
|
||||
**table_info_kwargs,
|
||||
),
|
||||
}
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user