mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
13 Commits
vwp/bold_h
...
vwp/inheri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f04575757d | ||
|
|
e236e64cd4 | ||
|
|
f4b4d52e58 | ||
|
|
417ce72df2 | ||
|
|
d7e9b72380 | ||
|
|
203e97d789 | ||
|
|
f99348fb12 | ||
|
|
6bc9700863 | ||
|
|
4abeea3d42 | ||
|
|
e79e003218 | ||
|
|
37fca4ae75 | ||
|
|
c90ce64757 | ||
|
|
4cd6a2223a |
@@ -28,11 +28,24 @@ from langchain.schema import (
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
from langchain.utilities.asyncio import asyncio_timeout
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_all_instance_of_base_tool(
|
||||
class_name: str, tools: Sequence[BaseStructuredTool]
|
||||
) -> None:
|
||||
"""Validate that all tools are of type BaseStructuredTool."""
|
||||
for tool in tools:
|
||||
if not isinstance(tool, BaseTool):
|
||||
raise TypeError(
|
||||
f"Agent {class_name} only supporte tools of type {BaseTool}."
|
||||
f" {tool.name} is of type {type(tool).__name__}"
|
||||
)
|
||||
|
||||
|
||||
class BaseSingleActionAgent(BaseModel):
|
||||
"""Base Agent class."""
|
||||
|
||||
@@ -103,7 +116,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
@@ -448,13 +461,12 @@ class Agent(BaseSingleActionAgent):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
|
||||
"""Create a prompt for this class."""
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
"""Validate that appropriate tools are passed in."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -465,7 +477,7 @@ class Agent(BaseSingleActionAgent):
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
@@ -539,7 +551,7 @@ class AgentExecutor(Chain):
|
||||
"""Consists of an agent using tools."""
|
||||
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
|
||||
tools: Sequence[BaseTool]
|
||||
tools: Sequence[BaseStructuredTool]
|
||||
return_intermediate_steps: bool = False
|
||||
max_iterations: Optional[int] = 15
|
||||
max_execution_time: Optional[float] = None
|
||||
@@ -549,7 +561,7 @@ class AgentExecutor(Chain):
|
||||
def from_agent_and_tools(
|
||||
cls,
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
@@ -617,7 +629,7 @@ class AgentExecutor(Chain):
|
||||
else:
|
||||
return self.agent.return_values
|
||||
|
||||
def lookup_tool(self, name: str) -> BaseTool:
|
||||
def lookup_tool(self, name: str) -> BaseStructuredTool:
|
||||
"""Lookup tool by name."""
|
||||
return {tool.name: tool for tool in self.tools}[name]
|
||||
|
||||
@@ -657,9 +669,51 @@ class AgentExecutor(Chain):
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
|
||||
def _run_tool(
|
||||
self,
|
||||
tool: BaseStructuredTool,
|
||||
agent_action: AgentAction,
|
||||
color: str,
|
||||
**tool_run_kwargs: Any,
|
||||
) -> Any:
|
||||
tool_input = agent_action.tool_input
|
||||
if isinstance(tool_input, str):
|
||||
if not isinstance(tool, BaseTool):
|
||||
return (
|
||||
f"Error: tool {tool.name} could not be "
|
||||
f"run with input {agent_action.tool_input}"
|
||||
)
|
||||
return tool.run(
|
||||
tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
|
||||
async def _aruntool(
|
||||
self,
|
||||
tool: BaseStructuredTool,
|
||||
agent_action: AgentAction,
|
||||
color: str,
|
||||
**tool_run_kwargs: Any,
|
||||
) -> Any:
|
||||
tool_input = agent_action.tool_input
|
||||
if isinstance(tool_input, str):
|
||||
if not isinstance(tool, BaseTool):
|
||||
return (
|
||||
f"Error: tool {tool.name} could not be "
|
||||
f"run with input {agent_action.tool_input}"
|
||||
)
|
||||
return await tool.arun(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
|
||||
def _take_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
name_to_tool_map: Dict[str, BaseStructuredTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
@@ -692,11 +746,8 @@ class AgentExecutor(Chain):
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = tool.run(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
**tool_run_kwargs,
|
||||
observation = self._run_tool(
|
||||
tool, agent_action, color, run_async=True, **tool_run_kwargs
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
@@ -711,7 +762,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
async def _atake_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
name_to_tool_map: Dict[str, BaseStructuredTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
@@ -751,11 +802,8 @@ class AgentExecutor(Chain):
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = await tool.arun(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
**tool_run_kwargs,
|
||||
observation = await self._aruntool(
|
||||
tool, agent_action, color, run_async=True, **tool_run_kwargs
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
|
||||
@@ -2,7 +2,11 @@ from typing import Any, List, Optional, Sequence, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.agent import (
|
||||
Agent,
|
||||
AgentOutputParser,
|
||||
validate_all_instance_of_base_tool,
|
||||
)
|
||||
from langchain.agents.chat.output_parser import ChatOutputParser
|
||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
@@ -14,7 +18,7 @@ from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction, BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
class ChatAgent(Agent):
|
||||
@@ -56,7 +60,7 @@ class ChatAgent(Agent):
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
@@ -74,11 +78,16 @@ class ChatAgent(Agent):
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
validate_all_instance_of_base_tool(cls.__name__, tools)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
|
||||
@@ -20,7 +20,10 @@ class ChatOutputParser(AgentOutputParser):
|
||||
try:
|
||||
action = text.split("```")[1]
|
||||
response = json.loads(action.strip())
|
||||
return AgentAction(response["action"], response["action_input"], text)
|
||||
action_input = response["action_input"]
|
||||
if isinstance(action_input, dict):
|
||||
action_input = json.dumps(action_input)
|
||||
return AgentAction(response["action"], action_input, text)
|
||||
|
||||
except Exception:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import Any, List, Optional, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.agent import (
|
||||
Agent,
|
||||
AgentOutputParser,
|
||||
validate_all_instance_of_base_tool,
|
||||
)
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.conversational.output_parser import ConvoOutputParser
|
||||
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
@@ -13,7 +17,7 @@ from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
class ConversationalAgent(Agent):
|
||||
@@ -46,7 +50,7 @@ class ConversationalAgent(Agent):
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
@@ -80,11 +84,16 @@ class ConversationalAgent(Agent):
|
||||
input_variables = ["input", "chat_history", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
validate_all_instance_of_base_tool(cls.__name__, tools)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import Any, List, Optional, Sequence, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.agent import (
|
||||
Agent,
|
||||
AgentOutputParser,
|
||||
validate_all_instance_of_base_tool,
|
||||
)
|
||||
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
|
||||
from langchain.agents.conversational_chat.prompt import (
|
||||
PREFIX,
|
||||
@@ -29,7 +33,7 @@ from langchain.schema import (
|
||||
BaseOutputParser,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
class ConversationalChatAgent(Agent):
|
||||
@@ -58,7 +62,7 @@ class ConversationalChatAgent(Agent):
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
system_message: str = PREFIX,
|
||||
human_message: str = SUFFIX,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
@@ -98,11 +102,16 @@ class ConversationalChatAgent(Agent):
|
||||
thoughts.append(human_message)
|
||||
return thoughts
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
validate_all_instance_of_base_tool(cls.__name__, tools)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
system_message: str = PREFIX,
|
||||
|
||||
@@ -5,7 +5,12 @@ from typing import Any, Callable, List, NamedTuple, Optional, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||
from langchain.agents.agent import (
|
||||
Agent,
|
||||
AgentExecutor,
|
||||
AgentOutputParser,
|
||||
validate_all_instance_of_base_tool,
|
||||
)
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.output_parser import MRKLOutputParser
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
@@ -14,7 +19,7 @@ from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
class ChainConfig(NamedTuple):
|
||||
@@ -58,7 +63,7 @@ class ZeroShotAgent(Agent):
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
@@ -88,7 +93,7 @@ class ZeroShotAgent(Agent):
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
@@ -121,13 +126,15 @@ class ZeroShotAgent(Agent):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
for tool in tools:
|
||||
if tool.description is None:
|
||||
raise ValueError(
|
||||
f"Got a tool {tool.name} without a description. For this agent, "
|
||||
f"a description must always be provided."
|
||||
)
|
||||
validate_all_instance_of_base_tool(cls.__name__, tools)
|
||||
|
||||
|
||||
class MRKLChain(AgentExecutor):
|
||||
|
||||
@@ -3,7 +3,12 @@ from typing import Any, List, Optional, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||
from langchain.agents.agent import (
|
||||
Agent,
|
||||
AgentExecutor,
|
||||
AgentOutputParser,
|
||||
validate_all_instance_of_base_tool,
|
||||
)
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.react.output_parser import ReActOutputParser
|
||||
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
|
||||
@@ -13,7 +18,7 @@ from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
class ReActDocstoreAgent(Agent):
|
||||
@@ -31,12 +36,12 @@ class ReActDocstoreAgent(Agent):
|
||||
return AgentType.REACT_DOCSTORE
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
|
||||
"""Return default prompt."""
|
||||
return WIKI_PROMPT
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
if len(tools) != 2:
|
||||
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
@@ -44,6 +49,7 @@ class ReActDocstoreAgent(Agent):
|
||||
raise ValueError(
|
||||
f"Tool names should be Lookup and Search, got {tool_names}"
|
||||
)
|
||||
validate_all_instance_of_base_tool(cls.__name__, tools)
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
@@ -113,17 +119,18 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
||||
"""Agent for the ReAct TextWorld chain."""
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
|
||||
"""Return default prompt."""
|
||||
return TEXTWORLD_PROMPT
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
if tool_names != {"Play"}:
|
||||
raise ValueError(f"Tool name should be Play, got {tool_names}")
|
||||
validate_all_instance_of_base_tool(cls.__name__, tools)
|
||||
|
||||
|
||||
class ReActChain(AgentExecutor):
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
||||
from langchain.utilities.serpapi import SerpAPIWrapper
|
||||
|
||||
@@ -30,12 +30,12 @@ class SelfAskWithSearchAgent(Agent):
|
||||
return AgentType.SELF_ASK_WITH_SEARCH
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
|
||||
"""Prompt does not depend on tools."""
|
||||
return PROMPT
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
|
||||
@@ -3,12 +3,11 @@ from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, validate_arguments, validator
|
||||
from pydantic import validator
|
||||
|
||||
from langchain.tools.base import (
|
||||
BaseTool,
|
||||
create_schema_from_function,
|
||||
get_filtered_args,
|
||||
StringSchema,
|
||||
)
|
||||
|
||||
|
||||
@@ -28,22 +27,14 @@ class Tool(BaseTool):
|
||||
raise ValueError("Partial functions not yet supported in tools.")
|
||||
return func
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
inferred_model = validate_arguments(self.func).model # type: ignore
|
||||
return get_filtered_args(inferred_model, self.func)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(self, tool_input: str) -> str:
|
||||
"""Use the tool."""
|
||||
return self.func(*args, **kwargs)
|
||||
return self.func(tool_input)
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
return await self.coroutine(tool_input)
|
||||
raise NotImplementedError("Tool does not support async")
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
@@ -74,8 +65,7 @@ class InvalidTool(BaseTool):
|
||||
def tool(
|
||||
*args: Union[str, Callable],
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
args_schema: Optional[Type[StringSchema]] = None,
|
||||
) -> Callable:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
@@ -83,10 +73,7 @@ def tool(
|
||||
*args: The arguments to the tool.
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop.
|
||||
args_schema: optional argument schema for user to specify
|
||||
infer_schema: Whether to infer the schema of the arguments from
|
||||
the function's signature. This also makes the resultant tool
|
||||
accept a dictionary input to its `run()` function.
|
||||
args_schema: The schema for the arguments used to validate input.
|
||||
|
||||
Requires:
|
||||
- Function must be of type (str) -> str
|
||||
@@ -112,15 +99,13 @@ def tool(
|
||||
# Description example:
|
||||
# search_api(query: str) - Searches the API for the query.
|
||||
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
|
||||
tool_kwargs = {} if args_schema is None else {"args_schema": args_schema}
|
||||
tool_ = Tool(
|
||||
name=tool_name,
|
||||
func=func,
|
||||
args_schema=_args_schema,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
**tool_kwargs,
|
||||
)
|
||||
return tool_
|
||||
|
||||
|
||||
@@ -2,17 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Dict, Generic, List, NamedTuple, Optional, Sequence, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
@@ -41,7 +31,7 @@ class AgentAction(NamedTuple):
|
||||
"""Agent's action to take."""
|
||||
|
||||
tool: str
|
||||
tool_input: Union[str, dict]
|
||||
tool_input: str
|
||||
log: str
|
||||
|
||||
|
||||
@@ -401,8 +391,6 @@ class OutputParserException(Exception):
|
||||
errors will be raised.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Base interface for transforming documents."""
|
||||
|
||||
@@ -2,241 +2,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import Any, Dict, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
create_model,
|
||||
validate_arguments,
|
||||
validator,
|
||||
)
|
||||
from pydantic.main import ModelMetaclass
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
|
||||
# For backwards compatability, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
if isinstance(run_input, str):
|
||||
return (run_input,), {}
|
||||
else:
|
||||
return [], run_input
|
||||
class StringSchema(BaseModel):
|
||||
"""Schema for a tool with string input."""
|
||||
|
||||
# Child tools can add additional validation by
|
||||
# subclassing this schema.
|
||||
tool_input: str
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
|
||||
|
||||
class ToolMetaclass(ModelMetaclass):
|
||||
"""Metaclass for BaseTool to ensure the provided args_schema
|
||||
|
||||
doesn't silently ignored."""
|
||||
|
||||
def __new__(
|
||||
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
|
||||
) -> ToolMetaclass:
|
||||
"""Create the definition of the new tool class."""
|
||||
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
|
||||
if schema_type is not None:
|
||||
schema_annotations = dct.get("__annotations__", {})
|
||||
args_schema_type = schema_annotations.get("args_schema", None)
|
||||
if args_schema_type is None or args_schema_type == BaseModel:
|
||||
# Throw errors for common mis-annotations.
|
||||
# TODO: Use get_args / get_origin and fully
|
||||
# specify valid annotations.
|
||||
typehint_mandate = """
|
||||
class ChildTool(BaseTool):
|
||||
...
|
||||
args_schema: Type[BaseModel] = SchemaClass
|
||||
..."""
|
||||
raise SchemaAnnotationError(
|
||||
f"Tool definition for {name} must include valid type annotations"
|
||||
f" for argument 'args_schema' to behave as expected.\n"
|
||||
f"Expected annotation of 'Type[BaseModel]'"
|
||||
f" but got '{args_schema_type}'.\n"
|
||||
f"Expected class looks like:\n"
|
||||
f"{typehint_mandate}"
|
||||
)
|
||||
# Pass through to Pydantic's metaclass
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: BaseModel, field_names: list
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {
|
||||
field_name: (
|
||||
model.__fields__[field_name].type_,
|
||||
model.__fields__[field_name].default,
|
||||
)
|
||||
for field_name in field_names
|
||||
if field_name in model.__fields__
|
||||
}
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
||||
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {k: schema[k] for k in valid_keys}
|
||||
|
||||
|
||||
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature."""
|
||||
inferred_model = validate_arguments(func).model # type: ignore
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
filtered_args = get_filtered_args(inferred_model, func)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(filtered_args)
|
||||
)
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
class BaseTool(ABC, BaseStructuredTool[str]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
args_schema: Optional[Type[BaseModel]] = None
|
||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||
return_direct: bool = False
|
||||
verbose: bool = False
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
args_schema: Type[StringSchema] = StringSchema # :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
def _wrap_input(self, tool_input: Union[str, Dict]) -> Dict:
|
||||
"""Wrap the tool's input into a pydantic model."""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input
|
||||
return {"tool_input": tool_input}
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
inferred_model = validate_arguments(self._run).model # type: ignore
|
||||
return get_filtered_args(inferred_model, self._run)
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> None:
|
||||
"""Convert tool input to pydantic model."""
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if input_args is not None:
|
||||
key_ = next(iter(input_args.__fields__.keys()))
|
||||
input_args.validate({key_: tool_input})
|
||||
else:
|
||||
if input_args is not None:
|
||||
input_args.validate(tool_input)
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
def _prepare_input(self, input_: dict) -> Tuple[Sequence, Dict]:
|
||||
"""Prepare the args and kwargs for the tool."""
|
||||
# We expect a single string input
|
||||
return tuple(input_.values()), {}
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(self, tool_input: str) -> str:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
tool_input: Union[str, dict],
|
||||
verbose: bool | None = None,
|
||||
start_color: str | None = "green",
|
||||
color: str | None = "green",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the tool."""
|
||||
self._parse_input(tool_input)
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
|
||||
observation = self._run(*tool_args, **tool_kwargs)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
"""Use the tool."""
|
||||
wrapped_input = self._wrap_input(tool_input)
|
||||
return super().run(wrapped_input, verbose, start_color, color, **kwargs)
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
tool_input: Union[str, dict],
|
||||
verbose: bool | None = None,
|
||||
start_color: str | None = "green",
|
||||
color: str | None = "green",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the tool asynchronously."""
|
||||
self._parse_input(tool_input)
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
args, kwargs = _to_args_and_kwargs(tool_input)
|
||||
observation = await self._arun(*args, **kwargs)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
else:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
"""Use the tool asynchronously."""
|
||||
wrapped_input = self._wrap_input(tool_input)
|
||||
return await super().arun(wrapped_input, verbose, start_color, color, **kwargs)
|
||||
|
||||
def __call__(self, tool_input: str) -> str:
|
||||
"""Make tool callable."""
|
||||
def __call__(self, tool_input: Union[dict, str]) -> str:
|
||||
return self.run(tool_input)
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.base import BaseTool, StringSchema
|
||||
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||
|
||||
|
||||
class ReadFileInput(BaseModel):
|
||||
class ReadFileInput(StringSchema):
|
||||
"""Input for ReadFileTool."""
|
||||
|
||||
file_path: str = Field(..., description="name of file")
|
||||
tool_input: str = Field(..., description="name of file", alias="file_path")
|
||||
|
||||
|
||||
class ReadFileTool(BaseTool):
|
||||
name: str = "read_file"
|
||||
args_schema: Type[BaseModel] = ReadFileInput
|
||||
description: str = "Read file from disk"
|
||||
args_schema: Type[ReadFileInput] = ReadFileInput
|
||||
description: str = "(file_path: str) -> str, Read file from disk."
|
||||
root_dir: Optional[str] = None
|
||||
"""Directory to read file from.
|
||||
|
||||
@@ -35,6 +35,6 @@ class ReadFileTool(BaseTool):
|
||||
except Exception as e:
|
||||
return "Error: " + str(e)
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
async def _arun(self, file_path: str) -> str:
|
||||
# TODO: Add aiofiles method
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
class WriteFileInput(BaseModel):
|
||||
@@ -14,7 +14,7 @@ class WriteFileInput(BaseModel):
|
||||
text: str = Field(..., description="text to write to file")
|
||||
|
||||
|
||||
class WriteFileTool(BaseTool):
|
||||
class WriteFileTool(BaseStructuredTool[str]):
|
||||
name: str = "write_file"
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
description: str = "Write file to disk"
|
||||
|
||||
@@ -127,11 +127,11 @@ class ListPowerBITool(BaseTool):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(self, tool_input: str = "") -> str:
|
||||
"""Get the names of the tables."""
|
||||
return ", ".join(self.powerbi.get_table_names())
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(self, tool_input: str = "") -> str:
|
||||
"""Get the names of the tables."""
|
||||
return ", ".join(self.powerbi.get_table_names())
|
||||
|
||||
|
||||
332
langchain/tools/structured.py
Normal file
332
langchain/tools/structured.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from inspect import Parameter, signature
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, create_model, validate_arguments
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.utilities.async_utils import async_or_sync_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OUTPUT_T = TypeVar("OUTPUT_T")
|
||||
|
||||
|
||||
class BaseStructuredTool(
|
||||
GenericModel,
|
||||
Generic[OUTPUT_T],
|
||||
BaseModel,
|
||||
):
|
||||
"""Parent class for all structured tools."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
return_direct: bool = False
|
||||
verbose: bool = False
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
args_schema: Type[BaseModel] # :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def args(self) -> Dict:
|
||||
return self.args_schema.schema()["properties"]
|
||||
|
||||
def _get_verbosity(
|
||||
self,
|
||||
verbose: Optional[bool] = None,
|
||||
) -> bool:
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
return verbose_
|
||||
|
||||
def _prepare_input(self, input_: dict) -> Tuple[Sequence, Dict]:
|
||||
"""Prepare the args and kwargs for the tool."""
|
||||
return (), input_
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, *args: Any, **kwargs: Any) -> OUTPUT_T:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> OUTPUT_T:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def _parse_input(self, tool_input: dict) -> dict:
|
||||
parsed_input = self.args_schema.parse_obj(tool_input)
|
||||
parsed_dict = parsed_input.dict()
|
||||
return {k: getattr(parsed_input, k) for k in parsed_dict.keys()}
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: dict,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> OUTPUT_T:
|
||||
"""Run the tool."""
|
||||
tool_input = self._parse_input(tool_input)
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
args, kwargs = self._prepare_input(tool_input)
|
||||
observation = self._run(*args, **kwargs)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: dict,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> OUTPUT_T:
|
||||
"""Run the tool asynchronously."""
|
||||
tool_input = self._parse_input(tool_input)
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_start,
|
||||
{"name": self.name, "description": self.description},
|
||||
str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
args, kwargs = self._prepare_input(tool_input)
|
||||
observation = await self._arun(*args, **kwargs)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_error,
|
||||
e,
|
||||
verbose=verbose_,
|
||||
is_async=self.callback_manager.is_async,
|
||||
)
|
||||
raise e
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_end,
|
||||
str(observation),
|
||||
verbose=verbose_,
|
||||
color=color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: dict) -> OUTPUT_T:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input)
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: BaseModel, field_names: list
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {
|
||||
field_name: (
|
||||
model.__fields__[field_name].type_,
|
||||
model.__fields__[field_name].default,
|
||||
)
|
||||
for field_name in field_names
|
||||
if field_name in model.__fields__
|
||||
}
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
||||
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {k: schema[k] for k in valid_keys}
|
||||
|
||||
|
||||
def _warn_args_kwargs(func: Callable) -> None:
|
||||
# Check if the function has *args or **kwargs
|
||||
sig = signature(func)
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == Parameter.VAR_POSITIONAL:
|
||||
logger.warning(f"{func.__name__} uses *args, which are not well supported.")
|
||||
elif param.kind == Parameter.VAR_KEYWORD:
|
||||
logger.warning(
|
||||
f"{func.__name__} uses **kwargs, which are not well supported."
|
||||
)
|
||||
|
||||
|
||||
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature."""
|
||||
_warn_args_kwargs(func)
|
||||
inferred_model = validate_arguments(func).model # type: ignore
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
filtered_args = get_filtered_args(inferred_model, func)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(filtered_args)
|
||||
)
|
||||
|
||||
|
||||
class StructuredTool(BaseStructuredTool[Any]):
|
||||
"""StructuredTool that takes in function or coroutine directly."""
|
||||
|
||||
func: Callable[..., Any]
|
||||
"""The function to run when the tool is called."""
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
||||
"""The asynchronous version of the function."""
|
||||
args_schema: Type[BaseModel] # :meta private:
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""The JSON Schema arguments for the tool."""
|
||||
return self.args_schema.schema()["properties"]
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool."""
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
raise NotImplementedError(f"StructuredTool {self.name} does not support async")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable[..., Any],
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> "StructuredTool":
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
Args:
|
||||
func: The function to run when the tool is called.
|
||||
coroutine: The asynchronous version of the function.
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop.
|
||||
args_schema: optional argument schema for user to specify
|
||||
infer_schema: Whether to infer the schema of the arguments from
|
||||
the function's signature. This also makes the resultant tool
|
||||
accept a dictionary input to its `run()` function.
|
||||
name: The name of the tool. Defaults to the function name.
|
||||
description: The description of the tool. Defaults to the function
|
||||
docstring.
|
||||
"""
|
||||
description = func.__doc__ or description
|
||||
if description is None or not description.strip():
|
||||
raise ValueError(
|
||||
f"Function {func.__name__} must have a docstring, or set description."
|
||||
)
|
||||
name = name or func.__name__
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
_args_schema = create_schema_from_function(f"{name}Schema", func)
|
||||
description = f"{name}{signature(func)} - {description}"
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
return_direct=return_direct,
|
||||
args_schema=_args_schema,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
def structured_tool(
|
||||
*args: Union[str, Callable],
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
) -> Callable:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
Args:
|
||||
*args: The arguments to the tool.
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop.
|
||||
args_schema: Optional argument schema for user to specify. If
|
||||
none, will infer the schema from the function's signature.
|
||||
|
||||
Requires:
|
||||
- Function must be of type (str) -> str
|
||||
- Function must have a docstring
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
|
||||
@tool("search", return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(func: Callable) -> StructuredTool:
|
||||
return StructuredTool.from_function(
|
||||
name=tool_name,
|
||||
func=func,
|
||||
args_schema=args_schema,
|
||||
return_direct=return_direct,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
# if the argument is a string, then we use the string as the tool name
|
||||
# Example usage: @tool("search", return_direct=True)
|
||||
return _make_with_name(args[0])
|
||||
elif len(args) == 1 and callable(args[0]):
|
||||
# if the argument is a function, then we use the function name as the tool name
|
||||
# Example usage: @tool
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
elif len(args) == 0:
|
||||
# if there are no arguments, then we use the function name as the tool name
|
||||
# Example usage: @tool(return_direct=True)
|
||||
def _partial(func: Callable[[str], str]) -> BaseStructuredTool:
|
||||
return _make_with_name(func.__name__)(func)
|
||||
|
||||
return _partial
|
||||
else:
|
||||
raise ValueError("Too many arguments for tool decorator")
|
||||
12
langchain/utilities/async_utils.py
Normal file
12
langchain/utilities/async_utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Async utilities."""
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
async def async_or_sync_call(
|
||||
method: Callable, *args: Any, is_async: bool, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run the callback manager method asynchronously or synchronously."""
|
||||
if is_async:
|
||||
return await method(*args, **kwargs)
|
||||
else:
|
||||
return method(*args, **kwargs)
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Util that calls OpenWeatherMap using PyOWM."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.tools.base import BaseModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
"""Test tool utils."""
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import pydantic
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.tools.base import BaseTool, SchemaAnnotationError
|
||||
from langchain.tools.base import BaseTool, StringSchema
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@@ -23,169 +20,7 @@ def test_unnamed_decorator() -> None:
|
||||
assert search_api.name == "search_api"
|
||||
assert not search_api.return_direct
|
||||
assert search_api("test") == "API result"
|
||||
|
||||
|
||||
class _MockSchema(BaseModel):
|
||||
arg1: int
|
||||
arg2: bool
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockStructuredTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_structured_args() -> None:
|
||||
"""Test functionality with structured arguments."""
|
||||
structured_api = _MockStructuredTool()
|
||||
assert isinstance(structured_api, BaseTool)
|
||||
assert structured_api.name == "structured_api"
|
||||
expected_result = "1 True {'foo': 'bar'}"
|
||||
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
|
||||
assert structured_api.run(args) == expected_result
|
||||
|
||||
|
||||
def test_unannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool without type hints raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _UnAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_misannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _MisAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema: BaseModel = _MockSchema # type: ignore
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_forward_ref_annotated_base_tool_accepted() -> None:
|
||||
"""Test that a using forward ref annotation syntax is accepted.""" ""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: "Type[BaseModel]" = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_subclass_annotated_base_tool_accepted() -> None:
|
||||
"""Test BaseTool child w/ custom schema isn't overwritten."""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[_MockSchema] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
|
||||
tool = _ForwardRefAnnotatedTool()
|
||||
assert tool.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorator_with_specified_schema() -> None:
|
||||
"""Test that manually specified schemata are passed through to the tool."""
|
||||
|
||||
@tool(args_schema=_MockSchema)
|
||||
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func, Tool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
|
||||
@tool
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(structured_tool_input, Tool)
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
== structured_tool_input.args
|
||||
)
|
||||
|
||||
|
||||
def test_structured_args_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1}, {arg2}, {opt_arg}"
|
||||
|
||||
assert isinstance(structured_tool_input, Tool)
|
||||
assert structured_tool_input.name == "structured_tool_input"
|
||||
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
|
||||
expected_result = "1, 0.001, {'foo': 'bar'}"
|
||||
assert structured_tool_input.run(args) == expected_result
|
||||
|
||||
|
||||
def test_structured_single_str_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def unstructured_tool_input(tool_input: str) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{tool_input}"
|
||||
|
||||
assert isinstance(unstructured_tool_input, Tool)
|
||||
assert unstructured_tool_input.args_schema is None
|
||||
assert search_api.args_schema == StringSchema
|
||||
|
||||
|
||||
def test_base_tool_inheritance_base_schema() -> None:
|
||||
@@ -202,7 +37,7 @@ def test_base_tool_inheritance_base_schema() -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
simple_tool = _MockSimpleTool()
|
||||
assert simple_tool.args_schema is None
|
||||
assert simple_tool.args_schema == StringSchema
|
||||
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
|
||||
assert simple_tool.args == expected_args
|
||||
|
||||
@@ -215,56 +50,11 @@ def test_tool_lambda_args_schema() -> None:
|
||||
description="A tool",
|
||||
func=lambda tool_input: tool_input,
|
||||
)
|
||||
assert tool.args_schema is None
|
||||
expected_args = {"tool_input": {"title": "Tool Input"}}
|
||||
assert tool.args_schema == StringSchema
|
||||
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||
)
|
||||
assert tool.args_schema is None
|
||||
expected_args = {
|
||||
"tool_input": {"title": "Tool Input"},
|
||||
"other_arg": {"title": "Other Arg"},
|
||||
}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_partial_function_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a partial function."""
|
||||
|
||||
def func(tool_input: str, other_arg: str) -> str:
|
||||
return tool_input + other_arg
|
||||
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# We don't yet support args_schema inference for partial functions
|
||||
# so want to make sure we proactively raise an error
|
||||
Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=partial(func, other_arg="foo"),
|
||||
)
|
||||
|
||||
|
||||
def test_empty_args_decorator() -> None:
|
||||
"""Test inferred schema of decorated fn with no args."""
|
||||
|
||||
@tool
|
||||
def empty_tool_input() -> str:
|
||||
"""Return a constant."""
|
||||
return "the empty result"
|
||||
|
||||
assert isinstance(empty_tool_input, Tool)
|
||||
assert empty_tool_input.name == "empty_tool_input"
|
||||
assert empty_tool_input.args == {}
|
||||
assert empty_tool_input.run({}) == "the empty result"
|
||||
|
||||
|
||||
def test_named_tool_decorator() -> None:
|
||||
"""Test functionality when arguments are provided as input to decorator."""
|
||||
|
||||
@@ -304,32 +94,17 @@ def test_unnamed_tool_decorator_return_direct() -> None:
|
||||
assert search_api.return_direct
|
||||
|
||||
|
||||
def test_tool_with_kwargs() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
def test_base_tool_decorator_multiple_args() -> None:
|
||||
"""Test the schema that's generated is still a simple string."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def search_api(
|
||||
arg_1: float,
|
||||
ping: str = "hi",
|
||||
) -> str:
|
||||
def some_tool(query: str, foo: int = 3, bar: Optional[dict] = None) -> str:
|
||||
"""Search the API for the query."""
|
||||
return f"arg_1={arg_1}, ping={ping}"
|
||||
return f"{query} {foo} {bar}"
|
||||
|
||||
assert isinstance(search_api, Tool)
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
"ping": "pong",
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=pong"
|
||||
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=hi"
|
||||
assert isinstance(some_tool, Tool)
|
||||
assert some_tool.name == "some_tool"
|
||||
assert some_tool.run("foo") == "foo 3 None"
|
||||
|
||||
|
||||
def test_missing_docstring() -> None:
|
||||
|
||||
@@ -12,10 +12,10 @@ def test_read_file_with_root_dir() -> None:
|
||||
with (Path(temp_dir) / "file.txt").open("w") as f:
|
||||
f.write("Hello, world!")
|
||||
tool = ReadFileTool(root_dir=temp_dir)
|
||||
result = tool.run("file.txt")
|
||||
result = tool.run({"file_path": "file.txt"})
|
||||
assert result == "Hello, world!"
|
||||
# Check absolute files can still be passed if they lie within the root dir.
|
||||
result = tool.run(str(Path(temp_dir) / "file.txt"))
|
||||
result = tool.run({"file_path": str(Path(temp_dir) / "file.txt")})
|
||||
assert result == "Hello, world!"
|
||||
|
||||
|
||||
@@ -25,5 +25,5 @@ def test_read_file() -> None:
|
||||
with (Path(temp_dir) / "file.txt").open("w") as f:
|
||||
f.write("Hello, world!")
|
||||
tool = ReadFileTool()
|
||||
result = tool.run(str(Path(temp_dir) / "file.txt"))
|
||||
result = tool.run({"file_path": str(Path(temp_dir) / "file.txt")})
|
||||
assert result == "Hello, world!"
|
||||
|
||||
209
tests/unit_tests/tools/test_structured.py
Normal file
209
tests/unit_tests/tools/test_structured.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.tools.structured import (
|
||||
BaseStructuredTool,
|
||||
StructuredTool,
|
||||
structured_tool,
|
||||
)
|
||||
|
||||
|
||||
class _MockSchema(BaseModel):
|
||||
arg1: int
|
||||
arg2: bool
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockStructuredTool(BaseStructuredTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_structured_args() -> None:
|
||||
"""Test functionality with structured arguments."""
|
||||
structured_api = _MockStructuredTool()
|
||||
assert isinstance(structured_api, BaseStructuredTool)
|
||||
assert structured_api.name == "structured_api"
|
||||
expected_result = "1 True {'foo': 'bar'}"
|
||||
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
|
||||
assert structured_api.run(args) == expected_result
|
||||
|
||||
|
||||
def test_subclass_annotated_base_tool_accepted() -> None:
|
||||
"""Test BaseTool child w/ custom schema isn't overwritten."""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseStructuredTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[_MockSchema] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
assert issubclass(_ForwardRefAnnotatedTool, BaseStructuredTool)
|
||||
tool = _ForwardRefAnnotatedTool()
|
||||
assert tool.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorator_with_specified_schema() -> None:
|
||||
"""Test that manually specified schemata are passed through to the tool."""
|
||||
|
||||
@structured_tool(args_schema=_MockSchema)
|
||||
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func, StructuredTool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
|
||||
@structured_tool
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(structured_tool_input, StructuredTool)
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
== structured_tool_input.args
|
||||
)
|
||||
|
||||
|
||||
def test_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = StructuredTool.from_function(
|
||||
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||
name="tool",
|
||||
description="A tool",
|
||||
)
|
||||
assert set(tool.args_schema.schema()["properties"]) == {"tool_input", "other_arg"}
|
||||
expected_args = {
|
||||
"tool_input": {"title": "Tool Input"},
|
||||
"other_arg": {"title": "Other Arg"},
|
||||
}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_partial_function_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a partial function."""
|
||||
|
||||
def func(tool_input: str, other_arg: str) -> str:
|
||||
return tool_input + other_arg
|
||||
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# We don't yet support args_schema inference for partial functions
|
||||
# so want to make sure we proactively raise an error
|
||||
StructuredTool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=partial(func, other_arg="foo"),
|
||||
)
|
||||
|
||||
|
||||
def test_tool_with_kwargs() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
|
||||
@structured_tool(return_direct=True)
|
||||
def search_api(
|
||||
arg_1: float,
|
||||
ping: str = "hi",
|
||||
) -> str:
|
||||
"""Search the API for the query."""
|
||||
return f"arg_1={arg_1}, ping={ping}"
|
||||
|
||||
assert isinstance(search_api, StructuredTool)
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
"ping": "pong",
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=pong"
|
||||
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=hi"
|
||||
|
||||
|
||||
def test_empty_args_decorator() -> None:
|
||||
"""Test inferred schema of decorated fn with no args."""
|
||||
|
||||
@structured_tool
|
||||
def empty_tool_input() -> str:
|
||||
"""Return a constant."""
|
||||
return "the empty result"
|
||||
|
||||
assert isinstance(empty_tool_input, StructuredTool)
|
||||
assert empty_tool_input.name == "empty_tool_input"
|
||||
assert empty_tool_input.args == {}
|
||||
assert empty_tool_input.run({}) == "the empty result"
|
||||
|
||||
|
||||
def test_nested_pydantic_args() -> None:
|
||||
"""Test inferred schema when args are nested pydantic models."""
|
||||
# This is a pattern that is common with FastAPI methods.
|
||||
# If we only parse a dict input but pass the dict
|
||||
# to the function, we are limited only to primitive types
|
||||
# in general.
|
||||
|
||||
class SomeNestedInput(BaseModel):
|
||||
arg2: str
|
||||
|
||||
class SomeInput(BaseModel):
|
||||
arg1: int
|
||||
arg2: SomeNestedInput
|
||||
|
||||
@structured_tool
|
||||
def nested_tool(some_input: SomeInput) -> dict:
|
||||
"""Return a constant."""
|
||||
return some_input.dict()
|
||||
|
||||
assert isinstance(nested_tool, StructuredTool)
|
||||
assert nested_tool.name == "nested_tool"
|
||||
input_ = {"some_input": {"arg1": 1, "arg2": {"arg2": "foo"}}}
|
||||
assert nested_tool.run(input_) == input_["some_input"]
|
||||
|
||||
|
||||
def test_warning_on_args_kwargs(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test inferred schema when args are nested pydantic models."""
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
|
||||
@structured_tool
|
||||
def anything_goes(*foo: Any, **bar: Any) -> str:
|
||||
"""Return a constant."""
|
||||
return str(foo) + "|" + str(bar)
|
||||
|
||||
# Check if the expected warning message was logged
|
||||
assert any(
|
||||
"anything_goes uses *args" in record.message for record in caplog.records
|
||||
)
|
||||
assert any(
|
||||
"anything_goes uses **kwargs" in record.message for record in caplog.records
|
||||
)
|
||||
Reference in New Issue
Block a user