community[patch]: fix QuerySQLDatabaseTool name (#28659)

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Bagatur
2024-12-12 19:16:03 -08:00
committed by GitHub
parent 94c22c3f48
commit fa06188834
7 changed files with 84 additions and 65 deletions

View File

@@ -13,7 +13,10 @@ from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDataBaseTool,
QuerySQLDatabaseTool,
)
from langchain_community.tools.sql_database.tool import (
QuerySQLDataBaseTool as QuerySQLDataBaseTool, # keep import for backwards compat.
)
from langchain_community.utilities.sql_database import SQLDatabase
@@ -110,7 +113,7 @@ class SQLDatabaseToolkit(BaseToolkit):
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDataBaseTool(
query_sql_database_tool = QuerySQLDatabaseTool(
db=self.db, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (

View File

@@ -302,6 +302,7 @@ if TYPE_CHECKING:
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDataBaseTool,
QuerySQLDatabaseTool,
)
from langchain_community.tools.stackexchange.tool import (
StackExchangeTool,
@@ -453,7 +454,8 @@ __all__ = [
"QueryCheckerTool",
"QueryPowerBITool",
"QuerySQLCheckerTool",
"QuerySQLDataBaseTool",
"QuerySQLDatabaseTool",
"QuerySQLDataBaseTool", # Legacy, kept for backwards compatibility.
"QuerySparkSQLTool",
"ReadFileTool",
"RedditSearchRun",
@@ -606,6 +608,8 @@ _module_lookup = {
"QueryCheckerTool": "langchain_community.tools.spark_sql.tool",
"QueryPowerBITool": "langchain_community.tools.powerbi.tool",
"QuerySQLCheckerTool": "langchain_community.tools.sql_database.tool",
"QuerySQLDatabaseTool": "langchain_community.tools.sql_database.tool",
# Legacy, kept for backwards compatibility.
"QuerySQLDataBaseTool": "langchain_community.tools.sql_database.tool",
"QuerySparkSQLTool": "langchain_community.tools.spark_sql.tool",
"ReadFileTool": "langchain_community.tools.file_management",

View File

@@ -28,12 +28,18 @@ class BaseSQLDatabaseTool(BaseModel):
)
class _QuerySQLDataBaseToolInput(BaseModel):
class _QuerySQLDatabaseToolInput(BaseModel):
query: str = Field(..., description="A detailed and correct SQL query.")
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[override, override]
"""Tool for querying a SQL database."""
class QuerySQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[override, override]
"""Tool for querying a SQL database.
.. versionchanged:: 0.3.12
Renamed from QuerySQLDataBaseTool to QuerySQLDatabaseTool.
Legacy name still works for backwards compatibility.
"""
name: str = "sql_db_query"
description: str = """
@@ -41,7 +47,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overr
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
args_schema: Type[BaseModel] = _QuerySQLDatabaseToolInput
def _run(
self,
@@ -52,6 +58,10 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overr
return self.db.run_no_throw(query)
# Backwards compatibility.
QuerySQLDataBaseTool = QuerySQLDatabaseTool
class _InfoSQLDatabaseToolInput(BaseModel):
table_names: str = Field(
...,
@@ -80,7 +90,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overri
)
class _ListSQLDataBaseToolInput(BaseModel):
class _ListSQLDatabaseToolInput(BaseModel):
tool_input: str = Field("", description="An empty string")
@@ -89,7 +99,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overri
name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma-separated list of tables in the database."
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
args_schema: Type[BaseModel] = _ListSQLDatabaseToolInput
def _run(
self,

View File

@@ -41,4 +41,5 @@ def test_tool_names_unique() -> None:
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
names = sorted([tool_cls.model_fields["name"].default for tool_cls in tool_classes])
duplicated_names = [name for name in names if names.count(name) > 1]
assert not duplicated_names
# only one duplicate currently exists
assert duplicated_names == ["sql_db_query", "sql_db_query"]

View File

@@ -105,6 +105,7 @@ EXPECTED_ALL = [
"QueryCheckerTool",
"QueryPowerBITool",
"QuerySQLCheckerTool",
"QuerySQLDatabaseTool",
"QuerySQLDataBaseTool",
"QuerySparkSQLTool",
"ReadFileTool",