community[patch]: fix agent_toolkits mypy (#17050)

Related to #17048
This commit is contained in:
Bagatur 2024-02-05 11:56:24 -08:00 committed by GitHub
parent 6ffd5b15bc
commit e7b3290d30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 40 deletions

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List, Optional
from typing import Dict, List, Optional, Type
from langchain_core.pydantic_v1 import root_validator
@ -14,18 +14,17 @@ from langchain_community.tools.file_management.move import MoveFileTool
from langchain_community.tools.file_management.read import ReadFileTool
from langchain_community.tools.file_management.write import WriteFileTool
_FILE_TOOLS = {
# "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined]
tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined]
for tool_cls in [
CopyFileTool,
DeleteFileTool,
FileSearchTool,
MoveFileTool,
ReadFileTool,
WriteFileTool,
ListDirectoryTool,
]
_FILE_TOOLS: List[Type[BaseTool]] = [
CopyFileTool,
DeleteFileTool,
FileSearchTool,
MoveFileTool,
ReadFileTool,
WriteFileTool,
ListDirectoryTool,
]
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS
}
@ -61,20 +60,20 @@ class FileManagementToolkit(BaseToolkit):
def validate_tools(cls, values: dict) -> dict:
selected_tools = values.get("selected_tools") or []
for tool_name in selected_tools:
if tool_name not in _FILE_TOOLS:
if tool_name not in _FILE_TOOLS_MAP:
raise ValueError(
f"File Tool of name {tool_name} not supported."
f" Permitted tools: {list(_FILE_TOOLS)}"
f" Permitted tools: {list(_FILE_TOOLS_MAP)}"
)
return values
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
allowed_tools = self.selected_tools or _FILE_TOOLS.keys()
allowed_tools = self.selected_tools or _FILE_TOOLS_MAP
tools: List[BaseTool] = []
for tool in allowed_tools:
tool_cls = _FILE_TOOLS[tool]
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
tool_cls = _FILE_TOOLS_MAP[tool]
tools.append(tool_cls(root_dir=self.root_dir))
return tools

View File

@ -2,7 +2,7 @@
import json
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, cast
import yaml
from langchain_core.callbacks import BaseCallbackManager
@ -68,7 +68,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool name."""
description = REQUESTS_GET_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
@ -83,8 +83,10 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e:
raise e
data_params = data.get("params")
response = self.requests_wrapper.get(data["url"], params=data_params)
response = response[: self.response_length] # type: ignore[index]
response: str = cast(
str, self.requests_wrapper.get(data["url"], params=data_params)
)
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
@ -100,7 +102,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool name."""
description = REQUESTS_POST_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
@ -114,8 +116,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.post(data["url"], data["data"])
response = response[: self.response_length] # type: ignore[index]
response: str = cast(str, self.requests_wrapper.post(data["url"], data["data"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
@ -131,7 +133,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool name."""
description = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
@ -145,8 +147,10 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.patch(data["url"], data["data"])
response = response[: self.response_length] # type: ignore[index]
response: str = cast(
str, self.requests_wrapper.patch(data["url"], data["data"])
)
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
@ -162,7 +166,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool name."""
description = REQUESTS_PUT_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
@ -176,8 +180,8 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.put(data["url"], data["data"])
response = response[: self.response_length] # type: ignore[index]
response: str = cast(str, self.requests_wrapper.put(data["url"], data["data"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
@ -208,8 +212,8 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.delete(data["url"])
response = response[: self.response_length] # type: ignore[index]
response: str = cast(str, self.requests_wrapper.delete(data["url"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()

View File

@ -58,7 +58,7 @@ def create_pbi_agent(
input_variables=input_variables,
**prompt_params,
),
callback_manager=callback_manager, # type: ignore
callback_manager=callback_manager,
verbose=verbose,
),
allowed_tools=[tool.name for tool in tools],

View File

@ -2,7 +2,17 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
@ -176,13 +186,13 @@ def create_sql_agent(
elif agent_type == AgentType.OPENAI_FUNCTIONS:
if prompt is None:
messages = [
SystemMessage(content=prefix), # type: ignore[arg-type]
messages: List = [
SystemMessage(content=cast(str, prefix)),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt),
input_keys_arg=["input"],
@ -191,12 +201,12 @@ def create_sql_agent(
elif agent_type == "openai-tools":
if prompt is None:
messages = [
SystemMessage(content=prefix), # type: ignore[arg-type]
SystemMessage(content=cast(str, prefix)),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableMultiActionAgent(
runnable=create_openai_tools_agent(llm, tools, prompt),
input_keys_arg=["input"],