mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
parent
6ffd5b15bc
commit
e7b3290d30
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
|
|
||||||
@ -14,18 +14,17 @@ from langchain_community.tools.file_management.move import MoveFileTool
|
|||||||
from langchain_community.tools.file_management.read import ReadFileTool
|
from langchain_community.tools.file_management.read import ReadFileTool
|
||||||
from langchain_community.tools.file_management.write import WriteFileTool
|
from langchain_community.tools.file_management.write import WriteFileTool
|
||||||
|
|
||||||
_FILE_TOOLS = {
|
_FILE_TOOLS: List[Type[BaseTool]] = [
|
||||||
# "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined]
|
CopyFileTool,
|
||||||
tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined]
|
DeleteFileTool,
|
||||||
for tool_cls in [
|
FileSearchTool,
|
||||||
CopyFileTool,
|
MoveFileTool,
|
||||||
DeleteFileTool,
|
ReadFileTool,
|
||||||
FileSearchTool,
|
WriteFileTool,
|
||||||
MoveFileTool,
|
ListDirectoryTool,
|
||||||
ReadFileTool,
|
]
|
||||||
WriteFileTool,
|
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
|
||||||
ListDirectoryTool,
|
tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -61,20 +60,20 @@ class FileManagementToolkit(BaseToolkit):
|
|||||||
def validate_tools(cls, values: dict) -> dict:
|
def validate_tools(cls, values: dict) -> dict:
|
||||||
selected_tools = values.get("selected_tools") or []
|
selected_tools = values.get("selected_tools") or []
|
||||||
for tool_name in selected_tools:
|
for tool_name in selected_tools:
|
||||||
if tool_name not in _FILE_TOOLS:
|
if tool_name not in _FILE_TOOLS_MAP:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File Tool of name {tool_name} not supported."
|
f"File Tool of name {tool_name} not supported."
|
||||||
f" Permitted tools: {list(_FILE_TOOLS)}"
|
f" Permitted tools: {list(_FILE_TOOLS_MAP)}"
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def get_tools(self) -> List[BaseTool]:
|
def get_tools(self) -> List[BaseTool]:
|
||||||
"""Get the tools in the toolkit."""
|
"""Get the tools in the toolkit."""
|
||||||
allowed_tools = self.selected_tools or _FILE_TOOLS.keys()
|
allowed_tools = self.selected_tools or _FILE_TOOLS_MAP
|
||||||
tools: List[BaseTool] = []
|
tools: List[BaseTool] = []
|
||||||
for tool in allowed_tools:
|
for tool in allowed_tools:
|
||||||
tool_cls = _FILE_TOOLS[tool]
|
tool_cls = _FILE_TOOLS_MAP[tool]
|
||||||
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
|
tools.append(tool_cls(root_dir=self.root_dir))
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, cast
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core.callbacks import BaseCallbackManager
|
from langchain_core.callbacks import BaseCallbackManager
|
||||||
@ -68,7 +68,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
"""Tool name."""
|
"""Tool name."""
|
||||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||||
"""Tool description."""
|
"""Tool description."""
|
||||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
response_length: int = MAX_RESPONSE_LENGTH
|
||||||
"""Maximum length of the response to be returned."""
|
"""Maximum length of the response to be returned."""
|
||||||
llm_chain: Any = Field(
|
llm_chain: Any = Field(
|
||||||
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
||||||
@ -83,8 +83,10 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
data_params = data.get("params")
|
data_params = data.get("params")
|
||||||
response = self.requests_wrapper.get(data["url"], params=data_params)
|
response: str = cast(
|
||||||
response = response[: self.response_length] # type: ignore[index]
|
str, self.requests_wrapper.get(data["url"], params=data_params)
|
||||||
|
)
|
||||||
|
response = response[: self.response_length]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
@ -100,7 +102,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
"""Tool name."""
|
"""Tool name."""
|
||||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||||
"""Tool description."""
|
"""Tool description."""
|
||||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
response_length: int = MAX_RESPONSE_LENGTH
|
||||||
"""Maximum length of the response to be returned."""
|
"""Maximum length of the response to be returned."""
|
||||||
llm_chain: Any = Field(
|
llm_chain: Any = Field(
|
||||||
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
||||||
@ -114,8 +116,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
data = parse_json_markdown(text)
|
data = parse_json_markdown(text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
response = self.requests_wrapper.post(data["url"], data["data"])
|
response: str = cast(str, self.requests_wrapper.post(data["url"], data["data"]))
|
||||||
response = response[: self.response_length] # type: ignore[index]
|
response = response[: self.response_length]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
@ -131,7 +133,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
"""Tool name."""
|
"""Tool name."""
|
||||||
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
||||||
"""Tool description."""
|
"""Tool description."""
|
||||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
response_length: int = MAX_RESPONSE_LENGTH
|
||||||
"""Maximum length of the response to be returned."""
|
"""Maximum length of the response to be returned."""
|
||||||
llm_chain: Any = Field(
|
llm_chain: Any = Field(
|
||||||
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
|
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
|
||||||
@ -145,8 +147,10 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
data = parse_json_markdown(text)
|
data = parse_json_markdown(text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
response = self.requests_wrapper.patch(data["url"], data["data"])
|
response: str = cast(
|
||||||
response = response[: self.response_length] # type: ignore[index]
|
str, self.requests_wrapper.patch(data["url"], data["data"])
|
||||||
|
)
|
||||||
|
response = response[: self.response_length]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
@ -162,7 +166,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
"""Tool name."""
|
"""Tool name."""
|
||||||
description = REQUESTS_PUT_TOOL_DESCRIPTION
|
description = REQUESTS_PUT_TOOL_DESCRIPTION
|
||||||
"""Tool description."""
|
"""Tool description."""
|
||||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
response_length: int = MAX_RESPONSE_LENGTH
|
||||||
"""Maximum length of the response to be returned."""
|
"""Maximum length of the response to be returned."""
|
||||||
llm_chain: Any = Field(
|
llm_chain: Any = Field(
|
||||||
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
|
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
|
||||||
@ -176,8 +180,8 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
data = parse_json_markdown(text)
|
data = parse_json_markdown(text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
response = self.requests_wrapper.put(data["url"], data["data"])
|
response: str = cast(str, self.requests_wrapper.put(data["url"], data["data"]))
|
||||||
response = response[: self.response_length] # type: ignore[index]
|
response = response[: self.response_length]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
@ -208,8 +212,8 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
data = parse_json_markdown(text)
|
data = parse_json_markdown(text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
response = self.requests_wrapper.delete(data["url"])
|
response: str = cast(str, self.requests_wrapper.delete(data["url"]))
|
||||||
response = response[: self.response_length] # type: ignore[index]
|
response = response[: self.response_length]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
|
@ -58,7 +58,7 @@ def create_pbi_agent(
|
|||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
**prompt_params,
|
**prompt_params,
|
||||||
),
|
),
|
||||||
callback_manager=callback_manager, # type: ignore
|
callback_manager=callback_manager,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
),
|
),
|
||||||
allowed_tools=[tool.name for tool in tools],
|
allowed_tools=[tool.name for tool in tools],
|
||||||
|
@ -2,7 +2,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, SystemMessage
|
from langchain_core.messages import AIMessage, SystemMessage
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||||
@ -176,13 +186,13 @@ def create_sql_agent(
|
|||||||
|
|
||||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
messages = [
|
messages: List = [
|
||||||
SystemMessage(content=prefix), # type: ignore[arg-type]
|
SystemMessage(content=cast(str, prefix)),
|
||||||
HumanMessagePromptTemplate.from_template("{input}"),
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||||
]
|
]
|
||||||
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
|
prompt = ChatPromptTemplate.from_messages(messages)
|
||||||
agent = RunnableAgent(
|
agent = RunnableAgent(
|
||||||
runnable=create_openai_functions_agent(llm, tools, prompt),
|
runnable=create_openai_functions_agent(llm, tools, prompt),
|
||||||
input_keys_arg=["input"],
|
input_keys_arg=["input"],
|
||||||
@ -191,12 +201,12 @@ def create_sql_agent(
|
|||||||
elif agent_type == "openai-tools":
|
elif agent_type == "openai-tools":
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=prefix), # type: ignore[arg-type]
|
SystemMessage(content=cast(str, prefix)),
|
||||||
HumanMessagePromptTemplate.from_template("{input}"),
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||||
]
|
]
|
||||||
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
|
prompt = ChatPromptTemplate.from_messages(messages)
|
||||||
agent = RunnableMultiActionAgent(
|
agent = RunnableMultiActionAgent(
|
||||||
runnable=create_openai_tools_agent(llm, tools, prompt),
|
runnable=create_openai_tools_agent(llm, tools, prompt),
|
||||||
input_keys_arg=["input"],
|
input_keys_arg=["input"],
|
||||||
|
Loading…
Reference in New Issue
Block a user