mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
parent
1b480914b4
commit
1a34c65e01
@ -1,7 +1,6 @@
|
|||||||
"""SQL agent."""
|
"""SQL agent."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -93,7 +92,7 @@ def create_sql_agent(
|
|||||||
using 'db' and 'llm'. Must provide exactly one of 'db' or 'toolkit'.
|
using 'db' and 'llm'. Must provide exactly one of 'db' or 'toolkit'.
|
||||||
prompt: Complete agent prompt. prompt and {prefix, suffix, format_instructions,
|
prompt: Complete agent prompt. prompt and {prefix, suffix, format_instructions,
|
||||||
input_variables} are mutually exclusive.
|
input_variables} are mutually exclusive.
|
||||||
**kwargs: DEPRECATED. Not used, kept for backwards compatibility.
|
**kwargs: Arbitrary additional Agent args.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An AgentExecutor with the specified agent_type agent.
|
An AgentExecutor with the specified agent_type agent.
|
||||||
@ -131,13 +130,6 @@ def create_sql_agent(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must provide exactly one of 'toolkit' or 'db'. Received both."
|
"Must provide exactly one of 'toolkit' or 'db'. Received both."
|
||||||
)
|
)
|
||||||
if input_variables:
|
|
||||||
kwargs = kwargs or {}
|
|
||||||
kwargs["input_variables"] = input_variables
|
|
||||||
if kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
f"Received additional kwargs {kwargs} which are no longer supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
|
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
|
||||||
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||||
@ -183,6 +175,7 @@ def create_sql_agent(
|
|||||||
runnable=create_react_agent(llm, tools, prompt),
|
runnable=create_react_agent(llm, tools, prompt),
|
||||||
input_keys_arg=["input"],
|
input_keys_arg=["input"],
|
||||||
return_keys_arg=["output"],
|
return_keys_arg=["output"],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||||
@ -198,6 +191,7 @@ def create_sql_agent(
|
|||||||
runnable=create_openai_functions_agent(llm, tools, prompt),
|
runnable=create_openai_functions_agent(llm, tools, prompt),
|
||||||
input_keys_arg=["input"],
|
input_keys_arg=["input"],
|
||||||
return_keys_arg=["output"],
|
return_keys_arg=["output"],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif agent_type == "openai-tools":
|
elif agent_type == "openai-tools":
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
@ -212,6 +206,7 @@ def create_sql_agent(
|
|||||||
runnable=create_openai_tools_agent(llm, tools, prompt),
|
runnable=create_openai_tools_agent(llm, tools, prompt),
|
||||||
input_keys_arg=["input"],
|
input_keys_arg=["input"],
|
||||||
return_keys_arg=["output"],
|
return_keys_arg=["output"],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user