Compare commits

...

1 Commits

Author SHA1 Message Date
Eugene Yurtsev
ff1de3d347 x 2023-12-13 16:47:03 -05:00
3 changed files with 87 additions and 34 deletions

View File

@@ -1034,18 +1034,17 @@ class AgentExecutor(Chain):
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
return self._consume_next_step(
[
a
for a in self._iter_next_step(
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager,
)
]
"""Take a single step."""
next_steps = list(
self._iter_next_step(
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager,
)
)
return self._consume_next_step(next_steps)
def _iter_next_step(
self,

View File

@@ -1,6 +1,7 @@
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain.agents import AgentOutputParser
@@ -32,9 +33,22 @@ class XMLAgentOutputParser(AgentOutputParser):
if "</tool>" in text:
tool, tool_input = text.split("</tool>")
_tool = tool.split("<tool>")[1]
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
if "<tool_input>" in tool_input:
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
else:
raise OutputParserException(
error=ValueError("Invalid format for output."),
llm_output=text,
observation=(
"ERROR: For a fool invocation, be sure to include a <tool_input> and"
"</tool_input> tags. A function without parameters could be invoked with a "
"an empty dictionary as the tool input.\n"
"To invoke a tool, use the format "
"`<tool>$TOOL_NAME</tool><tool_input>$TOOL_INPUT</tool_input>`.\n "
),
)
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
elif "<final_answer>" in text:
_, answer = text.split("<final_answer>")
@@ -42,10 +56,25 @@ class XMLAgentOutputParser(AgentOutputParser):
answer = answer.split("</final_answer>")[0]
return AgentFinish(return_values={"output": answer}, log=text)
else:
raise ValueError
raise OutputParserException(
error=ValueError("Invalid format for output."),
llm_output=text,
observation=(
"ERROR: Please either invoke a tool or provide a final answer."
"To invoke a tool, use the format "
"`<tool>$TOOL_NAME</tool><tool_input>$TOOL_INPUT</tool_input>`. "
"where $TOOL_NAME is one of the provided tools and $TOOL_INPUT "
"is a dictionary of arguments to pass to the tool, "
"matching the schema.\n"
),
send_to_llm=True,
)
def get_format_instructions(self) -> str:
raise NotImplementedError
"""Get the format instructions for this output parser."""
raise NotImplementedError(
"XMLAgentOutputParser does contain format instructions."
)
@property
def _type(self) -> str:

View File

@@ -14,6 +14,11 @@ from langchain.chains.llm import LLMChain
class XMLAgent(BaseSingleActionAgent):
"""Agent that uses XML tags.
This agent only works with LLMs not chat models!
Ability of agent to invoke tools varies a lot depending on how good the underlying
LLM is!
Args:
tools: list of tools the agent can choose from
llm_chain: The LLMChain to call to predict the next action
@@ -22,13 +27,25 @@ class XMLAgent(BaseSingleActionAgent):
.. code-block:: python
from langchain.agents import XMLAgent
from langchain
from langchain.agents import AgentExecutor, XMLAgent
from langchain.chains import LLMChain
tools = ...
model =
chain = LLMChain(
llm=model,
prompt=XMLAgent.get_default_prompt(),
output_parser=XMLAgent.get_default_output_parser(),
)
agent = XMLAgent(tools=tools, llm_chain=chain)
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
handle_parsing_errors=True
)
agent_executor.invoke({"input": "what's the weather in New york?"})
"""
tools: List[BaseTool]
@@ -38,6 +55,7 @@ class XMLAgent(BaseSingleActionAgent):
@property
def input_keys(self) -> List[str]:
"""Get the input keys."""
return ["input"]
@staticmethod
@@ -48,25 +66,38 @@ class XMLAgent(BaseSingleActionAgent):
@staticmethod
def get_default_output_parser() -> XMLAgentOutputParser:
"""Get the default output parser."""
return XMLAgentOutputParser()
def _format_intermediate_steps(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
"""Format the steps."""
log = ""
for action, observation in intermediate_steps:
if action.tool == "_Exception":
# This only works correctly when handle_parsing_errors=True
log += action.log # Will contain the llm output from the exception
log += "\n{observation}\n"
pass
else:
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input>\n<observation>{observation}</observation>\n"
)
return log
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
log = ""
for action, observation in intermediate_steps:
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input><observation>{observation}</observation>"
)
tools = ""
for tool in self.tools:
tools += f"{tool.name}: {tool.description}\n"
inputs = {
"intermediate_steps": log,
"intermediate_steps": self._format_intermediate_steps(intermediate_steps),
"tools": tools,
"question": kwargs["input"],
"stop": ["</tool_input>", "</final_answer>"],
@@ -80,17 +111,11 @@ class XMLAgent(BaseSingleActionAgent):
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
log = ""
for action, observation in intermediate_steps:
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input><observation>{observation}</observation>"
)
tools = ""
for tool in self.tools:
tools += f"{tool.name}: {tool.description}\n"
inputs = {
"intermediate_steps": log,
"intermediate_steps": self._format_intermediate_steps(intermediate_steps),
"tools": tools,
"question": kwargs["input"],
"stop": ["</tool_input>", "</final_answer>"],