mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
community[patch]: add args_schema to sql_database tools for langGraph integration (#18595)
- **Description:** This modification adds pydantic input definition for sql_database tools. This helps for function calling capability in LangGraph. Since actions nodes will usually check for the args_schema attribute on tools, This update should make these tools compatible with it (only implemented on the InfoSQLDatabaseTool) - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** juanfe8881
This commit is contained in:
parent
c922ea36cb
commit
f5b9aedc48
@ -39,6 +39,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
If the query is not correct, an error message will be returned.
|
If the query is not correct, an error message will be returned.
|
||||||
If an error is returned, rewrite the query, check the query, and try again.
|
If an error is returned, rewrite the query, check the query, and try again.
|
||||||
"""
|
"""
|
||||||
|
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
@ -77,11 +78,16 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _ListSQLDataBaseToolInput(BaseModel):
|
||||||
|
tool_input: str = Field(..., description="An empty string")
|
||||||
|
|
||||||
|
|
||||||
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||||
"""Tool for getting tables names."""
|
"""Tool for getting tables names."""
|
||||||
|
|
||||||
name: str = "sql_db_list_tables"
|
name: str = "sql_db_list_tables"
|
||||||
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
|
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||||
|
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
@ -92,6 +98,10 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
return ", ".join(self.db.get_usable_table_names())
|
return ", ".join(self.db.get_usable_table_names())
|
||||||
|
|
||||||
|
|
||||||
|
class _QuerySQLCheckerToolInput(BaseModel):
|
||||||
|
query: str = Field(..., description="A detailed and SQL query to be checked.")
|
||||||
|
|
||||||
|
|
||||||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||||
"""Use an LLM to check if a query is correct.
|
"""Use an LLM to check if a query is correct.
|
||||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||||
@ -104,6 +114,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
Use this tool to double check if your query is correct before executing it.
|
Use this tool to double check if your query is correct before executing it.
|
||||||
Always use this tool before executing a query with sql_db_query!
|
Always use this tool before executing a query with sql_db_query!
|
||||||
"""
|
"""
|
||||||
|
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
Loading…
Reference in New Issue
Block a user