community[patch]: fix agent_toolkits mypy (#17050)

Related to #17048
This commit is contained in:
Bagatur 2024-02-05 11:56:24 -08:00 committed by GitHub
parent 6ffd5b15bc
commit e7b3290d30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 40 deletions

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Optional from typing import Dict, List, Optional, Type
from langchain_core.pydantic_v1 import root_validator from langchain_core.pydantic_v1 import root_validator
@ -14,10 +14,7 @@ from langchain_community.tools.file_management.move import MoveFileTool
from langchain_community.tools.file_management.read import ReadFileTool from langchain_community.tools.file_management.read import ReadFileTool
from langchain_community.tools.file_management.write import WriteFileTool from langchain_community.tools.file_management.write import WriteFileTool
_FILE_TOOLS = { _FILE_TOOLS: List[Type[BaseTool]] = [
# "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined]
tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined]
for tool_cls in [
CopyFileTool, CopyFileTool,
DeleteFileTool, DeleteFileTool,
FileSearchTool, FileSearchTool,
@ -25,7 +22,9 @@ _FILE_TOOLS = {
ReadFileTool, ReadFileTool,
WriteFileTool, WriteFileTool,
ListDirectoryTool, ListDirectoryTool,
] ]
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS
} }
@ -61,20 +60,20 @@ class FileManagementToolkit(BaseToolkit):
def validate_tools(cls, values: dict) -> dict: def validate_tools(cls, values: dict) -> dict:
selected_tools = values.get("selected_tools") or [] selected_tools = values.get("selected_tools") or []
for tool_name in selected_tools: for tool_name in selected_tools:
if tool_name not in _FILE_TOOLS: if tool_name not in _FILE_TOOLS_MAP:
raise ValueError( raise ValueError(
f"File Tool of name {tool_name} not supported." f"File Tool of name {tool_name} not supported."
f" Permitted tools: {list(_FILE_TOOLS)}" f" Permitted tools: {list(_FILE_TOOLS_MAP)}"
) )
return values return values
def get_tools(self) -> List[BaseTool]: def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit.""" """Get the tools in the toolkit."""
allowed_tools = self.selected_tools or _FILE_TOOLS.keys() allowed_tools = self.selected_tools or _FILE_TOOLS_MAP
tools: List[BaseTool] = [] tools: List[BaseTool] = []
for tool in allowed_tools: for tool in allowed_tools:
tool_cls = _FILE_TOOLS[tool] tool_cls = _FILE_TOOLS_MAP[tool]
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore tools.append(tool_cls(root_dir=self.root_dir))
return tools return tools

View File

@ -2,7 +2,7 @@
import json import json
import re import re
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, cast
import yaml import yaml
from langchain_core.callbacks import BaseCallbackManager from langchain_core.callbacks import BaseCallbackManager
@ -68,7 +68,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool name.""" """Tool name."""
description = REQUESTS_GET_TOOL_DESCRIPTION description = REQUESTS_GET_TOOL_DESCRIPTION
"""Tool description.""" """Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned.""" """Maximum length of the response to be returned."""
llm_chain: Any = Field( llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT) default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
@ -83,8 +83,10 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
data_params = data.get("params") data_params = data.get("params")
response = self.requests_wrapper.get(data["url"], params=data_params) response: str = cast(
response = response[: self.response_length] # type: ignore[index] str, self.requests_wrapper.get(data["url"], params=data_params)
)
response = response[: self.response_length]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()
@ -100,7 +102,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool name.""" """Tool name."""
description = REQUESTS_POST_TOOL_DESCRIPTION description = REQUESTS_POST_TOOL_DESCRIPTION
"""Tool description.""" """Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned.""" """Maximum length of the response to be returned."""
llm_chain: Any = Field( llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT) default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
@ -114,8 +116,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
data = parse_json_markdown(text) data = parse_json_markdown(text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
response = self.requests_wrapper.post(data["url"], data["data"]) response: str = cast(str, self.requests_wrapper.post(data["url"], data["data"]))
response = response[: self.response_length] # type: ignore[index] response = response[: self.response_length]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()
@ -131,7 +133,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool name.""" """Tool name."""
description = REQUESTS_PATCH_TOOL_DESCRIPTION description = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Tool description.""" """Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned.""" """Maximum length of the response to be returned."""
llm_chain: Any = Field( llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT) default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
@ -145,8 +147,10 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
data = parse_json_markdown(text) data = parse_json_markdown(text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
response = self.requests_wrapper.patch(data["url"], data["data"]) response: str = cast(
response = response[: self.response_length] # type: ignore[index] str, self.requests_wrapper.patch(data["url"], data["data"])
)
response = response[: self.response_length]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()
@ -162,7 +166,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool name.""" """Tool name."""
description = REQUESTS_PUT_TOOL_DESCRIPTION description = REQUESTS_PUT_TOOL_DESCRIPTION
"""Tool description.""" """Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned.""" """Maximum length of the response to be returned."""
llm_chain: Any = Field( llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT) default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
@ -176,8 +180,8 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
data = parse_json_markdown(text) data = parse_json_markdown(text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
response = self.requests_wrapper.put(data["url"], data["data"]) response: str = cast(str, self.requests_wrapper.put(data["url"], data["data"]))
response = response[: self.response_length] # type: ignore[index] response = response[: self.response_length]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()
@ -208,8 +212,8 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
data = parse_json_markdown(text) data = parse_json_markdown(text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
response = self.requests_wrapper.delete(data["url"]) response: str = cast(str, self.requests_wrapper.delete(data["url"]))
response = response[: self.response_length] # type: ignore[index] response = response[: self.response_length]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()

View File

@ -58,7 +58,7 @@ def create_pbi_agent(
input_variables=input_variables, input_variables=input_variables,
**prompt_params, **prompt_params,
), ),
callback_manager=callback_manager, # type: ignore callback_manager=callback_manager,
verbose=verbose, verbose=verbose,
), ),
allowed_tools=[tool.name for tool in tools], allowed_tools=[tool.name for tool in tools],

View File

@ -2,7 +2,17 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
from langchain_core.messages import AIMessage, SystemMessage from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.prompts import BasePromptTemplate, PromptTemplate
@ -176,13 +186,13 @@ def create_sql_agent(
elif agent_type == AgentType.OPENAI_FUNCTIONS: elif agent_type == AgentType.OPENAI_FUNCTIONS:
if prompt is None: if prompt is None:
messages = [ messages: List = [
SystemMessage(content=prefix), # type: ignore[arg-type] SystemMessage(content=cast(str, prefix)),
HumanMessagePromptTemplate.from_template("{input}"), HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name="agent_scratchpad"),
] ]
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type] prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent( agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt), runnable=create_openai_functions_agent(llm, tools, prompt),
input_keys_arg=["input"], input_keys_arg=["input"],
@ -191,12 +201,12 @@ def create_sql_agent(
elif agent_type == "openai-tools": elif agent_type == "openai-tools":
if prompt is None: if prompt is None:
messages = [ messages = [
SystemMessage(content=prefix), # type: ignore[arg-type] SystemMessage(content=cast(str, prefix)),
HumanMessagePromptTemplate.from_template("{input}"), HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name="agent_scratchpad"),
] ]
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type] prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableMultiActionAgent( agent = RunnableMultiActionAgent(
runnable=create_openai_tools_agent(llm, tools, prompt), runnable=create_openai_tools_agent(llm, tools, prompt),
input_keys_arg=["input"], input_keys_arg=["input"],