mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 15:03:21 +00:00
runnable powered agent (#10407)
This commit is contained in:
@@ -7,7 +7,16 @@ import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -36,6 +45,7 @@ from langchain.schema import (
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.asyncio import asyncio_timeout
|
||||
from langchain.utils.input import get_color_mapping
|
||||
@@ -307,6 +317,71 @@ class AgentOutputParser(BaseOutputParser):
|
||||
"""Parse text into agent action/finish."""
|
||||
|
||||
|
||||
class RunnableAgent(BaseSingleActionAgent):
|
||||
"""Agent powered by runnables."""
|
||||
|
||||
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
|
||||
"""Runnable to call to get agent action."""
|
||||
_input_keys: List[str] = []
|
||||
"""Input keys."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
Returns:
|
||||
List of input keys.
|
||||
"""
|
||||
return self._input_keys
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with the observations.
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||
output = self.runnable.invoke(inputs, config={"callbacks": callbacks})
|
||||
return output
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||
output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
|
||||
return output
|
||||
|
||||
|
||||
class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
"""Base class for single action agents."""
|
||||
|
||||
@@ -725,6 +800,14 @@ s
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_runnable_agent(cls, values: Dict) -> Dict:
|
||||
"""Convert runnable to agent if passed in."""
|
||||
agent = values["agent"]
|
||||
if isinstance(agent, Runnable):
|
||||
values["agent"] = RunnableAgent(runnable=agent)
|
||||
return values
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Raise error - saving not supported for Agent Executors."""
|
||||
raise ValueError(
|
||||
|
Reference in New Issue
Block a user