runnable powered agent (#10407)

This commit is contained in:
Harrison Chase
2023-09-09 15:22:13 -07:00
committed by GitHub
parent 6ad6bb46c4
commit 40d9191955
3 changed files with 288 additions and 2 deletions

View File

@@ -7,7 +7,16 @@ import logging
import time
from abc import abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)
import yaml
@@ -36,6 +45,7 @@ from langchain.schema import (
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable import Runnable
from langchain.tools.base import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
from langchain.utils.input import get_color_mapping
@@ -307,6 +317,71 @@ class AgentOutputParser(BaseOutputParser):
"""Parse text into agent action/finish."""
class RunnableAgent(BaseSingleActionAgent):
"""Agent powered by runnables."""
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
"""Runnable to call to get agent action."""
_input_keys: List[str] = []
"""Input keys."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
Returns:
List of input keys.
"""
return self._input_keys
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with the observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
output = self.runnable.invoke(inputs, config={"callbacks": callbacks})
return output
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
return output
class LLMSingleActionAgent(BaseSingleActionAgent):
"""Base class for single action agents."""
@@ -725,6 +800,14 @@ s
)
return values
@root_validator(pre=True)
def validate_runnable_agent(cls, values: Dict) -> Dict:
"""Convert runnable to agent if passed in."""
agent = values["agent"]
if isinstance(agent, Runnable):
values["agent"] = RunnableAgent(runnable=agent)
return values
def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors."""
raise ValueError(