mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 00:58:32 +00:00
core[patch], community[patch], langchain[patch], docs: Update SQL chains/agents/docs (#16168)
Revamp SQL use cases docs. In the process update SQL chains and agents.
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""SQL agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
@@ -15,22 +15,29 @@ from langchain_core.prompts.chat import (
|
||||
from langchain_community.agent_toolkits.sql.prompt import (
|
||||
SQL_FUNCTIONS_SUFFIX,
|
||||
SQL_PREFIX,
|
||||
SQL_SUFFIX,
|
||||
)
|
||||
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_community.tools import BaseTool
|
||||
from langchain_community.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
|
||||
|
||||
def create_sql_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: SQLDatabaseToolkit,
|
||||
agent_type: Optional[AgentType] = None,
|
||||
toolkit: Optional[SQLDatabaseToolkit] = None,
|
||||
agent_type: Optional[Union[AgentType, Literal["openai-tools"]]] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = SQL_PREFIX,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
format_instructions: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
@@ -41,62 +48,165 @@ def create_sql_agent(
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
*,
|
||||
db: Optional[SQLDatabase] = None,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct an SQL agent from an LLM and tools."""
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.chains.llm import LLMChain
|
||||
"""Construct a SQL agent from an LLM and toolkit or database.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for the agent.
|
||||
toolkit: SQLDatabaseToolkit for the agent to use. Must provide exactly one of
|
||||
'toolkit' or 'db'. Specify 'toolkit' if you want to use a different model
|
||||
for the agent and the toolkit.
|
||||
agent_type: One of "openai-tools", "openai-functions", or
|
||||
"zero-shot-react-description". Defaults to "zero-shot-react-description".
|
||||
"openai-tools" is recommended over "openai-functions".
|
||||
callback_manager: DEPRECATED. Pass "callbacks" key into 'agent_executor_kwargs'
|
||||
instead to pass constructor callbacks to AgentExecutor.
|
||||
prefix: Prompt prefix string. Must contain variables "top_k" and "dialect".
|
||||
suffix: Prompt suffix string. Default depends on agent type.
|
||||
format_instructions: Formatting instructions to pass to
|
||||
ZeroShotAgent.create_prompt() when 'agent_type' is
|
||||
"zero-shot-react-description". Otherwise ignored.
|
||||
input_variables: DEPRECATED. Input variables to explicitly specify as part of
|
||||
ZeroShotAgent.create_prompt() when 'agent_type' is
|
||||
"zero-shot-react-description". Otherwise ignored.
|
||||
top_k: Number of rows to query for by default.
|
||||
max_iterations: Passed to AgentExecutor init.
|
||||
max_execution_time: Passed to AgentExecutor init.
|
||||
early_stopping_method: Passed to AgentExecutor init.
|
||||
verbose: AgentExecutor verbosity.
|
||||
agent_executor_kwargs: Arbitrary additional AgentExecutor args.
|
||||
extra_tools: Additional tools to give to agent on top of the ones that come with
|
||||
SQLDatabaseToolkit.
|
||||
db: SQLDatabase from which to create a SQLDatabaseToolkit. Toolkit is created
|
||||
using 'db' and 'llm'. Must provide exactly one of 'db' or 'toolkit'.
|
||||
prompt: Complete agent prompt. prompt and {prefix, suffix, format_instructions,
|
||||
input_variables} are mutually exclusive.
|
||||
**kwargs: DEPRECATED. Not used, kept for backwards compatibility.
|
||||
|
||||
Returns:
|
||||
An AgentExecutor with the specified agent_type agent.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
|
||||
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
||||
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
|
||||
|
||||
""" # noqa: E501
|
||||
from langchain.agents import (
|
||||
create_openai_functions_agent,
|
||||
create_openai_tools_agent,
|
||||
create_react_agent,
|
||||
)
|
||||
from langchain.agents.agent import (
|
||||
AgentExecutor,
|
||||
RunnableAgent,
|
||||
RunnableMultiActionAgent,
|
||||
)
|
||||
from langchain.agents.agent_types import AgentType
|
||||
|
||||
if toolkit is None and db is None:
|
||||
raise ValueError(
|
||||
"Must provide exactly one of 'toolkit' or 'db'. Received neither."
|
||||
)
|
||||
if toolkit and db:
|
||||
raise ValueError(
|
||||
"Must provide exactly one of 'toolkit' or 'db'. Received both."
|
||||
)
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
f"Received additional kwargs {kwargs} which are no longer supported."
|
||||
)
|
||||
|
||||
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
|
||||
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
tools = toolkit.get_tools() + list(extra_tools)
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
agent: BaseSingleActionAgent
|
||||
if prompt is None:
|
||||
prefix = prefix or SQL_PREFIX
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
else:
|
||||
if "top_k" in prompt.input_variables:
|
||||
prompt = prompt.partial(top_k=str(top_k))
|
||||
if "dialect" in prompt.input_variables:
|
||||
prompt = prompt.partial(dialect=toolkit.dialect)
|
||||
db_context = toolkit.get_context()
|
||||
if "table_info" in prompt.input_variables:
|
||||
prompt = prompt.partial(table_info=db_context["table_info"])
|
||||
tools = [
|
||||
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
|
||||
]
|
||||
if "table_names" in prompt.input_variables:
|
||||
prompt = prompt.partial(table_names=db_context["table_names"])
|
||||
tools = [
|
||||
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
|
||||
]
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt_params = (
|
||||
{"format_instructions": format_instructions}
|
||||
if format_instructions is not None
|
||||
else {}
|
||||
if prompt is None:
|
||||
from langchain.agents.mrkl import prompt as react_prompt
|
||||
|
||||
format_instructions = (
|
||||
format_instructions or react_prompt.FORMAT_INSTRUCTIONS
|
||||
)
|
||||
template = "\n\n".join(
|
||||
[
|
||||
react_prompt.PREFIX,
|
||||
"{tools}",
|
||||
format_instructions,
|
||||
react_prompt.SUFFIX,
|
||||
]
|
||||
)
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
agent = RunnableAgent(
|
||||
runnable=create_react_agent(llm, tools, prompt),
|
||||
input_keys_arg=["input"],
|
||||
return_keys_arg=["output"],
|
||||
)
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix or SQL_SUFFIX,
|
||||
input_variables=input_variables,
|
||||
**prompt_params,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
messages = [
|
||||
SystemMessage(content=prefix),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
if prompt is None:
|
||||
messages = [
|
||||
SystemMessage(content=prefix),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
agent = RunnableAgent(
|
||||
runnable=create_openai_functions_agent(llm, tools, prompt),
|
||||
input_keys_arg=["input"],
|
||||
return_keys_arg=["output"],
|
||||
)
|
||||
elif agent_type == "openai-tools":
|
||||
if prompt is None:
|
||||
messages = [
|
||||
SystemMessage(content=prefix),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
agent = RunnableMultiActionAgent(
|
||||
runnable=create_openai_tools_agent(llm, tools, prompt),
|
||||
input_keys_arg=["input"],
|
||||
return_keys_arg=["output"],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
return AgentExecutor(
|
||||
name="SQL Agent Executor",
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
|
@@ -69,3 +69,7 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
list_sql_database_tool,
|
||||
query_sql_checker_tool,
|
||||
]
|
||||
|
||||
def get_context(self) -> dict:
|
||||
"""Return db context that you may want in agent prompt."""
|
||||
return self.db.get_context()
|
||||
|
@@ -1,10 +1,10 @@
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
|
||||
|
||||
import sqlalchemy
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.utils import get_from_env
|
||||
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||
from sqlalchemy.engine import Engine
|
||||
@@ -272,11 +272,9 @@ class SQLDatabase:
|
||||
return sorted(self._include_tables)
|
||||
return sorted(self._all_tables - self._ignore_tables)
|
||||
|
||||
@deprecated("0.0.1", alternative="get_usable_table_name", removal="0.2.0")
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
warnings.warn(
|
||||
"This method is deprecated - please use `get_usable_table_names`."
|
||||
)
|
||||
return self.get_usable_table_names()
|
||||
|
||||
@property
|
||||
@@ -487,3 +485,9 @@ class SQLDatabase:
|
||||
except SQLAlchemyError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def get_context(self) -> Dict[str, Any]:
|
||||
"""Return db context that you may want in agent prompt."""
|
||||
table_names = list(self.get_usable_table_names())
|
||||
table_info = self.get_table_info_no_throw()
|
||||
return {"table_info": table_info, "table_names": ", ".join(table_names)}
|
||||
|
@@ -74,6 +74,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
*,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
"""Create k-shot example selector using example list and embeddings.
|
||||
@@ -102,7 +105,13 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
|
||||
return cls(
|
||||
vectorstore=vectorstore,
|
||||
k=k,
|
||||
input_keys=input_keys,
|
||||
example_keys=example_keys,
|
||||
vectorstore_kwargs=vectorstore_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
|
||||
|
@@ -343,8 +343,8 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
|
||||
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
|
||||
"""Runnable to call to get agent action."""
|
||||
_input_keys: List[str] = []
|
||||
"""Input keys."""
|
||||
input_keys_arg: List[str] = []
|
||||
return_keys_arg: List[str] = []
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -354,16 +354,11 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
"""Return values of the agent."""
|
||||
return []
|
||||
return self.return_keys_arg
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
Returns:
|
||||
List of input keys.
|
||||
"""
|
||||
return self._input_keys
|
||||
return self.input_keys_arg
|
||||
|
||||
def plan(
|
||||
self,
|
||||
@@ -439,8 +434,8 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
|
||||
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
|
||||
"""Runnable to call to get agent actions."""
|
||||
_input_keys: List[str] = []
|
||||
"""Input keys."""
|
||||
input_keys_arg: List[str] = []
|
||||
return_keys_arg: List[str] = []
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -450,7 +445,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
"""Return values of the agent."""
|
||||
return []
|
||||
return self.return_keys_arg
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -459,7 +454,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
Returns:
|
||||
List of input keys.
|
||||
"""
|
||||
return self._input_keys
|
||||
return self.input_keys_arg
|
||||
|
||||
def plan(
|
||||
self,
|
||||
|
@@ -83,9 +83,9 @@ class ZeroShotAgent(Agent):
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
if input_variables:
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
return PromptTemplate.from_template(template)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
|
@@ -131,7 +131,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name")
|
||||
run_name = config.get("run_name") or self.get_name()
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
@@ -178,7 +178,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name")
|
||||
run_name = config.get("run_name") or self.get_name()
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from typing import List, Optional, TypedDict, Union
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnableParallel
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
|
||||
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
|
||||
|
||||
@@ -31,7 +31,7 @@ def create_sql_query_chain(
|
||||
db: SQLDatabase,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
k: int = 5,
|
||||
) -> Runnable[Union[SQLInput, SQLInputWithTables], str]:
|
||||
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
|
||||
"""Create a chain that generates SQL queries.
|
||||
|
||||
*Security Note*: This chain generates SQL queries for the given database.
|
||||
@@ -50,34 +50,93 @@ def create_sql_query_chain(
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
Args:
|
||||
llm: The language model to use
|
||||
db: The SQLDatabase to generate the query for
|
||||
llm: The language model to use.
|
||||
db: The SQLDatabase to generate the query for.
|
||||
prompt: The prompt to use. If none is provided, will choose one
|
||||
based on dialect. Defaults to None.
|
||||
based on dialect. Defaults to None. See Prompt section below for more.
|
||||
k: The number of results per select statement to return. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
A chain that takes in a question and generates a SQL query that answers
|
||||
that question.
|
||||
"""
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pip install -U langchain langchain-community langchain-openai
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.chains import create_sql_query_chain
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
|
||||
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
||||
chain = create_sql_query_chain(llm, db)
|
||||
response = chain.invoke({"question": "How many employees are there"})
|
||||
|
||||
Prompt:
|
||||
If no prompt is provided, a default prompt is selected based on the SQLDatabase dialect. If one is provided, it must support input variables:
|
||||
* input: The user question plus suffix "\nSQLQuery: " is passed here.
|
||||
* top_k: The number of results per select statement (the `k` argument to
|
||||
this function) is passed in here.
|
||||
* table_info: Table definitions and sample rows are passed in here. If the
|
||||
user specifies "table_names_to_use" when invoking chain, only those
|
||||
will be included. Otherwise, all tables are included.
|
||||
* dialect (optional): If dialect input variable is in prompt, the db
|
||||
dialect will be passed in here.
|
||||
|
||||
Here's an example prompt:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||
Use the following format:
|
||||
|
||||
Question: "Question here"
|
||||
SQLQuery: "SQL Query to run"
|
||||
SQLResult: "Result of the SQLQuery"
|
||||
Answer: "Final answer here"
|
||||
|
||||
Only use the following tables:
|
||||
|
||||
{table_info}.
|
||||
|
||||
Question: {input}'''
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
""" # noqa: E501
|
||||
if prompt is not None:
|
||||
prompt_to_use = prompt
|
||||
elif db.dialect in SQL_PROMPTS:
|
||||
prompt_to_use = SQL_PROMPTS[db.dialect]
|
||||
else:
|
||||
prompt_to_use = PROMPT
|
||||
if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables):
|
||||
raise ValueError(
|
||||
f"Prompt must have input variables: 'input', 'top_k', "
|
||||
f"'table_info'. Received prompt with input variables: "
|
||||
f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}"
|
||||
)
|
||||
if "dialect" in prompt_to_use.input_variables:
|
||||
prompt_to_use = prompt_to_use.partial(dialect=db.dialect)
|
||||
|
||||
inputs = {
|
||||
"input": lambda x: x["question"] + "\nSQLQuery: ",
|
||||
"top_k": lambda _: k,
|
||||
"table_info": lambda x: db.get_table_info(
|
||||
table_names=x.get("table_names_to_use")
|
||||
),
|
||||
}
|
||||
if "dialect" in prompt_to_use.input_variables:
|
||||
inputs["dialect"] = lambda _: (db.dialect, prompt_to_use)
|
||||
return (
|
||||
RunnableParallel(inputs)
|
||||
| prompt_to_use
|
||||
RunnablePassthrough.assign(**inputs) # type: ignore
|
||||
| (
|
||||
lambda x: {
|
||||
k: v
|
||||
for k, v in x.items()
|
||||
if k not in ("question", "table_names_to_use")
|
||||
}
|
||||
)
|
||||
| prompt_to_use.partial(top_k=str(k))
|
||||
| llm.bind(stop=["\nSQLResult:"])
|
||||
| StrOutputParser()
|
||||
| _strip
|
||||
|
@@ -1,3 +1,7 @@
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
@@ -10,8 +14,37 @@ class RetrieverInput(BaseModel):
|
||||
query: str = Field(description="query to look up in retriever")
|
||||
|
||||
|
||||
def _get_relevant_documents(
|
||||
query: str,
|
||||
retriever: BaseRetriever,
|
||||
document_prompt: BasePromptTemplate,
|
||||
document_separator: str,
|
||||
) -> str:
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return document_separator.join(
|
||||
format_document(doc, document_prompt) for doc in docs
|
||||
)
|
||||
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
query: str,
|
||||
retriever: BaseRetriever,
|
||||
document_prompt: BasePromptTemplate,
|
||||
document_separator: str,
|
||||
) -> str:
|
||||
docs = await retriever.aget_relevant_documents(query)
|
||||
return document_separator.join(
|
||||
format_document(doc, document_prompt) for doc in docs
|
||||
)
|
||||
|
||||
|
||||
def create_retriever_tool(
|
||||
retriever: BaseRetriever, name: str, description: str
|
||||
retriever: BaseRetriever,
|
||||
name: str,
|
||||
description: str,
|
||||
*,
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_separator: str = "\n\n",
|
||||
) -> Tool:
|
||||
"""Create a tool to do retrieval of documents.
|
||||
|
||||
@@ -25,10 +58,23 @@ def create_retriever_tool(
|
||||
Returns:
|
||||
Tool class to pass to an agent
|
||||
"""
|
||||
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
|
||||
func = partial(
|
||||
_get_relevant_documents,
|
||||
retriever=retriever,
|
||||
document_prompt=document_prompt,
|
||||
document_separator=document_separator,
|
||||
)
|
||||
afunc = partial(
|
||||
_aget_relevant_documents,
|
||||
retriever=retriever,
|
||||
document_prompt=document_prompt,
|
||||
document_separator=document_separator,
|
||||
)
|
||||
return Tool(
|
||||
name=name,
|
||||
description=description,
|
||||
func=retriever.get_relevant_documents,
|
||||
coroutine=retriever.aget_relevant_documents,
|
||||
func=func,
|
||||
coroutine=afunc,
|
||||
args_schema=RetrieverInput,
|
||||
)
|
||||
|
Reference in New Issue
Block a user