mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 19:18:53 +00:00
Fix: Pass along kwargs when creating a sql agent (#2350)
Currently, `agent_toolkits.sql.create_sql_agent()` passes kwargs to the `ZeroShotAgent` that it creates but not to `AgentExecutor` that it also creates. This prevents the caller from providing some useful arguments like `max_iterations` and `early_stopping_method` This PR changes `create_sql_agent` so that it passes kwargs to both constructors. --------- Co-authored-by: Zachary Jones <zjones@zetaglobal.com>
This commit is contained in:
parent
7ed8d00bba
commit
c969a779c9
@ -20,6 +20,8 @@ def create_sql_agent(
|
|||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
|
max_iterations: Optional[int] = 15,
|
||||||
|
early_stopping_method: str = "force",
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
@ -41,5 +43,9 @@ def create_sql_agent(
|
|||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent, tools=toolkit.get_tools(), verbose=verbose
|
agent=agent,
|
||||||
|
tools=toolkit.get_tools(),
|
||||||
|
verbose=verbose,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
early_stopping_method=early_stopping_method,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user