mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 10:23:18 +00:00
community[patch]: fix QuerySQLDatabaseTool name (#28659)
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -13,7 +13,10 @@ from langchain_community.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
QuerySQLDatabaseTool,
|
||||
)
|
||||
from langchain_community.tools.sql_database.tool import (
|
||||
QuerySQLDataBaseTool as QuerySQLDataBaseTool, # keep import for backwards compat.
|
||||
)
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
|
||||
@@ -110,7 +113,7 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
|
||||
"to query the correct table fields."
|
||||
)
|
||||
query_sql_database_tool = QuerySQLDataBaseTool(
|
||||
query_sql_database_tool = QuerySQLDatabaseTool(
|
||||
db=self.db, description=query_sql_database_tool_description
|
||||
)
|
||||
query_sql_checker_tool_description = (
|
||||
|
@@ -302,6 +302,7 @@ if TYPE_CHECKING:
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
QuerySQLDatabaseTool,
|
||||
)
|
||||
from langchain_community.tools.stackexchange.tool import (
|
||||
StackExchangeTool,
|
||||
@@ -453,7 +454,8 @@ __all__ = [
|
||||
"QueryCheckerTool",
|
||||
"QueryPowerBITool",
|
||||
"QuerySQLCheckerTool",
|
||||
"QuerySQLDataBaseTool",
|
||||
"QuerySQLDatabaseTool",
|
||||
"QuerySQLDataBaseTool", # Legacy, kept for backwards compatibility.
|
||||
"QuerySparkSQLTool",
|
||||
"ReadFileTool",
|
||||
"RedditSearchRun",
|
||||
@@ -606,6 +608,8 @@ _module_lookup = {
|
||||
"QueryCheckerTool": "langchain_community.tools.spark_sql.tool",
|
||||
"QueryPowerBITool": "langchain_community.tools.powerbi.tool",
|
||||
"QuerySQLCheckerTool": "langchain_community.tools.sql_database.tool",
|
||||
"QuerySQLDatabaseTool": "langchain_community.tools.sql_database.tool",
|
||||
# Legacy, kept for backwards compatibility.
|
||||
"QuerySQLDataBaseTool": "langchain_community.tools.sql_database.tool",
|
||||
"QuerySparkSQLTool": "langchain_community.tools.spark_sql.tool",
|
||||
"ReadFileTool": "langchain_community.tools.file_management",
|
||||
|
@@ -28,12 +28,18 @@ class BaseSQLDatabaseTool(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class _QuerySQLDataBaseToolInput(BaseModel):
|
||||
class _QuerySQLDatabaseToolInput(BaseModel):
|
||||
query: str = Field(..., description="A detailed and correct SQL query.")
|
||||
|
||||
|
||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[override, override]
|
||||
"""Tool for querying a SQL database."""
|
||||
class QuerySQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[override, override]
|
||||
"""Tool for querying a SQL database.
|
||||
|
||||
.. versionchanged:: 0.3.12
|
||||
|
||||
Renamed from QuerySQLDataBaseTool to QuerySQLDatabaseTool.
|
||||
Legacy name still works for backwards compatibility.
|
||||
"""
|
||||
|
||||
name: str = "sql_db_query"
|
||||
description: str = """
|
||||
@@ -41,7 +47,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overr
|
||||
If the query is not correct, an error message will be returned.
|
||||
If an error is returned, rewrite the query, check the query, and try again.
|
||||
"""
|
||||
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
|
||||
args_schema: Type[BaseModel] = _QuerySQLDatabaseToolInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -52,6 +58,10 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overr
|
||||
return self.db.run_no_throw(query)
|
||||
|
||||
|
||||
# Backwards compatibility.
|
||||
QuerySQLDataBaseTool = QuerySQLDatabaseTool
|
||||
|
||||
|
||||
class _InfoSQLDatabaseToolInput(BaseModel):
|
||||
table_names: str = Field(
|
||||
...,
|
||||
@@ -80,7 +90,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overri
|
||||
)
|
||||
|
||||
|
||||
class _ListSQLDataBaseToolInput(BaseModel):
|
||||
class _ListSQLDatabaseToolInput(BaseModel):
|
||||
tool_input: str = Field("", description="An empty string")
|
||||
|
||||
|
||||
@@ -89,7 +99,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overri
|
||||
|
||||
name: str = "sql_db_list_tables"
|
||||
description: str = "Input is an empty string, output is a comma-separated list of tables in the database."
|
||||
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
|
||||
args_schema: Type[BaseModel] = _ListSQLDatabaseToolInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
@@ -41,4 +41,5 @@ def test_tool_names_unique() -> None:
|
||||
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
|
||||
names = sorted([tool_cls.model_fields["name"].default for tool_cls in tool_classes])
|
||||
duplicated_names = [name for name in names if names.count(name) > 1]
|
||||
assert not duplicated_names
|
||||
# only one duplicate currently exists
|
||||
assert duplicated_names == ["sql_db_query", "sql_db_query"]
|
||||
|
@@ -105,6 +105,7 @@ EXPECTED_ALL = [
|
||||
"QueryCheckerTool",
|
||||
"QueryPowerBITool",
|
||||
"QuerySQLCheckerTool",
|
||||
"QuerySQLDatabaseTool",
|
||||
"QuerySQLDataBaseTool",
|
||||
"QuerySparkSQLTool",
|
||||
"ReadFileTool",
|
||||
|
Reference in New Issue
Block a user