From 1372296dc8863443e80039169db2f26e9eefc0ff Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:58:14 -0800 Subject: [PATCH] FIX: Infer runnable agent single or multi action (#13412) --- libs/langchain/langchain/agents/agent.py | 102 ++++++++++++++++++++--- 1 file changed, 90 insertions(+), 12 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index da9db9f566d..48de6b1ee2f 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -319,7 +319,7 @@ class BaseMultiActionAgent(BaseModel): return {} -class AgentOutputParser(BaseOutputParser): +class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]): """Base class for parsing agent output into agent action/finish.""" @abstractmethod @@ -327,7 +327,9 @@ class AgentOutputParser(BaseOutputParser): """Parse text into agent action/finish.""" -class MultiActionAgentOutputParser(BaseOutputParser): +class MultiActionAgentOutputParser( + BaseOutputParser[Union[List[AgentAction], AgentFinish]] +): """Base class for parsing agent output into agent actions/finish.""" @abstractmethod @@ -335,17 +337,87 @@ class MultiActionAgentOutputParser(BaseOutputParser): """Parse text into agent actions/finish.""" -class RunnableAgent(BaseMultiActionAgent): +class RunnableAgent(BaseSingleActionAgent): """Agent powered by runnables.""" - runnable: Union[ - Runnable[dict, Union[AgentAction, AgentFinish]], - Runnable[dict, Union[List[AgentAction], AgentFinish]], - ] + runnable: Runnable[dict, Union[AgentAction, AgentFinish]] """Runnable to call to get agent action.""" _input_keys: List[str] = [] """Input keys.""" + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @property + def return_values(self) -> List[str]: + """Return values of the agent.""" + return [] + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + Returns: + List of input keys. + """ + return self._input_keys + + def plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with the observations. + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} + output = self.runnable.invoke(inputs, config={"callbacks": callbacks}) + return output + + async def aplan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[ + AgentAction, + AgentFinish, + ]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} + output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks}) + return output + + +class RunnableMultiActionAgent(BaseMultiActionAgent): + """Agent powered by runnables.""" + + runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]] + """Runnable to call to get agent actions.""" + _input_keys: List[str] = [] + """Input keys.""" + class Config: """Configuration for this pydantic object.""" @@ -387,8 +459,6 @@ class RunnableAgent(BaseMultiActionAgent): """ inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} output = self.runnable.invoke(inputs, config={"callbacks": callbacks}) - if isinstance(output, AgentAction): - output = [output] return output async def aplan( @@ -413,8 +483,6 @@ class RunnableAgent(BaseMultiActionAgent): """ inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks}) - if isinstance(output, AgentAction): - output = [output] return output @@ -840,7 +908,17 @@ class AgentExecutor(Chain): """Convert runnable to agent if passed in.""" agent = values["agent"] if isinstance(agent, Runnable): - values["agent"] = RunnableAgent(runnable=agent) + try: + output_type = agent.OutputType + except Exception as _: + multi_action = False + else: + multi_action = output_type == Union[List[AgentAction], AgentFinish] + + if multi_action: + values["agent"] = RunnableMultiActionAgent(runnable=agent) + else: + values["agent"] = RunnableAgent(runnable=agent) return values def save(self, file_path: Union[Path, str]) -> None: