mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
core[patch], community[patch], langchain[patch], docs: Update SQL chains/agents/docs (#16168)
Revamp SQL use cases docs. In the process update SQL chains and agents.
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""SQL agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
@@ -15,22 +15,29 @@ from langchain_core.prompts.chat import (
|
||||
from langchain_community.agent_toolkits.sql.prompt import (
|
||||
SQL_FUNCTIONS_SUFFIX,
|
||||
SQL_PREFIX,
|
||||
SQL_SUFFIX,
|
||||
)
|
||||
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_community.tools import BaseTool
|
||||
from langchain_community.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
|
||||
|
||||
def create_sql_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: SQLDatabaseToolkit,
|
||||
agent_type: Optional[AgentType] = None,
|
||||
toolkit: Optional[SQLDatabaseToolkit] = None,
|
||||
agent_type: Optional[Union[AgentType, Literal["openai-tools"]]] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = SQL_PREFIX,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
format_instructions: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
@@ -41,62 +48,165 @@ def create_sql_agent(
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
*,
|
||||
db: Optional[SQLDatabase] = None,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct an SQL agent from an LLM and tools."""
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.chains.llm import LLMChain
|
||||
"""Construct a SQL agent from an LLM and toolkit or database.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for the agent.
|
||||
toolkit: SQLDatabaseToolkit for the agent to use. Must provide exactly one of
|
||||
'toolkit' or 'db'. Specify 'toolkit' if you want to use a different model
|
||||
for the agent and the toolkit.
|
||||
agent_type: One of "openai-tools", "openai-functions", or
|
||||
"zero-shot-react-description". Defaults to "zero-shot-react-description".
|
||||
"openai-tools" is recommended over "openai-functions".
|
||||
callback_manager: DEPRECATED. Pass "callbacks" key into 'agent_executor_kwargs'
|
||||
instead to pass constructor callbacks to AgentExecutor.
|
||||
prefix: Prompt prefix string. Must contain variables "top_k" and "dialect".
|
||||
suffix: Prompt suffix string. Default depends on agent type.
|
||||
format_instructions: Formatting instructions to pass to
|
||||
ZeroShotAgent.create_prompt() when 'agent_type' is
|
||||
"zero-shot-react-description". Otherwise ignored.
|
||||
input_variables: DEPRECATED. Input variables to explicitly specify as part of
|
||||
ZeroShotAgent.create_prompt() when 'agent_type' is
|
||||
"zero-shot-react-description". Otherwise ignored.
|
||||
top_k: Number of rows to query for by default.
|
||||
max_iterations: Passed to AgentExecutor init.
|
||||
max_execution_time: Passed to AgentExecutor init.
|
||||
early_stopping_method: Passed to AgentExecutor init.
|
||||
verbose: AgentExecutor verbosity.
|
||||
agent_executor_kwargs: Arbitrary additional AgentExecutor args.
|
||||
extra_tools: Additional tools to give to agent on top of the ones that come with
|
||||
SQLDatabaseToolkit.
|
||||
db: SQLDatabase from which to create a SQLDatabaseToolkit. Toolkit is created
|
||||
using 'db' and 'llm'. Must provide exactly one of 'db' or 'toolkit'.
|
||||
prompt: Complete agent prompt. prompt and {prefix, suffix, format_instructions,
|
||||
input_variables} are mutually exclusive.
|
||||
**kwargs: DEPRECATED. Not used, kept for backwards compatibility.
|
||||
|
||||
Returns:
|
||||
An AgentExecutor with the specified agent_type agent.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
|
||||
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
||||
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
|
||||
|
||||
""" # noqa: E501
|
||||
from langchain.agents import (
|
||||
create_openai_functions_agent,
|
||||
create_openai_tools_agent,
|
||||
create_react_agent,
|
||||
)
|
||||
from langchain.agents.agent import (
|
||||
AgentExecutor,
|
||||
RunnableAgent,
|
||||
RunnableMultiActionAgent,
|
||||
)
|
||||
from langchain.agents.agent_types import AgentType
|
||||
|
||||
if toolkit is None and db is None:
|
||||
raise ValueError(
|
||||
"Must provide exactly one of 'toolkit' or 'db'. Received neither."
|
||||
)
|
||||
if toolkit and db:
|
||||
raise ValueError(
|
||||
"Must provide exactly one of 'toolkit' or 'db'. Received both."
|
||||
)
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
f"Received additional kwargs {kwargs} which are no longer supported."
|
||||
)
|
||||
|
||||
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
|
||||
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
tools = toolkit.get_tools() + list(extra_tools)
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
agent: BaseSingleActionAgent
|
||||
if prompt is None:
|
||||
prefix = prefix or SQL_PREFIX
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
else:
|
||||
if "top_k" in prompt.input_variables:
|
||||
prompt = prompt.partial(top_k=str(top_k))
|
||||
if "dialect" in prompt.input_variables:
|
||||
prompt = prompt.partial(dialect=toolkit.dialect)
|
||||
db_context = toolkit.get_context()
|
||||
if "table_info" in prompt.input_variables:
|
||||
prompt = prompt.partial(table_info=db_context["table_info"])
|
||||
tools = [
|
||||
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
|
||||
]
|
||||
if "table_names" in prompt.input_variables:
|
||||
prompt = prompt.partial(table_names=db_context["table_names"])
|
||||
tools = [
|
||||
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
|
||||
]
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt_params = (
|
||||
{"format_instructions": format_instructions}
|
||||
if format_instructions is not None
|
||||
else {}
|
||||
if prompt is None:
|
||||
from langchain.agents.mrkl import prompt as react_prompt
|
||||
|
||||
format_instructions = (
|
||||
format_instructions or react_prompt.FORMAT_INSTRUCTIONS
|
||||
)
|
||||
template = "\n\n".join(
|
||||
[
|
||||
react_prompt.PREFIX,
|
||||
"{tools}",
|
||||
format_instructions,
|
||||
react_prompt.SUFFIX,
|
||||
]
|
||||
)
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
agent = RunnableAgent(
|
||||
runnable=create_react_agent(llm, tools, prompt),
|
||||
input_keys_arg=["input"],
|
||||
return_keys_arg=["output"],
|
||||
)
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix or SQL_SUFFIX,
|
||||
input_variables=input_variables,
|
||||
**prompt_params,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
messages = [
|
||||
SystemMessage(content=prefix),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
if prompt is None:
|
||||
messages = [
|
||||
SystemMessage(content=prefix),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
agent = RunnableAgent(
|
||||
runnable=create_openai_functions_agent(llm, tools, prompt),
|
||||
input_keys_arg=["input"],
|
||||
return_keys_arg=["output"],
|
||||
)
|
||||
elif agent_type == "openai-tools":
|
||||
if prompt is None:
|
||||
messages = [
|
||||
SystemMessage(content=prefix),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
agent = RunnableMultiActionAgent(
|
||||
runnable=create_openai_tools_agent(llm, tools, prompt),
|
||||
input_keys_arg=["input"],
|
||||
return_keys_arg=["output"],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
return AgentExecutor(
|
||||
name="SQL Agent Executor",
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
|
Reference in New Issue
Block a user