mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
community[patch]: fix QuerySQLDatabaseTool name (#28659)
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
94c22c3f48
commit
fa06188834
@ -192,10 +192,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[QuerySQLDataBaseTool(description=\"Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.\", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x105e02860>),\n",
|
||||
" InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x105e02860>),\n",
|
||||
" ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x105e02860>),\n",
|
||||
" QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x105e02860>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x1148a97b0>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x1148aaec0>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_proxy=''), llm_chain=LLMChain(prompt=PromptTemplate(input_variables=['dialect', 'query'], template='\\n{query}\\nDouble check the {dialect} query above for common mistakes, including:\\n- Using NOT IN with NULL values\\n- Using UNION when UNION ALL should have been used\\n- Using BETWEEN for exclusive ranges\\n- Data type mismatch in predicates\\n- Properly quoting identifiers\\n- Using the correct number of arguments for functions\\n- Casting to the correct data type\\n- Using the proper columns for joins\\n\\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\\n\\nOutput the final SQL query only.\\n\\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x1148a97b0>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x1148aaec0>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_proxy='')))]"
|
||||
"[QuerySQLDatabaseTool(description=\"Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.\", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),\n",
|
||||
" InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),\n",
|
||||
" ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),\n",
|
||||
" QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x10742d720>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x10742f7f0>, root_client=<openai.OpenAI object at 0x103d5fac0>, root_async_client=<openai.AsyncOpenAI object at 0x10742d780>, temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\\n{query}\\nDouble check the {dialect} query above for common mistakes, including:\\n- Using NOT IN with NULL values\\n- Using UNION when UNION ALL should have been used\\n- Using BETWEEN for exclusive ranges\\n- Data type mismatch in predicates\\n- Properly quoting identifiers\\n- Using the correct number of arguments for functions\\n- Casting to the correct data type\\n- Using the proper columns for joins\\n\\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\\n\\nOutput the final SQL query only.\\n\\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x10742d720>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x10742f7f0>, root_client=<openai.OpenAI object at 0x103d5fac0>, root_async_client=<openai.AsyncOpenAI object at 0x10742d780>, temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
@ -226,7 +226,7 @@
|
||||
" InfoSQLDatabaseTool,\n",
|
||||
" ListSQLDatabaseTool,\n",
|
||||
" QuerySQLCheckerTool,\n",
|
||||
" QuerySQLDataBaseTool,\n",
|
||||
" QuerySQLDatabaseTool,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -242,7 +242,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"id": "eda12f8b-be90-4697-ac84-2ece9e2d1708",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -265,7 +265,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"id": "3470ae96-e5e5-4717-a6d6-d7d28c7b7347",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -283,7 +283,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"id": "48bca92c-9b4b-4d5c-bcce-1b239c9e901c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -305,7 +305,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"id": "39e6d2bf-3194-4aba-854b-63faf919157b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -318,8 +318,8 @@
|
||||
"Which country's customers spent the most?\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"Tool Calls:\n",
|
||||
" sql_db_list_tables (call_eiheSxiL0s90KE50XyBnBtJY)\n",
|
||||
" Call ID: call_eiheSxiL0s90KE50XyBnBtJY\n",
|
||||
" sql_db_list_tables (call_EBPjyfzqXzFutDn8BklYACLj)\n",
|
||||
" Call ID: call_EBPjyfzqXzFutDn8BklYACLj\n",
|
||||
" Args:\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
"Name: sql_db_list_tables\n",
|
||||
@ -327,8 +327,8 @@
|
||||
"Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"Tool Calls:\n",
|
||||
" sql_db_schema (call_YKwGWt4UUVmxxY7vjjBDzFLJ)\n",
|
||||
" Call ID: call_YKwGWt4UUVmxxY7vjjBDzFLJ\n",
|
||||
" sql_db_schema (call_kGcnKpxRVFIY8dPjYIJbRoVU)\n",
|
||||
" Call ID: call_kGcnKpxRVFIY8dPjYIJbRoVU\n",
|
||||
" Args:\n",
|
||||
" table_names: Customer, Invoice, InvoiceLine\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
@ -405,14 +405,14 @@
|
||||
"*/\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"Tool Calls:\n",
|
||||
" sql_db_query (call_7WBDcMxl1h7MnI05njx1q8V9)\n",
|
||||
" Call ID: call_7WBDcMxl1h7MnI05njx1q8V9\n",
|
||||
" sql_db_query (call_cTfI7OrY64FzJaDd49ILFWw7)\n",
|
||||
" Call ID: call_cTfI7OrY64FzJaDd49ILFWw7\n",
|
||||
" Args:\n",
|
||||
" query: SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
"Name: sql_db_query\n",
|
||||
"\n",
|
||||
"[('USA', 523.0600000000003)]\n",
|
||||
"[('USA', 523.06)]\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"\n",
|
||||
"Customers from the USA spent the most, with a total amount spent of $523.06.\n"
|
||||
@ -440,7 +440,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 10,
|
||||
"id": "23c1235c-6d18-43e4-98ab-85b426b53d94",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -453,8 +453,8 @@
|
||||
"Who are the top 3 best selling artists?\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"Tool Calls:\n",
|
||||
" sql_db_query (call_9F6Bp2vwsDkeLW6FsJFqLiet)\n",
|
||||
" Call ID: call_9F6Bp2vwsDkeLW6FsJFqLiet\n",
|
||||
" sql_db_query (call_xAkvYiRFM7nCMKXsDNvk1OMx)\n",
|
||||
" Call ID: call_xAkvYiRFM7nCMKXsDNvk1OMx\n",
|
||||
" Args:\n",
|
||||
" query: SELECT artist_name, SUM(quantity) AS total_sold FROM sales GROUP BY artist_name ORDER BY total_sold DESC LIMIT 3\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
@ -465,8 +465,8 @@
|
||||
"(Background on this error at: https://sqlalche.me/e/20/e3q8)\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"Tool Calls:\n",
|
||||
" sql_db_list_tables (call_Gx5adzWnrBDIIxzUDzsn83zO)\n",
|
||||
" Call ID: call_Gx5adzWnrBDIIxzUDzsn83zO\n",
|
||||
" sql_db_list_tables (call_K4Zvbowsq7XPgGFepbvc5G7i)\n",
|
||||
" Call ID: call_K4Zvbowsq7XPgGFepbvc5G7i\n",
|
||||
" Args:\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
"Name: sql_db_list_tables\n",
|
||||
@ -474,8 +474,8 @@
|
||||
"Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"Tool Calls:\n",
|
||||
" sql_db_schema (call_ftywrZgEgGWLrnk9dYC0xtZv)\n",
|
||||
" Call ID: call_ftywrZgEgGWLrnk9dYC0xtZv\n",
|
||||
" sql_db_schema (call_tUztueSK7VO2klZ99xT4ZVhM)\n",
|
||||
" Call ID: call_tUztueSK7VO2klZ99xT4ZVhM\n",
|
||||
" Args:\n",
|
||||
" table_names: Artist, Album, InvoiceLine\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
@ -534,8 +534,8 @@
|
||||
"*/\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"Tool Calls:\n",
|
||||
" sql_db_query (call_i6n3lmS7E2ZivN758VOayTiy)\n",
|
||||
" Call ID: call_i6n3lmS7E2ZivN758VOayTiy\n",
|
||||
" sql_db_query (call_tVtLQIRPmCM6pukgpHFfq86A)\n",
|
||||
" Call ID: call_tVtLQIRPmCM6pukgpHFfq86A\n",
|
||||
" Args:\n",
|
||||
" query: SELECT Artist.Name AS artist_name, SUM(InvoiceLine.Quantity) AS total_sold FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY total_sold DESC LIMIT 3\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
@ -614,7 +614,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
File diff suppressed because one or more lines are too long
@ -13,7 +13,10 @@ from langchain_community.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
QuerySQLDatabaseTool,
|
||||
)
|
||||
from langchain_community.tools.sql_database.tool import (
|
||||
QuerySQLDataBaseTool as QuerySQLDataBaseTool, # keep import for backwards compat.
|
||||
)
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
|
||||
@ -110,7 +113,7 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
|
||||
"to query the correct table fields."
|
||||
)
|
||||
query_sql_database_tool = QuerySQLDataBaseTool(
|
||||
query_sql_database_tool = QuerySQLDatabaseTool(
|
||||
db=self.db, description=query_sql_database_tool_description
|
||||
)
|
||||
query_sql_checker_tool_description = (
|
||||
|
@ -302,6 +302,7 @@ if TYPE_CHECKING:
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
QuerySQLDatabaseTool,
|
||||
)
|
||||
from langchain_community.tools.stackexchange.tool import (
|
||||
StackExchangeTool,
|
||||
@ -453,7 +454,8 @@ __all__ = [
|
||||
"QueryCheckerTool",
|
||||
"QueryPowerBITool",
|
||||
"QuerySQLCheckerTool",
|
||||
"QuerySQLDataBaseTool",
|
||||
"QuerySQLDatabaseTool",
|
||||
"QuerySQLDataBaseTool", # Legacy, kept for backwards compatibility.
|
||||
"QuerySparkSQLTool",
|
||||
"ReadFileTool",
|
||||
"RedditSearchRun",
|
||||
@ -606,6 +608,8 @@ _module_lookup = {
|
||||
"QueryCheckerTool": "langchain_community.tools.spark_sql.tool",
|
||||
"QueryPowerBITool": "langchain_community.tools.powerbi.tool",
|
||||
"QuerySQLCheckerTool": "langchain_community.tools.sql_database.tool",
|
||||
"QuerySQLDatabaseTool": "langchain_community.tools.sql_database.tool",
|
||||
# Legacy, kept for backwards compatibility.
|
||||
"QuerySQLDataBaseTool": "langchain_community.tools.sql_database.tool",
|
||||
"QuerySparkSQLTool": "langchain_community.tools.spark_sql.tool",
|
||||
"ReadFileTool": "langchain_community.tools.file_management",
|
||||
|
@ -28,12 +28,18 @@ class BaseSQLDatabaseTool(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class _QuerySQLDataBaseToolInput(BaseModel):
|
||||
class _QuerySQLDatabaseToolInput(BaseModel):
|
||||
query: str = Field(..., description="A detailed and correct SQL query.")
|
||||
|
||||
|
||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[override, override]
|
||||
"""Tool for querying a SQL database."""
|
||||
class QuerySQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[override, override]
|
||||
"""Tool for querying a SQL database.
|
||||
|
||||
.. versionchanged:: 0.3.12
|
||||
|
||||
Renamed from QuerySQLDataBaseTool to QuerySQLDatabaseTool.
|
||||
Legacy name still works for backwards compatibility.
|
||||
"""
|
||||
|
||||
name: str = "sql_db_query"
|
||||
description: str = """
|
||||
@ -41,7 +47,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overr
|
||||
If the query is not correct, an error message will be returned.
|
||||
If an error is returned, rewrite the query, check the query, and try again.
|
||||
"""
|
||||
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
|
||||
args_schema: Type[BaseModel] = _QuerySQLDatabaseToolInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@ -52,6 +58,10 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overr
|
||||
return self.db.run_no_throw(query)
|
||||
|
||||
|
||||
# Backwards compatibility.
|
||||
QuerySQLDataBaseTool = QuerySQLDatabaseTool
|
||||
|
||||
|
||||
class _InfoSQLDatabaseToolInput(BaseModel):
|
||||
table_names: str = Field(
|
||||
...,
|
||||
@ -80,7 +90,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overri
|
||||
)
|
||||
|
||||
|
||||
class _ListSQLDataBaseToolInput(BaseModel):
|
||||
class _ListSQLDatabaseToolInput(BaseModel):
|
||||
tool_input: str = Field("", description="An empty string")
|
||||
|
||||
|
||||
@ -89,7 +99,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): # type: ignore[overri
|
||||
|
||||
name: str = "sql_db_list_tables"
|
||||
description: str = "Input is an empty string, output is a comma-separated list of tables in the database."
|
||||
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
|
||||
args_schema: Type[BaseModel] = _ListSQLDatabaseToolInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
@ -41,4 +41,5 @@ def test_tool_names_unique() -> None:
|
||||
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
|
||||
names = sorted([tool_cls.model_fields["name"].default for tool_cls in tool_classes])
|
||||
duplicated_names = [name for name in names if names.count(name) > 1]
|
||||
assert not duplicated_names
|
||||
# only one duplicate currently exists
|
||||
assert duplicated_names == ["sql_db_query", "sql_db_query"]
|
||||
|
@ -105,6 +105,7 @@ EXPECTED_ALL = [
|
||||
"QueryCheckerTool",
|
||||
"QueryPowerBITool",
|
||||
"QuerySQLCheckerTool",
|
||||
"QuerySQLDatabaseTool",
|
||||
"QuerySQLDataBaseTool",
|
||||
"QuerySparkSQLTool",
|
||||
"ReadFileTool",
|
||||
|
Loading…
Reference in New Issue
Block a user