community[patch]: fix QuerySQLDatabaseTool name (#28659)

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Bagatur
2024-12-12 19:16:03 -08:00
committed by GitHub
parent 94c22c3f48
commit fa06188834
7 changed files with 84 additions and 65 deletions

View File

@@ -302,6 +302,7 @@ if TYPE_CHECKING:
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDataBaseTool,
QuerySQLDatabaseTool,
)
from langchain_community.tools.stackexchange.tool import (
StackExchangeTool,
@@ -453,7 +454,8 @@ __all__ = [
"QueryCheckerTool",
"QueryPowerBITool",
"QuerySQLCheckerTool",
"QuerySQLDataBaseTool",
"QuerySQLDatabaseTool",
"QuerySQLDataBaseTool", # Legacy, kept for backwards compatibility.
"QuerySparkSQLTool",
"ReadFileTool",
"RedditSearchRun",
@@ -606,6 +608,8 @@ _module_lookup = {
"QueryCheckerTool": "langchain_community.tools.spark_sql.tool",
"QueryPowerBITool": "langchain_community.tools.powerbi.tool",
"QuerySQLCheckerTool": "langchain_community.tools.sql_database.tool",
"QuerySQLDatabaseTool": "langchain_community.tools.sql_database.tool",
# Legacy, kept for backwards compatibility.
"QuerySQLDataBaseTool": "langchain_community.tools.sql_database.tool",
"QuerySparkSQLTool": "langchain_community.tools.spark_sql.tool",
"ReadFileTool": "langchain_community.tools.file_management",

View File

@@ -28,12 +28,18 @@ class BaseSQLDatabaseTool(BaseModel):
)
class _QuerySQLDataBaseToolInput(BaseModel):
class _QuerySQLDatabaseToolInput(BaseModel):
query: str = Field(..., description="A detailed and correct SQL query.")
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[override, override]
"""Tool for querying a SQL database."""
class QuerySQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[override, override]
"""Tool for querying a SQL database.
.. versionchanged:: 0.3.12
Renamed from QuerySQLDataBaseTool to QuerySQLDatabaseTool.
Legacy name still works for backwards compatibility.
"""
name: str = "sql_db_query"
description: str = """
@@ -41,7 +47,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overr
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
args_schema: Type[BaseModel] = _QuerySQLDatabaseToolInput
def _run(
self,
@@ -52,6 +58,10 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overr
return self.db.run_no_throw(query)
# Backwards compatibility.
QuerySQLDataBaseTool = QuerySQLDatabaseTool
class _InfoSQLDatabaseToolInput(BaseModel):
table_names: str = Field(
...,
@@ -80,7 +90,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overri
)
class _ListSQLDataBaseToolInput(BaseModel):
class _ListSQLDatabaseToolInput(BaseModel):
tool_input: str = Field("", description="An empty string")
@@ -89,7 +99,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overri
name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma-separated list of tables in the database."
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
args_schema: Type[BaseModel] = _ListSQLDatabaseToolInput
def _run(
self,