Compare commits

...

13 Commits

Author SHA1 Message Date
vowelparrot
f04575757d change 2023-04-26 07:51:23 -07:00
vowelparrot
e236e64cd4 Upstream 2023-04-25 23:25:13 -07:00
vowelparrot
f4b4d52e58 Another option 2023-04-25 22:26:50 -07:00
vowelparrot
417ce72df2 come back here if you still want copys 2023-04-25 17:50:08 -07:00
vowelparrot
d7e9b72380 Merge branch 'master' into vwp/chatregtests 2023-04-25 17:15:55 -07:00
vowelparrot
203e97d789 update 2023-04-25 16:14:26 -07:00
vowelparrot
f99348fb12 foo 2023-04-25 15:43:46 -07:00
vowelparrot
6bc9700863 undo 2023-04-25 15:41:10 -07:00
vowelparrot
4abeea3d42 don't like 2023-04-25 14:36:55 -07:00
vowelparrot
e79e003218 Hm 2023-04-25 14:10:33 -07:00
vowelparrot
37fca4ae75 foo 2023-04-25 11:49:34 -07:00
vowelparrot
c90ce64757 hm 2023-04-25 10:14:33 -07:00
vowelparrot
4cd6a2223a MITRT 2023-04-25 10:14:13 -07:00
20 changed files with 770 additions and 551 deletions

View File

