community[patch]: add args_schema to sql_database tools for langGraph integration (#18595)

- **Description:** This modification adds pydantic input definition for
sql_database tools. This helps for function calling capability in
LangGraph. Since actions nodes will usually check for the args_schema
attribute on tools, This update should make these tools compatible with
it (only implemented on the InfoSQLDatabaseTool)
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Twitter handle:** juanfe8881
This commit is contained in:
Juan Felipe Arias 2024-03-15 20:03:36 +01:00 committed by GitHub
parent c922ea36cb
commit f5b9aedc48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,6 +39,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
If the query is not correct, an error message will be returned. If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again. If an error is returned, rewrite the query, check the query, and try again.
""" """
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
def _run( def _run(
self, self,
@ -77,11 +78,16 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
) )
class _ListSQLDataBaseToolInput(BaseModel):
tool_input: str = Field(..., description="An empty string")
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names.""" """Tool for getting tables names."""
name: str = "sql_db_list_tables" name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma separated list of tables in the database." description: str = "Input is an empty string, output is a comma separated list of tables in the database."
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
def _run( def _run(
self, self,
@ -92,6 +98,10 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
return ", ".join(self.db.get_usable_table_names()) return ", ".join(self.db.get_usable_table_names())
class _QuerySQLCheckerToolInput(BaseModel):
query: str = Field(..., description="A detailed and SQL query to be checked.")
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct. """Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
@ -104,6 +114,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
Use this tool to double check if your query is correct before executing it. Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with sql_db_query! Always use this tool before executing a query with sql_db_query!
""" """
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
@root_validator(pre=True) @root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: