community[patch]: add args_schema to sql_database tools for langGraph integration (#18595)

- **Description:** This modification adds pydantic input definition for
sql_database tools. This helps for function calling capability in
LangGraph. Since actions nodes will usually check for the args_schema
attribute on tools, This update should make these tools compatible with
it (only implemented on the InfoSQLDatabaseTool)
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Twitter handle:** juanfe8881
This commit is contained in:
Juan Felipe Arias 2024-03-15 20:03:36 +01:00 committed by GitHub
parent c922ea36cb
commit f5b9aedc48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,6 +39,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
def _run(
self,
@ -77,11 +78,16 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
)
class _ListSQLDataBaseToolInput(BaseModel):
tool_input: str = Field(..., description="An empty string")
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""
name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
def _run(
self,
@ -92,6 +98,10 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
return ", ".join(self.db.get_usable_table_names())
class _QuerySQLCheckerToolInput(BaseModel):
query: str = Field(..., description="A detailed and SQL query to be checked.")
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
@ -104,6 +114,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with sql_db_query!
"""
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: