retry to parsing (#3696)

This commit is contained in:
Harrison Chase 2023-05-01 22:05:42 -07:00 committed by GitHub
parent 3993166b5e
commit ca08a34a98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,9 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForToolRun,
CallbackManagerForChainRun, CallbackManagerForChainRun,
CallbackManagerForToolRun,
Callbacks, Callbacks,
) )
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -578,6 +580,25 @@ class Agent(BaseSingleActionAgent):
} }
class ExceptionTool(BaseTool):
name = "_Exception"
description = "Exception tool"
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
return query
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return query
class AgentExecutor(Chain): class AgentExecutor(Chain):
"""Consists of an agent using tools.""" """Consists of an agent using tools."""
@ -587,6 +608,7 @@ class AgentExecutor(Chain):
max_iterations: Optional[int] = 15 max_iterations: Optional[int] = 15
max_execution_time: Optional[float] = None max_execution_time: Optional[float] = None
early_stopping_method: str = "force" early_stopping_method: str = "force"
handle_parsing_errors: bool = False
@classmethod @classmethod
def from_agent_and_tools( def from_agent_and_tools(
@ -714,12 +736,28 @@ class AgentExecutor(Chain):
Override this to take control of how the agent makes and acts on choices. Override this to take control of how the agent makes and acts on choices.
""" """
try:
# Call the LLM to see what to do. # Call the LLM to see what to do.
output = self.agent.plan( output = self.agent.plan(
intermediate_steps, intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None, callbacks=run_manager.get_child() if run_manager else None,
**inputs, **inputs,
) )
except Exception as e:
if not self.handle_parsing_errors:
raise e
text = str(e).split("`")[1]
observation = "Invalid or incomplete response"
output = AgentAction("_Exception", observation, text)
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return [(output, observation)]
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
return output return output
@ -772,12 +810,28 @@ class AgentExecutor(Chain):
Override this to take control of how the agent makes and acts on choices. Override this to take control of how the agent makes and acts on choices.
""" """
try:
# Call the LLM to see what to do. # Call the LLM to see what to do.
output = await self.agent.aplan( output = await self.agent.aplan(
intermediate_steps, intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None, callbacks=run_manager.get_child() if run_manager else None,
**inputs, **inputs,
) )
except Exception as e:
if not self.handle_parsing_errors:
raise e
text = str(e).split("`")[1]
observation = "Invalid or incomplete response"
output = AgentAction("_Exception", observation, text)
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = await ExceptionTool().arun(
output.tool,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return [(output, observation)]
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
return output return output