Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
6af613e2b9 retrieval agents 2023-08-21 22:05:11 -07:00
4 changed files with 408 additions and 0 deletions

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,162 @@
import xml.etree.ElementTree as ET
from typing import Any, List, Tuple, Union
from langchain.agents.agent import (
AgentExecutor,
AgentOutputParser,
BaseSingleActionAgent,
)
from langchain.agents.xml.prompt import agent_instructions
from langchain.callbacks.base import Callbacks
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import AIMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseRetriever, Document
from langchain.tools.base import BaseTool
class XMLAgentOutputParser(AgentOutputParser):
"""Output parser for XMLAgent."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if "</tool>" in text:
tool, tool_input = text.split("</tool>")
_tool = tool.split("<tool>")[1]
_tool_input = tool_input.split("<tool_input>")[1]
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
elif "<final_answer>" in text:
_, answer = text.split("<final_answer>")
return AgentFinish(return_values={"output": answer}, log=text)
else:
raise ValueError
def get_format_instructions(self) -> str:
raise NotImplementedError
@property
def _type(self) -> str:
return "xml-retrieval-agent"
class XMLRetrievalAgent(BaseSingleActionAgent):
"""Agent that uses XML tags to do retrieval.
Args:
tools: list of tools the agent can choose from
llm_chain: The LLMChain to call to predict the next action
Examples:
.. code-block:: python
from langchain.agents import XMLAgent
from langchain
tools = ...
model =
"""
tools: List[BaseTool]
"""List of tools this agent has access to."""
llm_chain: LLMChain
"""Chain to use to predict action."""
@property
def input_keys(self) -> List[str]:
return ["input"]
@staticmethod
def get_default_prompt() -> ChatPromptTemplate:
return ChatPromptTemplate.from_template(
agent_instructions
) + AIMessagePromptTemplate.from_template("{intermediate_steps}")
@staticmethod
def get_default_output_parser() -> XMLAgentOutputParser:
return XMLAgentOutputParser()
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
log = ""
_id = 0
doc_mapping = {}
for action, observation in intermediate_steps:
doc_string = ""
for doc in observation:
doc_mapping[_id] = doc
doc_string += f"<id>{_id}</id><content>{doc.page_content}</content>"
_id += 1
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input><observation>{doc_string}</observation>"
)
tools = ""
for tool in self.tools:
tools += f"{tool.name}: {tool.description}\n"
inputs = {
"intermediate_steps": log,
"tools": tools,
"question": kwargs["input"],
"stop": ["</tool_input>", "</final_answer>"],
}
response = self.llm_chain(inputs, callbacks=callbacks)
result = response[self.llm_chain.output_key]
if isinstance(result, AgentAction):
return result
else:
root = ET.fromstring("<root>" + result.return_values["output"] + "</root>")
ids = [elem.text for elem in root.findall("id")]
docs = [doc_mapping[int(i)] for i in ids]
return AgentFinish(return_values={"output": docs}, log=result.log)
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
log = ""
_id = 0
doc_mapping = {}
for action, observation in intermediate_steps:
doc_string = ""
for doc in observation:
doc_mapping[_id] = doc
doc_string += f"<id>{_id}</id><content>{doc.page_content}</content>"
_id += 1
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input><observation>{doc_string}</observation>"
)
tools = ""
for tool in self.tools:
tools += f"{tool.name}: {tool.description}\n"
inputs = {
"intermediate_steps": log,
"tools": tools,
"question": kwargs["input"],
"stop": ["</tool_input>", "</final_answer>"],
}
response = await self.llm_chain.acall(inputs, callbacks=callbacks)
result = response[self.llm_chain.output_key]
if isinstance(result, AgentAction):
return result
else:
root = ET.fromstring("<root>" + result.return_values["output"] + "</root>")
ids = [elem.text for elem in root.findall("id")]
docs = [doc_mapping[int(i)] for i in ids]
return AgentFinish(return_values={"output": docs}, log=result.log)
class AgentRetriever(BaseRetriever, AgentExecutor):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
result = self({"input": query}, callbacks=run_manager.get_child())
return result["output"]

View File

@@ -0,0 +1,41 @@
# flake8: noqa
agent_instructions = """You are a helpful research assistant who helps find relevant documents for the user's questions.
You have access to the following tools:
{tools}
In order to use a tool, you MUST use <tool></tool> AND <tool_input></tool_input> tags.
Each tool will return a list of documents. This will be returned in the format of: <observation><doc><id>...</id><content>...</content></doc>...</observation>
For example, if you have a tool called 'search' that could run a google search, in order to search for the weather in SF you would respond:
<tool>search</tool>
<tool_input>weather in SF</tool_input>
<observation><doc><id>1</id><content>64 degrees</content></doc></observation>
When you are done, respond with a list of the document ids that are relevant. The documents corresponding to these ids will be returned. For example:
<final_answer><id>1</id><id>4</id>...</final_answer>
Only respond with the ids of the documents that are actually relevant to the question at hand. \
You can make as many queries as are necessary in order to get the correct documents.
Some example of how you should act:
Scenario 1:
- The user asks for topic X
- You run a query for Y but don't get any good results
- You run another for Z and get better results, and so you return some from those
Scenario 2:
- The user asks for topic X and Y
- You run a query for X
- You run a query for Y
- You return the relevant documents from each
Ready?
Begin!
Question: {question}"""