mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 19:18:53 +00:00
docstrings agents
(#7866)
Added/Updated docstrings for `agents` @baskaryan
This commit is contained in:
parent
c6f2d27789
commit
17956ff08e
@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseSingleActionAgent(BaseModel):
|
||||
"""Base Agent class."""
|
||||
"""Base Single Action Agent class."""
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
@ -179,7 +179,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
|
||||
class BaseMultiActionAgent(BaseModel):
|
||||
"""Base Agent class."""
|
||||
"""Base Multi Action Agent class."""
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
@ -200,7 +200,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
along with the observations.
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
@ -219,7 +219,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
along with the observations.
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
@ -299,18 +299,30 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
|
||||
class AgentOutputParser(BaseOutputParser):
|
||||
"""Base class for parsing agent output into agent action/finish."""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
"""Parse text into agent action/finish."""
|
||||
|
||||
|
||||
class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
"""Base class for single action agents."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLMChain to use for agent."""
|
||||
output_parser: AgentOutputParser
|
||||
"""Output parser to use for agent."""
|
||||
stop: List[str]
|
||||
"""List of strings to stop on."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
Returns:
|
||||
List of input keys.
|
||||
"""
|
||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
@ -329,7 +341,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
along with the observations.
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
@ -377,7 +389,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
|
||||
|
||||
class Agent(BaseSingleActionAgent):
|
||||
"""Class responsible for calling the language model and deciding the action.
|
||||
"""Agent that calls the language model and deciding the action.
|
||||
|
||||
This is driven by an LLMChain. The prompt in the LLMChain MUST include
|
||||
a variable called "agent_scratchpad" where the agent can put its
|
||||
@ -599,8 +611,12 @@ class Agent(BaseSingleActionAgent):
|
||||
|
||||
|
||||
class ExceptionTool(BaseTool):
|
||||
"""Tool that just returns the query."""
|
||||
|
||||
name = "_Exception"
|
||||
"""Name of the tool."""
|
||||
description = "Exception tool"
|
||||
"""Description of the tool."""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@ -618,7 +634,7 @@ class ExceptionTool(BaseTool):
|
||||
|
||||
|
||||
class AgentExecutor(Chain):
|
||||
"""Consists of an agent using tools."""
|
||||
"""Agent that is using tools."""
|
||||
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
|
||||
"""The agent to run for creating a plan and determining actions
|
||||
|
@ -8,7 +8,7 @@ from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class BaseToolkit(BaseModel, ABC):
|
||||
"""Class representing a collection of related tools."""
|
||||
"""Base Toolkit representing a collection of related tools."""
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Agent for working with csv files."""
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Toolkit for interacting with the local filesystem."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Jira Toolkit."""
|
||||
from typing import List
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Toolkit for interacting with a JSON spec."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Tool for interacting with a single API with natural language efinition."""
|
||||
"""Tool for interacting with a single API with natural language definition."""
|
||||
|
||||
|
||||
from typing import Any, Optional
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Toolkit for interacting with API's using natural language."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Sequence
|
||||
@ -15,7 +14,7 @@ from langchain.tools.plugin import AIPlugin
|
||||
|
||||
|
||||
class NLAToolkit(BaseToolkit):
|
||||
"""Natural Language API Toolkit Definition."""
|
||||
"""Natural Language API Toolkit."""
|
||||
|
||||
nla_tools: Sequence[NLATool] = Field(...)
|
||||
"""List of API Endpoint Tools."""
|
||||
|
@ -30,7 +30,7 @@ def create_openapi_agent(
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a json agent from an LLM and tools."""
|
||||
"""Construct an OpenAPI agent from an LLM and tools."""
|
||||
tools = toolkit.get_tools()
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
|
@ -46,6 +46,7 @@ from langchain.tools.requests.tool import BaseRequestsTool
|
||||
# information in the response.
|
||||
# However, the goal for now is to have only a single inference step.
|
||||
MAX_RESPONSE_LENGTH = 5000
|
||||
"""Maximum length of the response to be returned."""
|
||||
|
||||
|
||||
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
|
||||
@ -63,12 +64,18 @@ def _get_default_llm_chain_factory(
|
||||
|
||||
|
||||
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name = "requests_get"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
@ -87,13 +94,18 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
|
||||
|
||||
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
name = "requests_post"
|
||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name = "requests_post"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
@ -111,13 +123,18 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
|
||||
|
||||
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
name = "requests_patch"
|
||||
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
||||
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name = "requests_patch"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
@ -135,13 +152,19 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
|
||||
|
||||
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""A tool that sends a DELETE request and parses the response."""
|
||||
|
||||
name = "requests_delete"
|
||||
"""The name of the tool."""
|
||||
description = REQUESTS_DELETE_TOOL_DESCRIPTION
|
||||
"""The description of the tool."""
|
||||
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""The maximum length of the response."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
|
||||
)
|
||||
"""The LLM chain used to parse the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
@ -265,7 +288,7 @@ def create_openapi_agent(
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Instantiate API planner and controller for a given spec.
|
||||
"""Instantiate OpenAI API planner and controller for a given spec.
|
||||
|
||||
Inject credentials via requests_wrapper.
|
||||
|
||||
|
@ -23,7 +23,7 @@ from langchain.tools.requests.tool import (
|
||||
|
||||
|
||||
class RequestsToolkit(BaseToolkit):
|
||||
"""Toolkit for making requests."""
|
||||
"""Toolkit for making REST requests."""
|
||||
|
||||
requests_wrapper: TextRequestsWrapper
|
||||
|
||||
|
@ -32,7 +32,7 @@ else:
|
||||
|
||||
|
||||
class PlayWrightBrowserToolkit(BaseToolkit):
|
||||
"""Toolkit for web browser tools."""
|
||||
"""Toolkit for PlayWright browser tools."""
|
||||
|
||||
sync_browser: Optional["SyncBrowser"] = None
|
||||
async_browser: Optional["AsyncBrowser"] = None
|
||||
|
@ -30,7 +30,7 @@ def create_pbi_agent(
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a pbi agent from an LLM and tools."""
|
||||
"""Construct a Power BI agent from an LLM and tools."""
|
||||
if toolkit is None:
|
||||
if powerbi is None:
|
||||
raise ValueError("Must provide either a toolkit or powerbi dataset")
|
||||
|
@ -43,7 +43,7 @@ def create_spark_dataframe_agent(
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a spark agent from an LLM and dataframe."""
|
||||
"""Construct a Spark agent from an LLM and dataframe."""
|
||||
|
||||
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
|
||||
raise ValueError("Spark is not installed. run `pip install pyspark`.")
|
||||
|
@ -27,7 +27,7 @@ def create_spark_sql_agent(
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a sql agent from an LLM and tools."""
|
||||
"""Construct a Spark SQL agent from an LLM and tools."""
|
||||
tools = toolkit.get_tools()
|
||||
prefix = prefix.format(top_k=top_k)
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
|
@ -40,7 +40,7 @@ def create_sql_agent(
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a sql agent from an LLM and tools."""
|
||||
"""Construct an SQL agent from an LLM and tools."""
|
||||
tools = toolkit.get_tools()
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
agent: BaseSingleActionAgent
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Toolkit for interacting with a SQL database."""
|
||||
"""Toolkit for interacting with an SQL database."""
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
@ -23,7 +23,7 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return string representation of dialect to use."""
|
||||
"""Return string representation of SQL dialect to use."""
|
||||
return self.db.dialect
|
||||
|
||||
class Config:
|
||||
|
@ -22,7 +22,7 @@ def create_vectorstore_agent(
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a vectorstore agent from an LLM and tools."""
|
||||
"""Construct a VectorStore agent from an LLM and tools."""
|
||||
tools = toolkit.get_tools()
|
||||
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
||||
llm_chain = LLMChain(
|
||||
@ -50,7 +50,7 @@ def create_vectorstore_router_agent(
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a vectorstore router agent from an LLM and tools."""
|
||||
"""Construct a VectorStore router agent from an LLM and tools."""
|
||||
tools = toolkit.get_tools()
|
||||
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
||||
llm_chain = LLMChain(
|
||||
|
@ -15,7 +15,7 @@ from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class VectorStoreInfo(BaseModel):
|
||||
"""Information about a vectorstore."""
|
||||
"""Information about a VectorStore."""
|
||||
|
||||
vectorstore: VectorStore = Field(exclude=True)
|
||||
name: str
|
||||
@ -28,7 +28,7 @@ class VectorStoreInfo(BaseModel):
|
||||
|
||||
|
||||
class VectorStoreToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with a vector store."""
|
||||
"""Toolkit for interacting with a Vector Store."""
|
||||
|
||||
vectorstore_info: VectorStoreInfo = Field(exclude=True)
|
||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||
@ -62,7 +62,7 @@ class VectorStoreToolkit(BaseToolkit):
|
||||
|
||||
|
||||
class VectorStoreRouterToolkit(BaseToolkit):
|
||||
"""Toolkit for routing between vector stores."""
|
||||
"""Toolkit for routing between Vector Stores."""
|
||||
|
||||
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
|
||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||
|
@ -35,7 +35,7 @@ def create_xorbits_agent(
|
||||
from xorbits import numpy as np
|
||||
from xorbits import pandas as pd
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Xorbits package not installed, please install with `pip install xorbits`"
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,10 @@ from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
class ChatAgent(Agent):
|
||||
"""Chat Agent."""
|
||||
|
||||
output_parser: AgentOutputParser = Field(default_factory=ChatOutputParser)
|
||||
"""Output parser for the agent."""
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
|
@ -10,7 +10,10 @@ FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
|
||||
class ChatOutputParser(AgentOutputParser):
|
||||
"""Output parser for the chat agent."""
|
||||
|
||||
pattern = re.compile(r"^.*?`{3}(?:json)?\n(.*?)`{3}.*?$", re.DOTALL)
|
||||
"""Regex pattern to parse the output."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return FORMAT_INSTRUCTIONS
|
||||
|
@ -18,10 +18,12 @@ from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
class ConversationalAgent(Agent):
|
||||
"""An agent designed to hold a conversation in addition to using tools."""
|
||||
"""An agent that holds a conversation in addition to using tools."""
|
||||
|
||||
ai_prefix: str = "AI"
|
||||
"""Prefix to use before AI output."""
|
||||
output_parser: AgentOutputParser = Field(default_factory=ConvoOutputParser)
|
||||
"""Output parser for the agent."""
|
||||
|
||||
@classmethod
|
||||
def _get_default_output_parser(
|
||||
@ -55,7 +57,7 @@ class ConversationalAgent(Agent):
|
||||
human_prefix: str = "Human",
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
"""Create prompt in the style of the zero-shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
|
@ -7,7 +7,10 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class ConvoOutputParser(AgentOutputParser):
|
||||
"""Output parser for the conversational agent."""
|
||||
|
||||
ai_prefix: str = "AI"
|
||||
"""Prefix to use before AI output."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return FORMAT_INSTRUCTIONS
|
||||
|
@ -9,6 +9,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class ConvoOutputParser(AgentOutputParser):
|
||||
"""Output parser for the conversational agent."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return FORMAT_INSTRUCTIONS
|
||||
|
||||
|
@ -421,7 +421,7 @@ def load_tools(
|
||||
|
||||
Args:
|
||||
tool_names: name of tools to load.
|
||||
llm: Optional language model, may be needed to initialize certain tools.
|
||||
llm: An optional language model, may be needed to initialize certain tools.
|
||||
callbacks: Optional callback manager or list of callback handlers.
|
||||
If not provided, default global callback manager will be used.
|
||||
|
||||
|
@ -36,7 +36,17 @@ def load_agent_from_config(
|
||||
tools: Optional[List[Tool]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
"""Load agent from Config Dict."""
|
||||
"""Load agent from Config Dict.
|
||||
|
||||
Args:
|
||||
config: Config dict to load agent from.
|
||||
llm: Language model to use as the agent.
|
||||
tools: List of tools this agent has access to.
|
||||
**kwargs: Additional key word arguments passed to the agent executor.
|
||||
|
||||
Returns:
|
||||
An agent executor.
|
||||
"""
|
||||
if "_type" not in config:
|
||||
raise ValueError("Must specify an agent Type in config")
|
||||
load_from_tools = config.pop("load_from_llm_and_tools", False)
|
||||
@ -78,7 +88,15 @@ def load_agent_from_config(
|
||||
def load_agent(
|
||||
path: Union[str, Path], **kwargs: Any
|
||||
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
"""Unified method for loading a agent from LangChainHub or local fs."""
|
||||
"""Unified method for loading an agent from LangChainHub or local fs.
|
||||
|
||||
Args:
|
||||
path: Path to the agent file.
|
||||
**kwargs: Additional key word arguments passed to the agent executor.
|
||||
|
||||
Returns:
|
||||
An agent executor.
|
||||
"""
|
||||
if hub_result := try_load_from_hub(
|
||||
path, _load_agent_from_file, "agents", {"json", "yaml"}
|
||||
):
|
||||
|
@ -9,6 +9,8 @@ FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
|
||||
class MRKLOutputParser(AgentOutputParser):
|
||||
"""MRKL Output parser for the chat agent."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return FORMAT_INSTRUCTIONS
|
||||
|
||||
|
@ -28,7 +28,7 @@ class ReActDocstoreAgent(Agent):
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
"""Return Identifier of an agent type."""
|
||||
return AgentType.REACT_DOCSTORE
|
||||
|
||||
@classmethod
|
||||
|
@ -6,6 +6,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class ReActOutputParser(AgentOutputParser):
|
||||
"""Output parser for the ReAct agent."""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
action_prefix = "Action: "
|
||||
if not text.strip().split("\n")[-1].startswith(action_prefix):
|
||||
|
@ -5,6 +5,8 @@ from langchain.schema import AgentAction
|
||||
|
||||
|
||||
class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
|
||||
"""Chat prompt template for the agent scratchpad."""
|
||||
|
||||
def _construct_agent_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
|
@ -27,7 +27,7 @@ class SelfAskWithSearchAgent(Agent):
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
"""Return Identifier of an agent type."""
|
||||
return AgentType.SELF_ASK_WITH_SEARCH
|
||||
|
||||
@classmethod
|
||||
@ -75,7 +75,7 @@ class SelfAskWithSearchChain(AgentExecutor):
|
||||
search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper],
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with just an LLM and a search chain."""
|
||||
"""Initialize only with an LLM and a search chain."""
|
||||
search_tool = Tool(
|
||||
name="Intermediate Answer",
|
||||
func=search_chain.run,
|
||||
|
@ -5,6 +5,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class SelfAskOutputParser(AgentOutputParser):
|
||||
"""Output parser for the self-ask agent."""
|
||||
|
||||
followups: Sequence[str] = ("Follow up:", "Followup:")
|
||||
finish_string: str = "So the final answer is: "
|
||||
|
||||
|
@ -23,9 +23,12 @@ HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||
|
||||
|
||||
class StructuredChatAgent(Agent):
|
||||
"""Structured Chat Agent."""
|
||||
|
||||
output_parser: AgentOutputParser = Field(
|
||||
default_factory=StructuredChatOutputParserWithRetries
|
||||
)
|
||||
"""Output parser for the agent."""
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
|
@ -17,6 +17,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StructuredChatOutputParser(AgentOutputParser):
|
||||
"""Output parser for the structured chat agent."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return FORMAT_INSTRUCTIONS
|
||||
|
||||
@ -46,8 +48,12 @@ class StructuredChatOutputParser(AgentOutputParser):
|
||||
|
||||
|
||||
class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
"""Output parser with retries for the structured chat agent."""
|
||||
|
||||
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
|
||||
"""The base parser to use."""
|
||||
output_fixing_parser: Optional[OutputFixingParser] = None
|
||||
"""The output fixing parser to use."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return FORMAT_INSTRUCTIONS
|
||||
|
@ -12,7 +12,9 @@ class InvalidTool(BaseTool):
|
||||
"""Tool that is run when invalid tool name is encountered by agent."""
|
||||
|
||||
name = "invalid_tool"
|
||||
"""Name of the tool."""
|
||||
description = "Called when tool name is invalid."
|
||||
"""Description of the tool."""
|
||||
|
||||
def _run(
|
||||
self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
|
@ -27,7 +27,7 @@ class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
||||
"""Class to parse the output of an LLM call.
|
||||
"""Base class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user