diff --git a/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py index 538569755b5..81071ff5610 100644 --- a/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional +from typing import Dict, List, Optional, Type from langchain_core.pydantic_v1 import root_validator @@ -14,18 +14,17 @@ from langchain_community.tools.file_management.move import MoveFileTool from langchain_community.tools.file_management.read import ReadFileTool from langchain_community.tools.file_management.write import WriteFileTool -_FILE_TOOLS = { - # "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined] - tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined] - for tool_cls in [ - CopyFileTool, - DeleteFileTool, - FileSearchTool, - MoveFileTool, - ReadFileTool, - WriteFileTool, - ListDirectoryTool, - ] +_FILE_TOOLS: List[Type[BaseTool]] = [ + CopyFileTool, + DeleteFileTool, + FileSearchTool, + MoveFileTool, + ReadFileTool, + WriteFileTool, + ListDirectoryTool, +] +_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = { + tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS } @@ -61,20 +60,20 @@ class FileManagementToolkit(BaseToolkit): def validate_tools(cls, values: dict) -> dict: selected_tools = values.get("selected_tools") or [] for tool_name in selected_tools: - if tool_name not in _FILE_TOOLS: + if tool_name not in _FILE_TOOLS_MAP: raise ValueError( f"File Tool of name {tool_name} not supported." - f" Permitted tools: {list(_FILE_TOOLS)}" + f" Permitted tools: {list(_FILE_TOOLS_MAP)}" ) return values def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" - allowed_tools = self.selected_tools or _FILE_TOOLS.keys() + allowed_tools = self.selected_tools or _FILE_TOOLS_MAP tools: List[BaseTool] = [] for tool in allowed_tools: - tool_cls = _FILE_TOOLS[tool] - tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore + tool_cls = _FILE_TOOLS_MAP[tool] + tools.append(tool_cls(root_dir=self.root_dir)) return tools diff --git a/libs/community/langchain_community/agent_toolkits/openapi/planner.py b/libs/community/langchain_community/agent_toolkits/openapi/planner.py index 7b561a93039..e112dfc5150 100644 --- a/libs/community/langchain_community/agent_toolkits/openapi/planner.py +++ b/libs/community/langchain_community/agent_toolkits/openapi/planner.py @@ -2,7 +2,7 @@ import json import re from functools import partial -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, cast import yaml from langchain_core.callbacks import BaseCallbackManager @@ -68,7 +68,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): """Tool name.""" description = REQUESTS_GET_TOOL_DESCRIPTION """Tool description.""" - response_length: Optional[int] = MAX_RESPONSE_LENGTH + response_length: int = MAX_RESPONSE_LENGTH """Maximum length of the response to be returned.""" llm_chain: Any = Field( default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT) @@ -83,8 +83,10 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): except json.JSONDecodeError as e: raise e data_params = data.get("params") - response = self.requests_wrapper.get(data["url"], params=data_params) - response = response[: self.response_length] # type: ignore[index] + response: str = cast( + str, self.requests_wrapper.get(data["url"], params=data_params) + ) + response = response[: self.response_length] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() @@ -100,7 +102,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): """Tool name.""" description = REQUESTS_POST_TOOL_DESCRIPTION """Tool description.""" - response_length: Optional[int] = MAX_RESPONSE_LENGTH + response_length: int = MAX_RESPONSE_LENGTH """Maximum length of the response to be returned.""" llm_chain: Any = Field( default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT) @@ -114,8 +116,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e - response = self.requests_wrapper.post(data["url"], data["data"]) - response = response[: self.response_length] # type: ignore[index] + response: str = cast(str, self.requests_wrapper.post(data["url"], data["data"])) + response = response[: self.response_length] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() @@ -131,7 +133,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool): """Tool name.""" description = REQUESTS_PATCH_TOOL_DESCRIPTION """Tool description.""" - response_length: Optional[int] = MAX_RESPONSE_LENGTH + response_length: int = MAX_RESPONSE_LENGTH """Maximum length of the response to be returned.""" llm_chain: Any = Field( default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT) @@ -145,8 +147,10 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool): data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e - response = self.requests_wrapper.patch(data["url"], data["data"]) - response = response[: self.response_length] # type: ignore[index] + response: str = cast( + str, self.requests_wrapper.patch(data["url"], data["data"]) + ) + response = response[: self.response_length] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() @@ -162,7 +166,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool): """Tool name.""" description = REQUESTS_PUT_TOOL_DESCRIPTION """Tool description.""" - response_length: Optional[int] = MAX_RESPONSE_LENGTH + response_length: int = MAX_RESPONSE_LENGTH """Maximum length of the response to be returned.""" llm_chain: Any = Field( default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT) @@ -176,8 +180,8 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool): data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e - response = self.requests_wrapper.put(data["url"], data["data"]) - response = response[: self.response_length] # type: ignore[index] + response: str = cast(str, self.requests_wrapper.put(data["url"], data["data"])) + response = response[: self.response_length] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() @@ -208,8 +212,8 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool): data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e - response = self.requests_wrapper.delete(data["url"]) - response = response[: self.response_length] # type: ignore[index] + response: str = cast(str, self.requests_wrapper.delete(data["url"])) + response = response[: self.response_length] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() diff --git a/libs/community/langchain_community/agent_toolkits/powerbi/base.py b/libs/community/langchain_community/agent_toolkits/powerbi/base.py index 06de2b97a7b..486e55f4e7c 100644 --- a/libs/community/langchain_community/agent_toolkits/powerbi/base.py +++ b/libs/community/langchain_community/agent_toolkits/powerbi/base.py @@ -58,7 +58,7 @@ def create_pbi_agent( input_variables=input_variables, **prompt_params, ), - callback_manager=callback_manager, # type: ignore + callback_manager=callback_manager, verbose=verbose, ), allowed_tools=[tool.name for tool in tools], diff --git a/libs/community/langchain_community/agent_toolkits/sql/base.py b/libs/community/langchain_community/agent_toolkits/sql/base.py index 65294cf094b..a4d638b32c2 100644 --- a/libs/community/langchain_community/agent_toolkits/sql/base.py +++ b/libs/community/langchain_community/agent_toolkits/sql/base.py @@ -2,7 +2,17 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Sequence, + Union, + cast, +) from langchain_core.messages import AIMessage, SystemMessage from langchain_core.prompts import BasePromptTemplate, PromptTemplate @@ -176,13 +186,13 @@ def create_sql_agent( elif agent_type == AgentType.OPENAI_FUNCTIONS: if prompt is None: - messages = [ - SystemMessage(content=prefix), # type: ignore[arg-type] + messages: List = [ + SystemMessage(content=cast(str, prefix)), HumanMessagePromptTemplate.from_template("{input}"), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), MessagesPlaceholder(variable_name="agent_scratchpad"), ] - prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type] + prompt = ChatPromptTemplate.from_messages(messages) agent = RunnableAgent( runnable=create_openai_functions_agent(llm, tools, prompt), input_keys_arg=["input"], @@ -191,12 +201,12 @@ def create_sql_agent( elif agent_type == "openai-tools": if prompt is None: messages = [ - SystemMessage(content=prefix), # type: ignore[arg-type] + SystemMessage(content=cast(str, prefix)), HumanMessagePromptTemplate.from_template("{input}"), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), MessagesPlaceholder(variable_name="agent_scratchpad"), ] - prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type] + prompt = ChatPromptTemplate.from_messages(messages) agent = RunnableMultiActionAgent( runnable=create_openai_tools_agent(llm, tools, prompt), input_keys_arg=["input"],