@@ -28,11 +28,24 @@ from langchain.schema import (
BaseOutputParser,
)
from langchain.tools.base import BaseTool
from langchain.tools.structured import BaseStructuredTool
from langchain.utilities.asyncio import asyncio_timeout
logger = logging.getLogger(__name__)
def validate_all_instance_of_base_tool(
class_name: str, tools: Sequence[BaseStructuredTool]
) -> None:
"""Validate that all tools are of type BaseStructuredTool."""
for tool in tools:
if not isinstance(tool, BaseTool):
raise TypeError(
f"Agent {class_name} only supporte tools of type {BaseTool}."
f" {tool.name} is of type {type(tool).__name__}"
)
class BaseSingleActionAgent(BaseModel):
"""Base Agent class."""
@@ -103,7 +116,7 @@ class BaseSingleActionAgent(BaseModel):
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
@@ -448,13 +461,12 @@ class Agent(BaseSingleActionAgent):
@classmethod
@abstractmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
"""Create a prompt for this class."""
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
"""Validate that appropriate tools are passed in."""
pass
@classmethod
@abstractmethod
@@ -465,7 +477,7 @@ class Agent(BaseSingleActionAgent):
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
**kwargs: Any,
@@ -539,7 +551,7 @@ class AgentExecutor(Chain):
"""Consists of an agent using tools."""
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
tools: Sequence[BaseTool]
tools: Sequence[BaseStructuredTool]
return_intermediate_steps: bool = False
max_iterations: Optional[int] = 15
max_execution_time: Optional[float] = None
@@ -549,7 +561,7 @@ class AgentExecutor(Chain):
def from_agent_and_tools(
cls,
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> AgentExecutor:
@@ -617,7 +629,7 @@ class AgentExecutor(Chain):
else:
return self.agent.return_values
def lookup_tool(self, name: str) -> BaseTool:
def lookup_tool(self, name: str) -> BaseStructuredTool:
"""Lookup tool by name."""
return {tool.name: tool for tool in self.tools}[name]
@@ -657,9 +669,51 @@ class AgentExecutor(Chain):
final_output["intermediate_steps"] = intermediate_steps
return final_output
def _run_tool(
self,
tool: BaseStructuredTool,
agent_action: AgentAction,
color: str,
**tool_run_kwargs: Any,
) -> Any:
tool_input = agent_action.tool_input
if isinstance(tool_input, str):
if not isinstance(tool, BaseTool):
return (
f"Error: tool {tool.name} could not be "
f"run with input {agent_action.tool_input}"
)
return tool.run(
tool_input,
verbose=self.verbose,
color=color,
**tool_run_kwargs,
)
async def _aruntool(
self,
tool: BaseStructuredTool,
agent_action: AgentAction,
color: str,
**tool_run_kwargs: Any,
) -> Any:
tool_input = agent_action.tool_input
if isinstance(tool_input, str):
if not isinstance(tool, BaseTool):
return (
f"Error: tool {tool.name} could not be "
f"run with input {agent_action.tool_input}"
)
return await tool.arun(
agent_action.tool_input,
verbose=self.verbose,
color=color,
**tool_run_kwargs,
)
def _take_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
name_to_tool_map: Dict[str, BaseStructuredTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
@@ -692,11 +746,8 @@ class AgentExecutor(Chain):
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
agent_action.tool_input,
verbose=self.verbose,
color=color,
**tool_run_kwargs,
observation = self._run_tool(
tool, agent_action, color, run_async=True, **tool_run_kwargs
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
@@ -711,7 +762,7 @@ class AgentExecutor(Chain):
async def _atake_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
name_to_tool_map: Dict[str, BaseStructuredTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
@@ -751,11 +802,8 @@ class AgentExecutor(Chain):
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = await tool.arun(
agent_action.tool_input,
verbose=self.verbose,
color=color,
**tool_run_kwargs,
observation = await self._aruntool(
tool, agent_action, color, run_async=True, **tool_run_kwargs
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()

View File

@@ -2,7 +2,11 @@ from typing import Any, List, Optional, Sequence, Tuple
from pydantic import Field
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent import (
Agent,
AgentOutputParser,
validate_all_instance_of_base_tool,
)
from langchain.agents.chat.output_parser import ChatOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.callbacks.base import BaseCallbackManager
@@ -14,7 +18,7 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction, BaseLanguageModel
from langchain.tools import BaseTool
from langchain.tools.structured import BaseStructuredTool
class ChatAgent(Agent):
@@ -56,7 +60,7 @@ class ChatAgent(Agent):
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
@@ -74,11 +78,16 @@ class ChatAgent(Agent):
input_variables = ["input", "agent_scratchpad"]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
super()._validate_tools(tools)
validate_all_instance_of_base_tool(cls.__name__, tools)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,

View File

@@ -20,7 +20,10 @@ class ChatOutputParser(AgentOutputParser):
try:
action = text.split("```")[1]
response = json.loads(action.strip())
return AgentAction(response["action"], response["action_input"], text)
action_input = response["action_input"]
if isinstance(action_input, dict):
action_input = json.dumps(action_input)
return AgentAction(response["action"], action_input, text)
except Exception:
raise OutputParserException(f"Could not parse LLM output: {text}")

View File

@@ -5,7 +5,11 @@ from typing import Any, List, Optional, Sequence
from pydantic import Field
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent import (
Agent,
AgentOutputParser,
validate_all_instance_of_base_tool,
)
from langchain.agents.agent_types import AgentType
from langchain.agents.conversational.output_parser import ConvoOutputParser
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
@@ -13,7 +17,7 @@ from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.structured import BaseStructuredTool
class ConversationalAgent(Agent):
@@ -46,7 +50,7 @@ class ConversationalAgent(Agent):
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
@@ -80,11 +84,16 @@ class ConversationalAgent(Agent):
input_variables = ["input", "chat_history", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
@classmethod
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
super()._validate_tools(tools)
validate_all_instance_of_base_tool(cls.__name__, tools)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,

View File

@@ -5,7 +5,11 @@ from typing import Any, List, Optional, Sequence, Tuple
from pydantic import Field
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent import (
Agent,
AgentOutputParser,
validate_all_instance_of_base_tool,
)
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
from langchain.agents.conversational_chat.prompt import (
PREFIX,
@@ -29,7 +33,7 @@ from langchain.schema import (
BaseOutputParser,
HumanMessage,
)
from langchain.tools.base import BaseTool
from langchain.tools.structured import BaseStructuredTool
class ConversationalChatAgent(Agent):
@@ -58,7 +62,7 @@ class ConversationalChatAgent(Agent):
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
system_message: str = PREFIX,
human_message: str = SUFFIX,
input_variables: Optional[List[str]] = None,
@@ -98,11 +102,16 @@ class ConversationalChatAgent(Agent):
thoughts.append(human_message)
return thoughts
@classmethod
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
super()._validate_tools(tools)
validate_all_instance_of_base_tool(cls.__name__, tools)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
system_message: str = PREFIX,

View File

@@ -5,7 +5,12 @@ from typing import Any, Callable, List, NamedTuple, Optional, Sequence
from pydantic import Field
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent import (
Agent,
AgentExecutor,
AgentOutputParser,
validate_all_instance_of_base_tool,
)
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
@@ -14,7 +19,7 @@ from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.structured import BaseStructuredTool
class ChainConfig(NamedTuple):
@@ -58,7 +63,7 @@ class ZeroShotAgent(Agent):
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
@@ -88,7 +93,7 @@ class ZeroShotAgent(Agent):
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
tools: Sequence[BaseStructuredTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
@@ -121,13 +126,15 @@ class ZeroShotAgent(Agent):
)
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
super()._validate_tools(tools)
for tool in tools:
if tool.description is None:
raise ValueError(
f"Got a tool {tool.name} without a description. For this agent, "
f"a description must always be provided."
)
validate_all_instance_of_base_tool(cls.__name__, tools)
class MRKLChain(AgentExecutor):

View File

@@ -3,7 +3,12 @@ from typing import Any, List, Optional, Sequence
from pydantic import Field
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent import (
Agent,
AgentExecutor,
AgentOutputParser,
validate_all_instance_of_base_tool,
)
from langchain.agents.agent_types import AgentType
from langchain.agents.react.output_parser import ReActOutputParser
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
@@ -13,7 +18,7 @@ from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool
from langchain.tools.structured import BaseStructuredTool
class ReActDocstoreAgent(Agent):
@@ -31,12 +36,12 @@ class ReActDocstoreAgent(Agent):
return AgentType.REACT_DOCSTORE
@classmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
"""Return default prompt."""
return WIKI_PROMPT
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
if len(tools) != 2:
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}
@@ -44,6 +49,7 @@ class ReActDocstoreAgent(Agent):
raise ValueError(
f"Tool names should be Lookup and Search, got {tool_names}"
)
validate_all_instance_of_base_tool(cls.__name__, tools)
@property
def observation_prefix(self) -> str:
@@ -113,17 +119,18 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
"""Agent for the ReAct TextWorld chain."""
@classmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
"""Return default prompt."""
return TEXTWORLD_PROMPT
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}
if tool_names != {"Play"}:
raise ValueError(f"Tool name should be Play, got {tool_names}")
validate_all_instance_of_base_tool(cls.__name__, tools)
class ReActChain(AgentExecutor):

View File

@@ -10,7 +10,7 @@ from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool
from langchain.tools.structured import BaseStructuredTool
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.serpapi import SerpAPIWrapper
@@ -30,12 +30,12 @@ class SelfAskWithSearchAgent(Agent):
return AgentType.SELF_ASK_WITH_SEARCH
@classmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
"""Prompt does not depend on tools."""
return PROMPT
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}

View File

@@ -3,12 +3,11 @@ from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from pydantic import BaseModel, validate_arguments, validator
from pydantic import validator
from langchain.tools.base import (
BaseTool,
create_schema_from_function,
get_filtered_args,
StringSchema,
)
@@ -28,22 +27,14 @@ class Tool(BaseTool):
raise ValueError("Partial functions not yet supported in tools.")
return func
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: str) -> str:
"""Use the tool."""
return self.func(*args, **kwargs)
return self.func(tool_input)
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
return await self.coroutine(tool_input)
raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future
@@ -74,8 +65,7 @@ class InvalidTool(BaseTool):
def tool(
*args: Union[str, Callable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
args_schema: Optional[Type[StringSchema]] = None,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@@ -83,10 +73,7 @@ def tool(
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
args_schema: The schema for the arguments used to validate input.
Requires:
- Function must be of type (str) -> str
@@ -112,15 +99,13 @@ def tool(
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
tool_kwargs = {} if args_schema is None else {"args_schema": args_schema}
tool_ = Tool(
name=tool_name,
func=func,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
**tool_kwargs,
)
return tool_

View File

@@ -2,17 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from typing import Any, Dict, Generic, List, NamedTuple, Optional, Sequence, TypeVar
from pydantic import BaseModel, Extra, Field, root_validator
@@ -41,7 +31,7 @@ class AgentAction(NamedTuple):
"""Agent's action to take."""
tool: str
tool_input: Union[str, dict]
tool_input: str
log: str
@@ -401,8 +391,6 @@ class OutputParserException(Exception):
errors will be raised.
"""
pass
class BaseDocumentTransformer(ABC):
"""Base interface for transforming documents."""

View File

@@ -2,241 +2,68 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, Sequence, Tuple, Type, Union
from pydantic import (
BaseModel,
Extra,
Field,
create_model,
validate_arguments,
validator,
)
from pydantic.main import ModelMetaclass
from pydantic import BaseModel
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.tools.structured import BaseStructuredTool
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input
class StringSchema(BaseModel):
"""Schema for a tool with string input."""
# Child tools can add additional validation by
# subclassing this schema.
tool_input: str
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
class ToolMetaclass(ModelMetaclass):
"""Metaclass for BaseTool to ensure the provided args_schema
doesn't silently ignored."""
def __new__(
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
) -> ToolMetaclass:
"""Create the definition of the new tool class."""
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
if schema_type is not None:
schema_annotations = dct.get("__annotations__", {})
args_schema_type = schema_annotations.get("args_schema", None)
if args_schema_type is None or args_schema_type == BaseModel:
# Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully
# specify valid annotations.
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
..."""
raise SchemaAnnotationError(
f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
)
# Pass through to Pydantic's metaclass
return super().__new__(cls, name, bases, dct)
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
)
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
class BaseTool(ABC, BaseStructuredTool[str]):
"""Interface LangChain tools must implement."""
name: str
description: str
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
args_schema: Type[StringSchema] = StringSchema # :meta private:
class Config:
"""Configuration for this pydantic object."""
def _wrap_input(self, tool_input: Union[str, Dict]) -> Dict:
"""Wrap the tool's input into a pydantic model."""
if isinstance(tool_input, dict):
return tool_input
return {"tool_input": tool_input}
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self._run).model # type: ignore
return get_filtered_args(inferred_model, self._run)
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> None:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
else:
if input_args is not None:
input_args.validate(tool_input)
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
def _prepare_input(self, input_: dict) -> Tuple[Sequence, Dict]:
"""Prepare the args and kwargs for the tool."""
# We expect a single string input
return tuple(input_.values()), {}
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: str) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
def run(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
tool_input: Union[str, dict],
verbose: bool | None = None,
start_color: str | None = "green",
color: str | None = "green",
**kwargs: Any,
) -> str:
"""Run the tool."""
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
"""Use the tool."""
wrapped_input = self._wrap_input(tool_input)
return super().run(wrapped_input, verbose, start_color, color, **kwargs)
async def arun(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
tool_input: Union[str, dict],
verbose: bool | None = None,
start_color: str | None = "green",
color: str | None = "green",
**kwargs: Any,
) -> str:
"""Run the tool asynchronously."""
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
else:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
else:
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
"""Use the tool asynchronously."""
wrapped_input = self._wrap_input(tool_input)
return await super().arun(wrapped_input, verbose, start_color, color, **kwargs)
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
def __call__(self, tool_input: Union[dict, str]) -> str:
return self.run(tool_input)

View File

@@ -1,22 +1,22 @@
from pathlib import Path
from typing import Optional, Type
from pydantic import BaseModel, Field
from pydantic import Field
from langchain.tools.base import BaseTool
from langchain.tools.base import BaseTool, StringSchema
from langchain.tools.file_management.utils import get_validated_relative_path
class ReadFileInput(BaseModel):
class ReadFileInput(StringSchema):
"""Input for ReadFileTool."""
file_path: str = Field(..., description="name of file")
tool_input: str = Field(..., description="name of file", alias="file_path")
class ReadFileTool(BaseTool):
name: str = "read_file"
args_schema: Type[BaseModel] = ReadFileInput
description: str = "Read file from disk"
args_schema: Type[ReadFileInput] = ReadFileInput
description: str = "(file_path: str) -> str, Read file from disk."
root_dir: Optional[str] = None
"""Directory to read file from.
@@ -35,6 +35,6 @@ class ReadFileTool(BaseTool):
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
async def _arun(self, file_path: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

View File

@@ -3,8 +3,8 @@ from typing import Optional, Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
from langchain.tools.file_management.utils import get_validated_relative_path
from langchain.tools.structured import BaseStructuredTool
class WriteFileInput(BaseModel):
@@ -14,7 +14,7 @@ class WriteFileInput(BaseModel):
text: str = Field(..., description="text to write to file")
class WriteFileTool(BaseTool):
class WriteFileTool(BaseStructuredTool[str]):
name: str = "write_file"
args_schema: Type[BaseModel] = WriteFileInput
description: str = "Write file to disk"

View File

@@ -127,11 +127,11 @@ class ListPowerBITool(BaseTool):
arbitrary_types_allowed = True
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: str = "") -> str:
"""Get the names of the tables."""
return ", ".join(self.powerbi.get_table_names())
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: str = "") -> str:
"""Get the names of the tables."""
return ", ".join(self.powerbi.get_table_names())

View File

@@ -0,0 +1,332 @@
from __future__ import annotations
import logging
from abc import abstractmethod
from inspect import Parameter, signature
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel, Extra, Field, create_model, validate_arguments
from pydantic.generics import GenericModel
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.utilities.async_utils import async_or_sync_call
logger = logging.getLogger(__name__)
OUTPUT_T = TypeVar("OUTPUT_T")
class BaseStructuredTool(
GenericModel,
Generic[OUTPUT_T],
BaseModel,
):
"""Parent class for all structured tools."""
name: str
description: str
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
args_schema: Type[BaseModel] # :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> Dict:
return self.args_schema.schema()["properties"]
def _get_verbosity(
self,
verbose: Optional[bool] = None,
) -> bool:
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
return verbose_
def _prepare_input(self, input_: dict) -> Tuple[Sequence, Dict]:
"""Prepare the args and kwargs for the tool."""
return (), input_
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> OUTPUT_T:
"""Use the tool."""
@abstractmethod
async def _arun(self, *args: Any, **kwargs: Any) -> OUTPUT_T:
"""Use the tool asynchronously."""
def _parse_input(self, tool_input: dict) -> dict:
parsed_input = self.args_schema.parse_obj(tool_input)
parsed_dict = parsed_input.dict()
return {k: getattr(parsed_input, k) for k in parsed_dict.keys()}
def run(
self,
tool_input: dict,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUTPUT_T:
"""Run the tool."""
tool_input = self._parse_input(tool_input)
verbose_ = self._get_verbosity(verbose)
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
args, kwargs = self._prepare_input(tool_input)
observation = self._run(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
async def arun(
self,
tool_input: dict,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUTPUT_T:
"""Run the tool asynchronously."""
tool_input = self._parse_input(tool_input)
verbose_ = self._get_verbosity(verbose)
await async_or_sync_call(
self.callback_manager.on_tool_start,
{"name": self.name, "description": self.description},
str(tool_input),
verbose=verbose_,
color=start_color,
is_async=self.callback_manager.is_async,
**kwargs,
)
try:
args, kwargs = self._prepare_input(tool_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
await async_or_sync_call(
self.callback_manager.on_tool_error,
e,
verbose=verbose_,
is_async=self.callback_manager.is_async,
)
raise e
await async_or_sync_call(
self.callback_manager.on_tool_end,
str(observation),
verbose=verbose_,
color=color,
is_async=self.callback_manager.is_async,
**kwargs,
)
return observation
def __call__(self, tool_input: dict) -> OUTPUT_T:
"""Make tool callable."""
return self.run(tool_input)
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def _warn_args_kwargs(func: Callable) -> None:
# Check if the function has *args or **kwargs
sig = signature(func)
for param in sig.parameters.values():
if param.kind == Parameter.VAR_POSITIONAL:
logger.warning(f"{func.__name__} uses *args, which are not well supported.")
elif param.kind == Parameter.VAR_KEYWORD:
logger.warning(
f"{func.__name__} uses **kwargs, which are not well supported."
)
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
_warn_args_kwargs(func)
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
)
class StructuredTool(BaseStructuredTool[Any]):
"""StructuredTool that takes in function or coroutine directly."""
func: Callable[..., Any]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
args_schema: Type[BaseModel] # :meta private:
@property
def args(self) -> dict:
"""The JSON Schema arguments for the tool."""
return self.args_schema.schema()["properties"]
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
raise NotImplementedError(f"StructuredTool {self.name} does not support async")
@classmethod
def from_function(
cls,
func: Callable[..., Any],
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
name: Optional[str] = None,
description: Optional[str] = None,
) -> "StructuredTool":
"""Make tools out of functions, can be used with or without arguments.
Args:
func: The function to run when the tool is called.
coroutine: The asynchronous version of the function.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
name: The name of the tool. Defaults to the function name.
description: The description of the tool. Defaults to the function
docstring.
"""
description = func.__doc__ or description
if description is None or not description.strip():
raise ValueError(
f"Function {func.__name__} must have a docstring, or set description."
)
name = name or func.__name__
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
description = f"{name}{signature(func)} - {description}"
return cls(
name=name,
func=func,
coroutine=coroutine,
return_direct=return_direct,
args_schema=_args_schema,
description=description,
)
def structured_tool(
*args: Union[str, Callable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
Args:
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: Optional argument schema for user to specify. If
none, will infer the schema from the function's signature.
Requires:
- Function must be of type (str) -> str
- Function must have a docstring
Examples:
.. code-block:: python
@tool
def search_api(query: str) -> str:
# Searches the API for the query.
return
@tool("search", return_direct=True)
def search_api(query: str) -> str:
# Searches the API for the query.
return
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> StructuredTool:
return StructuredTool.from_function(
name=tool_name,
func=func,
args_schema=args_schema,
return_direct=return_direct,
)
return _make_tool
if len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])
elif len(args) == 1 and callable(args[0]):
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool
return _make_with_name(args[0].__name__)(args[0])
elif len(args) == 0:
# if there are no arguments, then we use the function name as the tool name
# Example usage: @tool(return_direct=True)
def _partial(func: Callable[[str], str]) -> BaseStructuredTool:
return _make_with_name(func.__name__)(func)
return _partial
else:
raise ValueError("Too many arguments for tool decorator")

View File

@@ -0,0 +1,12 @@
"""Async utilities."""
from typing import Any, Callable
async def async_or_sync_call(
method: Callable, *args: Any, is_async: bool, **kwargs: Any
) -> Any:
"""Run the callback manager method asynchronously or synchronously."""
if is_async:
return await method(*args, **kwargs)
else:
return method(*args, **kwargs)

View File

@@ -1,9 +1,8 @@
"""Util that calls OpenWeatherMap using PyOWM."""
from typing import Any, Dict, Optional
from pydantic import Extra, root_validator
from pydantic import BaseModel, Extra, root_validator
from langchain.tools.base import BaseModel
from langchain.utils import get_from_dict_or_env

View File

@@ -1,14 +1,11 @@
"""Test tool utils."""
from datetime import datetime
from functools import partial
from typing import Optional, Type, Union
import pydantic
from typing import Optional
import pytest
from pydantic import BaseModel
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool, SchemaAnnotationError
from langchain.tools.base import BaseTool, StringSchema
def test_unnamed_decorator() -> None:
@@ -23,169 +20,7 @@ def test_unnamed_decorator() -> None:
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_unannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool without type hints raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _UnAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_misannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _MisAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema: BaseModel = _MockSchema # type: ignore
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_forward_ref_annotated_base_tool_accepted() -> None:
"""Test that a using forward ref annotation syntax is accepted.""" ""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: "Type[BaseModel]" = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_subclass_annotated_base_tool_accepted() -> None:
"""Test BaseTool child w/ custom schema isn't overwritten."""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: Type[_MockSchema] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
tool = _ForwardRefAnnotatedTool()
assert tool.args_schema == _MockSchema
def test_decorator_with_specified_schema() -> None:
"""Test that manually specified schemata are passed through to the tool."""
@tool(args_schema=_MockSchema)
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, Tool)
assert tool_func.args_schema == _MockSchema
def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations."""
@tool
def structured_tool_input(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, Tool)
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
== structured_tool_input.args
)
def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, Tool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
assert structured_tool_input.run(args) == expected_result
def test_structured_single_str_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def unstructured_tool_input(tool_input: str) -> str:
"""Return the arguments directly."""
return f"{tool_input}"
assert isinstance(unstructured_tool_input, Tool)
assert unstructured_tool_input.args_schema is None
assert search_api.args_schema == StringSchema
def test_base_tool_inheritance_base_schema() -> None:
@@ -202,7 +37,7 @@ def test_base_tool_inheritance_base_schema() -> None:
raise NotImplementedError
simple_tool = _MockSimpleTool()
assert simple_tool.args_schema is None
assert simple_tool.args_schema == StringSchema
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
assert simple_tool.args == expected_args
@@ -215,56 +50,11 @@ def test_tool_lambda_args_schema() -> None:
description="A tool",
func=lambda tool_input: tool_input,
)
assert tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input"}}
assert tool.args_schema == StringSchema
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
assert tool.args == expected_args
def test_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
name="tool",
description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
)
assert tool.args_schema is None
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
}
assert tool.args == expected_args
def test_tool_partial_function_args_schema() -> None:
"""Test args schema inference when the tool argument is a partial function."""
def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg
with pytest.raises(pydantic.error_wrappers.ValidationError):
# We don't yet support args_schema inference for partial functions
# so want to make sure we proactively raise an error
Tool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
def test_empty_args_decorator() -> None:
"""Test inferred schema of decorated fn with no args."""
@tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, Tool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator."""
@@ -304,32 +94,17 @@ def test_unnamed_tool_decorator_return_direct() -> None:
assert search_api.return_direct
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
def test_base_tool_decorator_multiple_args() -> None:
"""Test the schema that's generated is still a simple string."""
@tool(return_direct=True)
def search_api(
arg_1: float,
ping: str = "hi",
) -> str:
def some_tool(query: str, foo: int = 3, bar: Optional[dict] = None) -> str:
"""Search the API for the query."""
return f"arg_1={arg_1}, ping={ping}"
return f"{query} {foo} {bar}"
assert isinstance(search_api, Tool)
result = search_api.run(
tool_input={
"arg_1": 3.2,
"ping": "pong",
}
)
assert result == "arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_1": 3.2,
}
)
assert result == "arg_1=3.2, ping=hi"
assert isinstance(some_tool, Tool)
assert some_tool.name == "some_tool"
assert some_tool.run("foo") == "foo 3 None"
def test_missing_docstring() -> None:

View File

@@ -12,10 +12,10 @@ def test_read_file_with_root_dir() -> None:
with (Path(temp_dir) / "file.txt").open("w") as f:
f.write("Hello, world!")
tool = ReadFileTool(root_dir=temp_dir)
result = tool.run("file.txt")
result = tool.run({"file_path": "file.txt"})
assert result == "Hello, world!"
# Check absolute files can still be passed if they lie within the root dir.
result = tool.run(str(Path(temp_dir) / "file.txt"))
result = tool.run({"file_path": str(Path(temp_dir) / "file.txt")})
assert result == "Hello, world!"
@@ -25,5 +25,5 @@ def test_read_file() -> None:
with (Path(temp_dir) / "file.txt").open("w") as f:
f.write("Hello, world!")
tool = ReadFileTool()
result = tool.run(str(Path(temp_dir) / "file.txt"))
result = tool.run({"file_path": str(Path(temp_dir) / "file.txt")})
assert result == "Hello, world!"

View File

@@ -0,0 +1,209 @@
import logging
from functools import partial
from typing import Any, Optional, Type
import pydantic
import pytest
from pydantic import BaseModel
from langchain.tools.structured import (
BaseStructuredTool,
StructuredTool,
structured_tool,
)
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseStructuredTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseStructuredTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_subclass_annotated_base_tool_accepted() -> None:
"""Test BaseTool child w/ custom schema isn't overwritten."""
class _ForwardRefAnnotatedTool(BaseStructuredTool):
name = "structured_api"
args_schema: Type[_MockSchema] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
assert issubclass(_ForwardRefAnnotatedTool, BaseStructuredTool)
tool = _ForwardRefAnnotatedTool()
assert tool.args_schema == _MockSchema
def test_decorator_with_specified_schema() -> None:
"""Test that manually specified schemata are passed through to the tool."""
@structured_tool(args_schema=_MockSchema)
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, StructuredTool)
assert tool_func.args_schema == _MockSchema
def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations."""
@structured_tool
def structured_tool_input(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, StructuredTool)
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
== structured_tool_input.args
)
def test_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = StructuredTool.from_function(
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
name="tool",
description="A tool",
)
assert set(tool.args_schema.schema()["properties"]) == {"tool_input", "other_arg"}
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
}
assert tool.args == expected_args
def test_tool_partial_function_args_schema() -> None:
"""Test args schema inference when the tool argument is a partial function."""
def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg
with pytest.raises(pydantic.error_wrappers.ValidationError):
# We don't yet support args_schema inference for partial functions
# so want to make sure we proactively raise an error
StructuredTool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
@structured_tool(return_direct=True)
def search_api(
arg_1: float,
ping: str = "hi",
) -> str:
"""Search the API for the query."""
return f"arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, StructuredTool)
result = search_api.run(
tool_input={
"arg_1": 3.2,
"ping": "pong",
}
)
assert result == "arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_1": 3.2,
}
)
assert result == "arg_1=3.2, ping=hi"
def test_empty_args_decorator() -> None:
"""Test inferred schema of decorated fn with no args."""
@structured_tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, StructuredTool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
def test_nested_pydantic_args() -> None:
"""Test inferred schema when args are nested pydantic models."""
# This is a pattern that is common with FastAPI methods.
# If we only parse a dict input but pass the dict
# to the function, we are limited only to primitive types
# in general.
class SomeNestedInput(BaseModel):
arg2: str
class SomeInput(BaseModel):
arg1: int
arg2: SomeNestedInput
@structured_tool
def nested_tool(some_input: SomeInput) -> dict:
"""Return a constant."""
return some_input.dict()
assert isinstance(nested_tool, StructuredTool)
assert nested_tool.name == "nested_tool"
input_ = {"some_input": {"arg1": 1, "arg2": {"arg2": "foo"}}}
assert nested_tool.run(input_) == input_["some_input"]
def test_warning_on_args_kwargs(caplog: pytest.LogCaptureFixture) -> None:
"""Test inferred schema when args are nested pydantic models."""
with caplog.at_level(logging.WARNING):
@structured_tool
def anything_goes(*foo: Any, **bar: Any) -> str:
"""Return a constant."""
return str(foo) + "|" + str(bar)
# Check if the expected warning message was logged
assert any(
"anything_goes uses *args" in record.message for record in caplog.records
)
assert any(
"anything_goes uses **kwargs" in record.message for record in caplog.records
)