mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-27 05:20:34 +00:00
First iteration of sqlagent (#633)
Co-authored-by: lesscomfortable <pancho_ingham@hotmail.com>
This commit is contained in:
parent
d54fd20ba4
commit
456b329baa
@ -3,7 +3,7 @@ from langchain.agents.agent import Agent, AgentExecutor
|
||||
from langchain.agents.conversational.base import ConversationalAgent
|
||||
from langchain.agents.load_tools import get_all_tool_names, load_tools
|
||||
from langchain.agents.loading import initialize_agent
|
||||
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
|
||||
from langchain.agents.mrkl.base import MRKLChain, SQLAgent, ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.agents.tools import Tool
|
||||
@ -17,6 +17,7 @@ __all__ = [
|
||||
"Tool",
|
||||
"initialize_agent",
|
||||
"ZeroShotAgent",
|
||||
"SQLAgent",
|
||||
"ReActTextWorldAgent",
|
||||
"load_tools",
|
||||
"get_all_tool_names",
|
||||
|
@ -4,10 +4,12 @@ from __future__ import annotations
|
||||
import re
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.agents.agent import Agent, AgentExecutor
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.mrkl.sql_prompt import SQL_PREFIX, SQL_SUFFIX
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.base import BaseLLM, BaseCallbackManager
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
@ -100,6 +102,53 @@ class ZeroShotAgent(Agent):
|
||||
return get_action_and_input(text)
|
||||
|
||||
|
||||
class SQLAgent(ZeroShotAgent):
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: List[Tool],
|
||||
prefix: str = SQL_PREFIX,
|
||||
suffix: str = SQL_SUFFIX,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
return super().create_prompt(tools, prefix, suffix, input_variables)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_sql_tool(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
sql_tool: Tool,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and SQL Chain tool."""
|
||||
|
||||
cls._validate_tool(sql_tool)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=cls.create_prompt([sql_tool]),
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_tool(cls, tool: Tool) -> None:
|
||||
|
||||
if isinstance(tool, List):
|
||||
raise TypeError("The SQLAgent must be used with only one tool.")
|
||||
|
||||
if tool.func.__self__.__class__.__name__ != "SQLDatabaseChain":
|
||||
raise ValueError(
|
||||
"The SQLAgent must be used with an 'SQLDatabaseChain' based tool."
|
||||
)
|
||||
|
||||
if tool.description is None:
|
||||
raise ValueError(
|
||||
f"Got a tool {tool.name} without a description. For this agent, "
|
||||
f"a description must always be provided."
|
||||
)
|
||||
|
||||
|
||||
class MRKLChain(AgentExecutor):
|
||||
"""Chain that implements the MRKL system.
|
||||
|
||||
@ -109,9 +158,8 @@ class MRKLChain(AgentExecutor):
|
||||
from langchain import OpenAI, MRKLChain
|
||||
from langchain.chains.mrkl.base import ChainConfig
|
||||
llm = OpenAI(temperature=0)
|
||||
prompt = PromptTemplate(...)
|
||||
chains = [...]
|
||||
mrkl = MRKLChain.from_chains(llm=llm, prompt=prompt)
|
||||
mrkl = MRKLChain.from_chains(llm=llm, chains=chains)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@ -157,5 +205,6 @@ class MRKLChain(AgentExecutor):
|
||||
Tool(name=c.action_name, func=c.action, description=c.action_description)
|
||||
for c in chains
|
||||
]
|
||||
|
||||
agent = ZeroShotAgent.from_llm_and_tools(llm, tools)
|
||||
return cls(agent=agent, tools=tools, **kwargs)
|
||||
|
13
langchain/agents/mrkl/sql_prompt.py
Normal file
13
langchain/agents/mrkl/sql_prompt.py
Normal file
@ -0,0 +1,13 @@
|
||||
# flake8: noqa
|
||||
SQL_PREFIX = """Answer the question as best you can.
|
||||
You should only use data in the SQL database to answer the query. The answer you return should come directly from the database. If you don't find an answer, say "There is not enough information in the DB to answer the question."
|
||||
Your first query can be exploratory, to understand the data in the table. As an example, you can query what the first 5 examples of a column are before querying that column.
|
||||
When possible, don't query exactly but always use 'LIKE' to make your queries more robust.
|
||||
Finally, be mindful of not repeating queries.
|
||||
|
||||
You have access to the following DB:"""
|
||||
|
||||
SQL_SUFFIX = """Begin!
|
||||
|
||||
Question: {input}
|
||||
Thought:{agent_scratchpad}"""
|
Loading…
Reference in New Issue
Block a user