core[patch], community[patch], langchain[patch], docs: Update SQL chains/agents/docs (#16168)

Revamp SQL use cases docs. In the process update SQL chains and agents.
This commit is contained in:
Bagatur
2024-01-22 08:19:08 -08:00
committed by GitHub
parent 05162928c0
commit 1dc6c1ce06
22 changed files with 3618 additions and 1361 deletions

View File

@@ -1,11 +1,11 @@
"""SQL agent."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
@@ -15,22 +15,29 @@ from langchain_core.prompts.chat import (
from langchain_community.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools import BaseTool
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
)
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_community.utilities.sql_database import SQLDatabase
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
agent_type: Optional[AgentType] = None,
toolkit: Optional[SQLDatabaseToolkit] = None,
agent_type: Optional[Union[AgentType, Literal["openai-tools"]]] = None,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
format_instructions: Optional[str] = None,
input_variables: Optional[List[str]] = None,
@@ -41,62 +48,165 @@ def create_sql_agent(
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (),
*,
db: Optional[SQLDatabase] = None,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools."""
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.chains.llm import LLMChain
"""Construct a SQL agent from an LLM and toolkit or database.
Args:
llm: Language model to use for the agent.
toolkit: SQLDatabaseToolkit for the agent to use. Must provide exactly one of
'toolkit' or 'db'. Specify 'toolkit' if you want to use a different model
for the agent and the toolkit.
agent_type: One of "openai-tools", "openai-functions", or
"zero-shot-react-description". Defaults to "zero-shot-react-description".
"openai-tools" is recommended over "openai-functions".
callback_manager: DEPRECATED. Pass "callbacks" key into 'agent_executor_kwargs'
instead to pass constructor callbacks to AgentExecutor.
prefix: Prompt prefix string. Must contain variables "top_k" and "dialect".
suffix: Prompt suffix string. Default depends on agent type.
format_instructions: Formatting instructions to pass to
ZeroShotAgent.create_prompt() when 'agent_type' is
"zero-shot-react-description". Otherwise ignored.
input_variables: DEPRECATED. Input variables to explicitly specify as part of
ZeroShotAgent.create_prompt() when 'agent_type' is
"zero-shot-react-description". Otherwise ignored.
top_k: Number of rows to query for by default.
max_iterations: Passed to AgentExecutor init.
max_execution_time: Passed to AgentExecutor init.
early_stopping_method: Passed to AgentExecutor init.
verbose: AgentExecutor verbosity.
agent_executor_kwargs: Arbitrary additional AgentExecutor args.
extra_tools: Additional tools to give to agent on top of the ones that come with
SQLDatabaseToolkit.
db: SQLDatabase from which to create a SQLDatabaseToolkit. Toolkit is created
using 'db' and 'llm'. Must provide exactly one of 'db' or 'toolkit'.
prompt: Complete agent prompt. prompt and {prefix, suffix, format_instructions,
input_variables} are mutually exclusive.
**kwargs: DEPRECATED. Not used, kept for backwards compatibility.
Returns:
An AgentExecutor with the specified agent_type agent.
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
""" # noqa: E501
from langchain.agents import (
create_openai_functions_agent,
create_openai_tools_agent,
create_react_agent,
)
from langchain.agents.agent import (
AgentExecutor,
RunnableAgent,
RunnableMultiActionAgent,
)
from langchain.agents.agent_types import AgentType
if toolkit is None and db is None:
raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received neither."
)
if toolkit and db:
raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received both."
)
if kwargs:
warnings.warn(
f"Received additional kwargs {kwargs} which are no longer supported."
)
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
tools = toolkit.get_tools() + list(extra_tools)
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent
if prompt is None:
prefix = prefix or SQL_PREFIX
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
else:
if "top_k" in prompt.input_variables:
prompt = prompt.partial(top_k=str(top_k))
if "dialect" in prompt.input_variables:
prompt = prompt.partial(dialect=toolkit.dialect)
db_context = toolkit.get_context()
if "table_info" in prompt.input_variables:
prompt = prompt.partial(table_info=db_context["table_info"])
tools = [
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
]
if "table_names" in prompt.input_variables:
prompt = prompt.partial(table_names=db_context["table_names"])
tools = [
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt_params = (
{"format_instructions": format_instructions}
if format_instructions is not None
else {}
if prompt is None:
from langchain.agents.mrkl import prompt as react_prompt
format_instructions = (
format_instructions or react_prompt.FORMAT_INSTRUCTIONS
)
template = "\n\n".join(
[
react_prompt.PREFIX,
"{tools}",
format_instructions,
react_prompt.SUFFIX,
]
)
prompt = PromptTemplate.from_template(template)
agent = RunnableAgent(
runnable=create_react_agent(llm, tools, prompt),
input_keys_arg=["input"],
return_keys_arg=["output"],
)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix or SQL_SUFFIX,
input_variables=input_variables,
**prompt_params,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
if prompt is None:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt),
input_keys_arg=["input"],
return_keys_arg=["output"],
)
elif agent_type == "openai-tools":
if prompt is None:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableMultiActionAgent(
runnable=create_openai_tools_agent(llm, tools, prompt),
input_keys_arg=["input"],
return_keys_arg=["output"],
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
return AgentExecutor(
name="SQL Agent Executor",
agent=agent,
tools=tools,
callback_manager=callback_manager,

View File

@@ -69,3 +69,7 @@ class SQLDatabaseToolkit(BaseToolkit):
list_sql_database_tool,
query_sql_checker_tool,
]
def get_context(self) -> dict:
"""Return db context that you may want in agent prompt."""
return self.db.get_context()

View File

@@ -1,10 +1,10 @@
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
import warnings
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
import sqlalchemy
from langchain_core._api import deprecated
from langchain_core.utils import get_from_env
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy.engine import Engine
@@ -272,11 +272,9 @@ class SQLDatabase:
return sorted(self._include_tables)
return sorted(self._all_tables - self._ignore_tables)
@deprecated("0.0.1", alternative="get_usable_table_name", removal="0.2.0")
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
@property
@@ -487,3 +485,9 @@ class SQLDatabase:
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
def get_context(self) -> Dict[str, Any]:
"""Return db context that you may want in agent prompt."""
table_names = list(self.get_usable_table_names())
table_info = self.get_table_info_no_throw()
return {"table_info": table_info, "table_names": ", ".join(table_names)}

View File

@@ -74,6 +74,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
vectorstore_cls: Type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
*,
example_keys: Optional[List[str]] = None,
vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector:
"""Create k-shot example selector using example list and embeddings.
@@ -102,7 +105,13 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
vectorstore = vectorstore_cls.from_texts(
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
)
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
return cls(
vectorstore=vectorstore,
k=k,
input_keys=input_keys,
example_keys=example_keys,
vectorstore_kwargs=vectorstore_kwargs,
)
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):

View File

@@ -343,8 +343,8 @@ class RunnableAgent(BaseSingleActionAgent):
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
"""Runnable to call to get agent action."""
_input_keys: List[str] = []
"""Input keys."""
input_keys_arg: List[str] = []
return_keys_arg: List[str] = []
class Config:
"""Configuration for this pydantic object."""
@@ -354,16 +354,11 @@ class RunnableAgent(BaseSingleActionAgent):
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return []
return self.return_keys_arg
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
Returns:
List of input keys.
"""
return self._input_keys
return self.input_keys_arg
def plan(
self,
@@ -439,8 +434,8 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
"""Runnable to call to get agent actions."""
_input_keys: List[str] = []
"""Input keys."""
input_keys_arg: List[str] = []
return_keys_arg: List[str] = []
class Config:
"""Configuration for this pydantic object."""
@@ -450,7 +445,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return []
return self.return_keys_arg
@property
def input_keys(self) -> List[str]:
@@ -459,7 +454,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
Returns:
List of input keys.
"""
return self._input_keys
return self.input_keys_arg
def plan(
self,

View File

@@ -83,9 +83,9 @@ class ZeroShotAgent(Agent):
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
if input_variables:
return PromptTemplate(template=template, input_variables=input_variables)
return PromptTemplate.from_template(template)
@classmethod
def from_llm_and_tools(

View File

@@ -131,7 +131,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
@@ -178,7 +178,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)

View File

@@ -1,10 +1,10 @@
from typing import List, Optional, TypedDict, Union
from typing import Any, Dict, List, Optional, TypedDict, Union
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnableParallel
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
@@ -31,7 +31,7 @@ def create_sql_query_chain(
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
) -> Runnable[Union[SQLInput, SQLInputWithTables], str]:
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
"""Create a chain that generates SQL queries.
*Security Note*: This chain generates SQL queries for the given database.
@@ -50,34 +50,93 @@ def create_sql_query_chain(
See https://python.langchain.com/docs/security for more information.
Args:
llm: The language model to use
db: The SQLDatabase to generate the query for
llm: The language model to use.
db: The SQLDatabase to generate the query for.
prompt: The prompt to use. If none is provided, will choose one
based on dialect. Defaults to None.
based on dialect. Defaults to None. See Prompt section below for more.
k: The number of results per select statement to return. Defaults to 5.
Returns:
A chain that takes in a question and generates a SQL query that answers
that question.
"""
Example:
.. code-block:: python
# pip install -U langchain langchain-community langchain-openai
from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
Prompt:
If no prompt is provided, a default prompt is selected based on the SQLDatabase dialect. If one is provided, it must support input variables:
* input: The user question plus suffix "\nSQLQuery: " is passed here.
* top_k: The number of results per select statement (the `k` argument to
this function) is passed in here.
* table_info: Table definitions and sample rows are passed in here. If the
user specifies "table_names_to_use" when invoking chain, only those
will be included. Otherwise, all tables are included.
* dialect (optional): If dialect input variable is in prompt, the db
dialect will be passed in here.
Here's an example prompt:
.. code-block:: python
from langchain_core.prompts import PromptTemplate
template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}.
Question: {input}'''
prompt = PromptTemplate.from_template(template)
""" # noqa: E501
if prompt is not None:
prompt_to_use = prompt
elif db.dialect in SQL_PROMPTS:
prompt_to_use = SQL_PROMPTS[db.dialect]
else:
prompt_to_use = PROMPT
if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables):
raise ValueError(
f"Prompt must have input variables: 'input', 'top_k', "
f"'table_info'. Received prompt with input variables: "
f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}"
)
if "dialect" in prompt_to_use.input_variables:
prompt_to_use = prompt_to_use.partial(dialect=db.dialect)
inputs = {
"input": lambda x: x["question"] + "\nSQLQuery: ",
"top_k": lambda _: k,
"table_info": lambda x: db.get_table_info(
table_names=x.get("table_names_to_use")
),
}
if "dialect" in prompt_to_use.input_variables:
inputs["dialect"] = lambda _: (db.dialect, prompt_to_use)
return (
RunnableParallel(inputs)
| prompt_to_use
RunnablePassthrough.assign(**inputs) # type: ignore
| (
lambda x: {
k: v
for k, v in x.items()
if k not in ("question", "table_names_to_use")
}
)
| prompt_to_use.partial(top_k=str(k))
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
| _strip

View File

@@ -1,3 +1,7 @@
from functools import partial
from typing import Optional
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
@@ -10,8 +14,37 @@ class RetrieverInput(BaseModel):
query: str = Field(description="query to look up in retriever")
def _get_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
) -> str:
docs = retriever.get_relevant_documents(query)
return document_separator.join(
format_document(doc, document_prompt) for doc in docs
)
async def _aget_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
) -> str:
docs = await retriever.aget_relevant_documents(query)
return document_separator.join(
format_document(doc, document_prompt) for doc in docs
)
def create_retriever_tool(
retriever: BaseRetriever, name: str, description: str
retriever: BaseRetriever,
name: str,
description: str,
*,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
) -> Tool:
"""Create a tool to do retrieval of documents.
@@ -25,10 +58,23 @@ def create_retriever_tool(
Returns:
Tool class to pass to an agent
"""
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial(
_get_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
afunc = partial(
_aget_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
return Tool(
name=name,
description=description,
func=retriever.get_relevant_documents,
coroutine=retriever.aget_relevant_documents,
func=func,
coroutine=afunc,
args_schema=RetrieverInput,
)