community[patch], experimental[patch]: support tool-calling sql and p… (#20639)

d agents
This commit is contained in:
Bagatur
2024-04-21 15:43:09 -07:00
committed by GitHub
parent d0cee65cdc
commit 1c7b3c75a7
2 changed files with 49 additions and 20 deletions

View File

@@ -44,7 +44,9 @@ if TYPE_CHECKING:
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: Optional[SQLDatabaseToolkit] = None,
agent_type: Optional[Union[AgentType, Literal["openai-tools"]]] = None,
agent_type: Optional[
Union[AgentType, Literal["openai-tools", "tool-calling"]]
] = None,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
@@ -65,13 +67,15 @@ def create_sql_agent(
"""Construct a SQL agent from an LLM and toolkit or database.
Args:
llm: Language model to use for the agent.
llm: Language model to use for the agent. If agent_type is "tool-calling" then
llm is expected to support tool calling.
toolkit: SQLDatabaseToolkit for the agent to use. Must provide exactly one of
'toolkit' or 'db'. Specify 'toolkit' if you want to use a different model
for the agent and the toolkit.
agent_type: One of "openai-tools", "openai-functions", or
agent_type: One of "tool-calling", "openai-tools", "openai-functions", or
"zero-shot-react-description". Defaults to "zero-shot-react-description".
"openai-tools" is recommended over "openai-functions".
"tool-calling" is recommended over the legacy "openai-tools" and
"openai-functions" types.
callback_manager: DEPRECATED. Pass "callbacks" key into 'agent_executor_kwargs'
instead to pass constructor callbacks to AgentExecutor.
prefix: Prompt prefix string. Must contain variables "top_k" and "dialect".
@@ -107,13 +111,14 @@ def create_sql_agent(
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
agent_executor = create_sql_agent(llm, db=db, agent_type="tool-calling", verbose=True)
""" # noqa: E501
from langchain.agents import (
create_openai_functions_agent,
create_openai_tools_agent,
create_react_agent,
create_tool_calling_agent,
)
from langchain.agents.agent import (
AgentExecutor,
@@ -193,7 +198,7 @@ def create_sql_agent(
return_keys_arg=["output"],
**kwargs,
)
elif agent_type == "openai-tools":
elif agent_type in ("openai-tools", "tool-calling"):
if prompt is None:
messages = [
SystemMessage(content=cast(str, prefix)),
@@ -202,8 +207,12 @@ def create_sql_agent(
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
if agent_type == "openai-tools":
runnable = create_openai_tools_agent(llm, tools, prompt)
else:
runnable = create_tool_calling_agent(llm, tools, prompt)
agent = RunnableMultiActionAgent(
runnable=create_openai_tools_agent(llm, tools, prompt),
runnable=runnable,
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
@@ -212,7 +221,8 @@ def create_sql_agent(
else:
raise ValueError(
f"Agent type {agent_type} not supported at the moment. Must be one of "
"'openai-tools', 'openai-functions', or 'zero-shot-react-description'."
"'tool-calling', 'openai-tools', 'openai-functions', or "
"'zero-shot-react-description'."
)
return AgentExecutor(