mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 20:19:43 +00:00
parent
6ffd5b15bc
commit
e7b3290d30
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
@ -14,18 +14,17 @@ from langchain_community.tools.file_management.move import MoveFileTool
|
||||
from langchain_community.tools.file_management.read import ReadFileTool
|
||||
from langchain_community.tools.file_management.write import WriteFileTool
|
||||
|
||||
_FILE_TOOLS = {
|
||||
# "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined]
|
||||
tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined]
|
||||
for tool_cls in [
|
||||
CopyFileTool,
|
||||
DeleteFileTool,
|
||||
FileSearchTool,
|
||||
MoveFileTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
ListDirectoryTool,
|
||||
]
|
||||
_FILE_TOOLS: List[Type[BaseTool]] = [
|
||||
CopyFileTool,
|
||||
DeleteFileTool,
|
||||
FileSearchTool,
|
||||
MoveFileTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
ListDirectoryTool,
|
||||
]
|
||||
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
|
||||
tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS
|
||||
}
|
||||
|
||||
|
||||
@ -61,20 +60,20 @@ class FileManagementToolkit(BaseToolkit):
|
||||
def validate_tools(cls, values: dict) -> dict:
|
||||
selected_tools = values.get("selected_tools") or []
|
||||
for tool_name in selected_tools:
|
||||
if tool_name not in _FILE_TOOLS:
|
||||
if tool_name not in _FILE_TOOLS_MAP:
|
||||
raise ValueError(
|
||||
f"File Tool of name {tool_name} not supported."
|
||||
f" Permitted tools: {list(_FILE_TOOLS)}"
|
||||
f" Permitted tools: {list(_FILE_TOOLS_MAP)}"
|
||||
)
|
||||
return values
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
allowed_tools = self.selected_tools or _FILE_TOOLS.keys()
|
||||
allowed_tools = self.selected_tools or _FILE_TOOLS_MAP
|
||||
tools: List[BaseTool] = []
|
||||
for tool in allowed_tools:
|
||||
tool_cls = _FILE_TOOLS[tool]
|
||||
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
|
||||
tool_cls = _FILE_TOOLS_MAP[tool]
|
||||
tools.append(tool_cls(root_dir=self.root_dir))
|
||||
return tools
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
import json
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
@ -68,7 +68,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Tool name."""
|
||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
response_length: int = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
||||
@ -83,8 +83,10 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
data_params = data.get("params")
|
||||
response = self.requests_wrapper.get(data["url"], params=data_params)
|
||||
response = response[: self.response_length] # type: ignore[index]
|
||||
response: str = cast(
|
||||
str, self.requests_wrapper.get(data["url"], params=data_params)
|
||||
)
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
@ -100,7 +102,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Tool name."""
|
||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
response_length: int = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
||||
@ -114,8 +116,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.post(data["url"], data["data"])
|
||||
response = response[: self.response_length] # type: ignore[index]
|
||||
response: str = cast(str, self.requests_wrapper.post(data["url"], data["data"]))
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
@ -131,7 +133,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
response_length: int = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
|
||||
@ -145,8 +147,10 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.patch(data["url"], data["data"])
|
||||
response = response[: self.response_length] # type: ignore[index]
|
||||
response: str = cast(
|
||||
str, self.requests_wrapper.patch(data["url"], data["data"])
|
||||
)
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
@ -162,7 +166,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PUT_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
response_length: int = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: Any = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
|
||||
@ -176,8 +180,8 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.put(data["url"], data["data"])
|
||||
response = response[: self.response_length] # type: ignore[index]
|
||||
response: str = cast(str, self.requests_wrapper.put(data["url"], data["data"]))
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
@ -208,8 +212,8 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
data = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.delete(data["url"])
|
||||
response = response[: self.response_length] # type: ignore[index]
|
||||
response: str = cast(str, self.requests_wrapper.delete(data["url"]))
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
@ -58,7 +58,7 @@ def create_pbi_agent(
|
||||
input_variables=input_variables,
|
||||
**prompt_params,
|
||||
),
|
||||
callback_manager=callback_manager, # type: ignore
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
|
@ -2,7 +2,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
@ -176,13 +186,13 @@ def create_sql_agent(
|
||||
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
if prompt is None:
|
||||
messages = [
|
||||
SystemMessage(content=prefix), # type: ignore[arg-type]
|
||||
messages: List = [
|
||||
SystemMessage(content=cast(str, prefix)),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
agent = RunnableAgent(
|
||||
runnable=create_openai_functions_agent(llm, tools, prompt),
|
||||
input_keys_arg=["input"],
|
||||
@ -191,12 +201,12 @@ def create_sql_agent(
|
||||
elif agent_type == "openai-tools":
|
||||
if prompt is None:
|
||||
messages = [
|
||||
SystemMessage(content=prefix), # type: ignore[arg-type]
|
||||
SystemMessage(content=cast(str, prefix)),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
agent = RunnableMultiActionAgent(
|
||||
runnable=create_openai_tools_agent(llm, tools, prompt),
|
||||
input_keys_arg=["input"],
|
||||
|
Loading…
Reference in New Issue
Block a user