docstrings agents (#7866)

Added/Updated docstrings for `agents`
@baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-18 02:23:24 -07:00 committed by GitHub
parent c6f2d27789
commit 17956ff08e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 130 additions and 46 deletions

View File

@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
class BaseSingleActionAgent(BaseModel):
"""Base Agent class."""
"""Base Single Action Agent class."""
@property
def return_values(self) -> List[str]:
@ -179,7 +179,7 @@ class BaseSingleActionAgent(BaseModel):
class BaseMultiActionAgent(BaseModel):
"""Base Agent class."""
"""Base Multi Action Agent class."""
@property
def return_values(self) -> List[str]:
@ -200,7 +200,7 @@ class BaseMultiActionAgent(BaseModel):
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with the observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -219,7 +219,7 @@ class BaseMultiActionAgent(BaseModel):
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with the observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -299,18 +299,30 @@ class BaseMultiActionAgent(BaseModel):
class AgentOutputParser(BaseOutputParser):
"""Base class for parsing agent output into agent action/finish."""
@abstractmethod
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
"""Parse text into agent action/finish."""
class LLMSingleActionAgent(BaseSingleActionAgent):
"""Base class for single action agents."""
llm_chain: LLMChain
"""LLMChain to use for agent."""
output_parser: AgentOutputParser
"""Output parser to use for agent."""
stop: List[str]
"""List of strings to stop on."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
Returns:
List of input keys.
"""
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def dict(self, **kwargs: Any) -> Dict:
@ -329,7 +341,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with the observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -377,7 +389,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
class Agent(BaseSingleActionAgent):
"""Class responsible for calling the language model and deciding the action.
"""Agent that calls the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include
a variable called "agent_scratchpad" where the agent can put its
@ -599,8 +611,12 @@ class Agent(BaseSingleActionAgent):
class ExceptionTool(BaseTool):
"""Tool that just returns the query."""
name = "_Exception"
"""Name of the tool."""
description = "Exception tool"
"""Description of the tool."""
def _run(
self,
@ -618,7 +634,7 @@ class ExceptionTool(BaseTool):
class AgentExecutor(Chain):
"""Consists of an agent using tools."""
"""Agent that is using tools."""
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
"""The agent to run for creating a plan and determining actions

View File

@ -8,7 +8,7 @@ from langchain.tools import BaseTool
class BaseToolkit(BaseModel, ABC):
"""Class representing a collection of related tools."""
"""Base Toolkit representing a collection of related tools."""
@abstractmethod
def get_tools(self) -> List[BaseTool]:

View File

@ -1,4 +1,3 @@
"""Agent for working with csv files."""
from typing import Any, List, Optional, Union
from langchain.agents.agent import AgentExecutor

View File

@ -1,4 +1,3 @@
"""Toolkit for interacting with the local filesystem."""
from __future__ import annotations
from typing import List, Optional

View File

@ -1,4 +1,3 @@
"""Jira Toolkit."""
from typing import List
from langchain.agents.agent_toolkits.base import BaseToolkit

View File

@ -1,4 +1,3 @@
"""Toolkit for interacting with a JSON spec."""
from __future__ import annotations
from typing import List

View File

@ -1,4 +1,4 @@
"""Tool for interacting with a single API with natural language efinition."""
"""Tool for interacting with a single API with natural language definition."""
from typing import Any, Optional

View File

@ -1,4 +1,3 @@
"""Toolkit for interacting with API's using natural language."""
from __future__ import annotations
from typing import Any, List, Optional, Sequence
@ -15,7 +14,7 @@ from langchain.tools.plugin import AIPlugin
class NLAToolkit(BaseToolkit):
"""Natural Language API Toolkit Definition."""
"""Natural Language API Toolkit."""
nla_tools: Sequence[NLATool] = Field(...)
"""List of API Endpoint Tools."""

View File

@ -18,7 +18,7 @@ if TYPE_CHECKING:
class O365Toolkit(BaseToolkit):
"""Toolkit for interacting with Office365."""
"""Toolkit for interacting with Office 365."""
account: Account = Field(default_factory=authenticate)

View File

@ -30,7 +30,7 @@ def create_openapi_agent(
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a json agent from an LLM and tools."""
"""Construct an OpenAPI agent from an LLM and tools."""
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(
tools,

View File

@ -46,6 +46,7 @@ from langchain.tools.requests.tool import BaseRequestsTool
# information in the response.
# However, the goal for now is to have only a single inference step.
MAX_RESPONSE_LENGTH = 5000
"""Maximum length of the response to be returned."""
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
@ -63,12 +64,18 @@ def _get_default_llm_chain_factory(
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
name = "requests_get"
"""Tool name."""
description = REQUESTS_GET_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
@ -87,13 +94,18 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
name = "requests_post"
description = REQUESTS_POST_TOOL_DESCRIPTION
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
name = "requests_post"
"""Tool name."""
description = REQUESTS_POST_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
@ -111,13 +123,18 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
name = "requests_patch"
description = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
name = "requests_patch"
"""Tool name."""
description = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
@ -135,13 +152,19 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
"""A tool that sends a DELETE request and parses the response."""
name = "requests_delete"
"""The name of the tool."""
description = REQUESTS_DELETE_TOOL_DESCRIPTION
"""The description of the tool."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""The maximum length of the response."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
)
"""The LLM chain used to parse the response."""
def _run(self, text: str) -> str:
try:
@ -265,7 +288,7 @@ def create_openapi_agent(
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Instantiate API planner and controller for a given spec.
"""Instantiate OpenAI API planner and controller for a given spec.
Inject credentials via requests_wrapper.

View File

@ -23,7 +23,7 @@ from langchain.tools.requests.tool import (
class RequestsToolkit(BaseToolkit):
"""Toolkit for making requests."""
"""Toolkit for making REST requests."""
requests_wrapper: TextRequestsWrapper

View File

@ -32,7 +32,7 @@ else:
class PlayWrightBrowserToolkit(BaseToolkit):
"""Toolkit for web browser tools."""
"""Toolkit for PlayWright browser tools."""
sync_browser: Optional["SyncBrowser"] = None
async_browser: Optional["AsyncBrowser"] = None

View File

@ -30,7 +30,7 @@ def create_pbi_agent(
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a pbi agent from an LLM and tools."""
"""Construct a Power BI agent from an LLM and tools."""
if toolkit is None:
if powerbi is None:
raise ValueError("Must provide either a toolkit or powerbi dataset")

View File

@ -29,7 +29,7 @@ from langchain.utilities.powerbi import PowerBIDataset
class PowerBIToolkit(BaseToolkit):
"""Toolkit for interacting with PowerBI dataset."""
"""Toolkit for interacting with Power BI dataset."""
powerbi: PowerBIDataset = Field(exclude=True)
llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True)

View File

@ -43,7 +43,7 @@ def create_spark_dataframe_agent(
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a spark agent from an LLM and dataframe."""
"""Construct a Spark agent from an LLM and dataframe."""
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
raise ValueError("Spark is not installed. run `pip install pyspark`.")

View File

@ -27,7 +27,7 @@ def create_spark_sql_agent(
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a sql agent from an LLM and tools."""
"""Construct a Spark SQL agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(top_k=top_k)
prompt = ZeroShotAgent.create_prompt(

View File

@ -40,7 +40,7 @@ def create_sql_agent(
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a sql agent from an LLM and tools."""
"""Construct an SQL agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent

View File

@ -1,4 +1,4 @@
"""Toolkit for interacting with a SQL database."""
"""Toolkit for interacting with an SQL database."""
from typing import List
from pydantic import Field
@ -23,7 +23,7 @@ class SQLDatabaseToolkit(BaseToolkit):
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
"""Return string representation of SQL dialect to use."""
return self.db.dialect
class Config:

View File

@ -22,7 +22,7 @@ def create_vectorstore_agent(
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a vectorstore agent from an LLM and tools."""
"""Construct a VectorStore agent from an LLM and tools."""
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
llm_chain = LLMChain(
@ -50,7 +50,7 @@ def create_vectorstore_router_agent(
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a vectorstore router agent from an LLM and tools."""
"""Construct a VectorStore router agent from an LLM and tools."""
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
llm_chain = LLMChain(

View File

@ -15,7 +15,7 @@ from langchain.vectorstores.base import VectorStore
class VectorStoreInfo(BaseModel):
"""Information about a vectorstore."""
"""Information about a VectorStore."""
vectorstore: VectorStore = Field(exclude=True)
name: str
@ -28,7 +28,7 @@ class VectorStoreInfo(BaseModel):
class VectorStoreToolkit(BaseToolkit):
"""Toolkit for interacting with a vector store."""
"""Toolkit for interacting with a Vector Store."""
vectorstore_info: VectorStoreInfo = Field(exclude=True)
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
@ -62,7 +62,7 @@ class VectorStoreToolkit(BaseToolkit):
class VectorStoreRouterToolkit(BaseToolkit):
"""Toolkit for routing between vector stores."""
"""Toolkit for routing between Vector Stores."""
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))

View File

@ -35,7 +35,7 @@ def create_xorbits_agent(
from xorbits import numpy as np
from xorbits import pandas as pd
except ImportError:
raise ValueError(
raise ImportError(
"Xorbits package not installed, please install with `pip install xorbits`"
)

View File

@ -24,7 +24,10 @@ from langchain.tools.base import BaseTool
class ChatAgent(Agent):
"""Chat Agent."""
output_parser: AgentOutputParser = Field(default_factory=ChatOutputParser)
"""Output parser for the agent."""
@property
def observation_prefix(self) -> str:

View File

@ -10,7 +10,10 @@ FINAL_ANSWER_ACTION = "Final Answer:"
class ChatOutputParser(AgentOutputParser):
"""Output parser for the chat agent."""
pattern = re.compile(r"^.*?`{3}(?:json)?\n(.*?)`{3}.*?$", re.DOTALL)
"""Regex pattern to parse the output."""
def get_format_instructions(self) -> str:
return FORMAT_INSTRUCTIONS

View File

@ -18,10 +18,12 @@ from langchain.tools.base import BaseTool
class ConversationalAgent(Agent):
"""An agent designed to hold a conversation in addition to using tools."""
"""An agent that holds a conversation in addition to using tools."""
ai_prefix: str = "AI"
"""Prefix to use before AI output."""
output_parser: AgentOutputParser = Field(default_factory=ConvoOutputParser)
"""Output parser for the agent."""
@classmethod
def _get_default_output_parser(
@ -55,7 +57,7 @@ class ConversationalAgent(Agent):
human_prefix: str = "Human",
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
"""Create prompt in the style of the zero-shot agent.
Args:
tools: List of tools the agent will have access to, used to format the

View File

@ -7,7 +7,10 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
class ConvoOutputParser(AgentOutputParser):
"""Output parser for the conversational agent."""
ai_prefix: str = "AI"
"""Prefix to use before AI output."""
def get_format_instructions(self) -> str:
return FORMAT_INSTRUCTIONS

View File

@ -9,6 +9,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
class ConvoOutputParser(AgentOutputParser):
"""Output parser for the conversational agent."""
def get_format_instructions(self) -> str:
return FORMAT_INSTRUCTIONS

View File

@ -421,7 +421,7 @@ def load_tools(
Args:
tool_names: name of tools to load.
llm: Optional language model, may be needed to initialize certain tools.
llm: An optional language model, may be needed to initialize certain tools.
callbacks: Optional callback manager or list of callback handlers.
If not provided, default global callback manager will be used.

View File

@ -36,7 +36,17 @@ def load_agent_from_config(
tools: Optional[List[Tool]] = None,
**kwargs: Any,
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
"""Load agent from Config Dict."""
"""Load agent from Config Dict.
Args:
config: Config dict to load agent from.
llm: Language model to use as the agent.
tools: List of tools this agent has access to.
**kwargs: Additional key word arguments passed to the agent executor.
Returns:
An agent executor.
"""
if "_type" not in config:
raise ValueError("Must specify an agent Type in config")
load_from_tools = config.pop("load_from_llm_and_tools", False)
@ -78,7 +88,15 @@ def load_agent_from_config(
def load_agent(
path: Union[str, Path], **kwargs: Any
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
"""Unified method for loading a agent from LangChainHub or local fs."""
"""Unified method for loading an agent from LangChainHub or local fs.
Args:
path: Path to the agent file.
**kwargs: Additional key word arguments passed to the agent executor.
Returns:
An agent executor.
"""
if hub_result := try_load_from_hub(
path, _load_agent_from_file, "agents", {"json", "yaml"}
):

View File

@ -9,6 +9,8 @@ FINAL_ANSWER_ACTION = "Final Answer:"
class MRKLOutputParser(AgentOutputParser):
"""MRKL Output parser for the chat agent."""
def get_format_instructions(self) -> str:
return FORMAT_INSTRUCTIONS

View File

@ -28,7 +28,7 @@ class ReActDocstoreAgent(Agent):
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
"""Return Identifier of an agent type."""
return AgentType.REACT_DOCSTORE
@classmethod

View File

@ -6,6 +6,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
class ReActOutputParser(AgentOutputParser):
"""Output parser for the ReAct agent."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
action_prefix = "Action: "
if not text.strip().split("\n")[-1].startswith(action_prefix):

View File

@ -5,6 +5,8 @@ from langchain.schema import AgentAction
class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
"""Chat prompt template for the agent scratchpad."""
def _construct_agent_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:

View File

@ -27,7 +27,7 @@ class SelfAskWithSearchAgent(Agent):
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
"""Return Identifier of an agent type."""
return AgentType.SELF_ASK_WITH_SEARCH
@classmethod
@ -75,7 +75,7 @@ class SelfAskWithSearchChain(AgentExecutor):
search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper],
**kwargs: Any,
):
"""Initialize with just an LLM and a search chain."""
"""Initialize only with an LLM and a search chain."""
search_tool = Tool(
name="Intermediate Answer",
func=search_chain.run,

View File

@ -5,6 +5,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
class SelfAskOutputParser(AgentOutputParser):
"""Output parser for the self-ask agent."""
followups: Sequence[str] = ("Follow up:", "Followup:")
finish_string: str = "So the final answer is: "

View File

@ -23,9 +23,12 @@ HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
class StructuredChatAgent(Agent):
"""Structured Chat Agent."""
output_parser: AgentOutputParser = Field(
default_factory=StructuredChatOutputParserWithRetries
)
"""Output parser for the agent."""
@property
def observation_prefix(self) -> str:

View File

@ -17,6 +17,8 @@ logger = logging.getLogger(__name__)
class StructuredChatOutputParser(AgentOutputParser):
"""Output parser for the structured chat agent."""
def get_format_instructions(self) -> str:
return FORMAT_INSTRUCTIONS
@ -46,8 +48,12 @@ class StructuredChatOutputParser(AgentOutputParser):
class StructuredChatOutputParserWithRetries(AgentOutputParser):
"""Output parser with retries for the structured chat agent."""
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
"""The base parser to use."""
output_fixing_parser: Optional[OutputFixingParser] = None
"""The output fixing parser to use."""
def get_format_instructions(self) -> str:
return FORMAT_INSTRUCTIONS

View File

@ -12,7 +12,9 @@ class InvalidTool(BaseTool):
"""Tool that is run when invalid tool name is encountered by agent."""
name = "invalid_tool"
"""Name of the tool."""
description = "Called when tool name is invalid."
"""Description of the tool."""
def _run(
self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None

View File

@ -27,7 +27,7 @@ class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
"""Base class to parse the output of an LLM call.
Output parsers help structure language model responses.