mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-20 05:04:50 +00:00
Compare commits
9 Commits
jacob/patc
...
vwp/chatre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7e9b72380 | ||
|
|
203e97d789 | ||
|
|
f99348fb12 | ||
|
|
6bc9700863 | ||
|
|
4abeea3d42 | ||
|
|
e79e003218 | ||
|
|
37fca4ae75 | ||
|
|
c90ce64757 | ||
|
|
4cd6a2223a |
145
langchain/agents/structured_agent.py
Normal file
145
langchain/agents/structured_agent.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Chain that takes in an input and produces an action and action input."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from abc import abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.schema import (
|
||||||
|
StructuredAgentAction,
|
||||||
|
AgentFinish,
|
||||||
|
BaseLanguageModel,
|
||||||
|
)
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSingleActionAgent(BaseModel):
|
||||||
|
"""Base Agent class."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def return_values(self) -> List[str]:
|
||||||
|
"""Return values of the agent."""
|
||||||
|
return ["output"]
|
||||||
|
|
||||||
|
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def plan(
|
||||||
|
self, intermediate_steps: List[Tuple[StructuredAgentAction, str]], **kwargs: Any
|
||||||
|
) -> Union[StructuredAgentAction, AgentFinish]:
|
||||||
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date,
|
||||||
|
along with observations
|
||||||
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action specifying what tool to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def aplan(
|
||||||
|
self, intermediate_steps: List[Tuple[StructuredAgentAction, str]], **kwargs: Any
|
||||||
|
) -> Union[StructuredAgentAction, AgentFinish]:
|
||||||
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date,
|
||||||
|
along with observations
|
||||||
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action specifying what tool to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Return the input keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def return_stopped_response(
|
||||||
|
self,
|
||||||
|
early_stopping_method: str,
|
||||||
|
intermediate_steps: List[Tuple[StructuredAgentAction, str]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AgentFinish:
|
||||||
|
"""Return response when agent has been stopped due to max iterations."""
|
||||||
|
if early_stopping_method == "force":
|
||||||
|
# `force` just returns a constant string
|
||||||
|
return AgentFinish(
|
||||||
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_and_tools(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseSingleActionAgent:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _agent_type(self) -> str:
|
||||||
|
"""Return Identifier of agent type."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
|
"""Return dictionary representation of agent."""
|
||||||
|
_dict = super().dict()
|
||||||
|
_dict["_type"] = str(self._agent_type)
|
||||||
|
return _dict
|
||||||
|
|
||||||
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
|
"""Save the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file to save the agent to.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# If working with agent executor
|
||||||
|
agent.agent.save(file_path="path/agent.yaml")
|
||||||
|
"""
|
||||||
|
# Convert file to Path object.
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
save_path = Path(file_path)
|
||||||
|
else:
|
||||||
|
save_path = file_path
|
||||||
|
|
||||||
|
directory_path = save_path.parent
|
||||||
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Fetch dictionary to save
|
||||||
|
agent_dict = self.dict()
|
||||||
|
|
||||||
|
if save_path.suffix == ".json":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump(agent_dict, f, indent=4)
|
||||||
|
elif save_path.suffix == ".yaml":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
yaml.dump(agent_dict, f, default_flow_style=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
|
def tool_run_logging_kwargs(self) -> Dict:
|
||||||
|
return {}
|
||||||
@@ -1,14 +1,12 @@
|
|||||||
"""Interface for tools."""
|
"""Interface for tools."""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
from typing import Any, Awaitable, Callable, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, validate_arguments, validator
|
from pydantic import validator
|
||||||
|
|
||||||
from langchain.tools.base import (
|
from langchain.tools.base import (
|
||||||
BaseTool,
|
BaseTool,
|
||||||
create_schema_from_function,
|
|
||||||
get_filtered_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -28,22 +26,14 @@ class Tool(BaseTool):
|
|||||||
raise ValueError("Partial functions not yet supported in tools.")
|
raise ValueError("Partial functions not yet supported in tools.")
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@property
|
def _run(self, tool_input: str) -> str:
|
||||||
def args(self) -> dict:
|
|
||||||
if self.args_schema is not None:
|
|
||||||
return self.args_schema.schema()["properties"]
|
|
||||||
else:
|
|
||||||
inferred_model = validate_arguments(self.func).model # type: ignore
|
|
||||||
return get_filtered_args(inferred_model, self.func)
|
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
return self.func(*args, **kwargs)
|
return self.func(tool_input)
|
||||||
|
|
||||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
async def _arun(self, tool_input: str) -> str:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
if self.coroutine:
|
if self.coroutine:
|
||||||
return await self.coroutine(*args, **kwargs)
|
return await self.coroutine(tool_input)
|
||||||
raise NotImplementedError("Tool does not support async")
|
raise NotImplementedError("Tool does not support async")
|
||||||
|
|
||||||
# TODO: this is for backwards compatibility, remove in future
|
# TODO: this is for backwards compatibility, remove in future
|
||||||
@@ -74,8 +64,6 @@ class InvalidTool(BaseTool):
|
|||||||
def tool(
|
def tool(
|
||||||
*args: Union[str, Callable],
|
*args: Union[str, Callable],
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[Type[BaseModel]] = None,
|
|
||||||
infer_schema: bool = True,
|
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Make tools out of functions, can be used with or without arguments.
|
"""Make tools out of functions, can be used with or without arguments.
|
||||||
|
|
||||||
@@ -83,10 +71,6 @@ def tool(
|
|||||||
*args: The arguments to the tool.
|
*args: The arguments to the tool.
|
||||||
return_direct: Whether to return directly from the tool rather
|
return_direct: Whether to return directly from the tool rather
|
||||||
than continuing the agent loop.
|
than continuing the agent loop.
|
||||||
args_schema: optional argument schema for user to specify
|
|
||||||
infer_schema: Whether to infer the schema of the arguments from
|
|
||||||
the function's signature. This also makes the resultant tool
|
|
||||||
accept a dictionary input to its `run()` function.
|
|
||||||
|
|
||||||
Requires:
|
Requires:
|
||||||
- Function must be of type (str) -> str
|
- Function must be of type (str) -> str
|
||||||
@@ -112,13 +96,9 @@ def tool(
|
|||||||
# Description example:
|
# Description example:
|
||||||
# search_api(query: str) - Searches the API for the query.
|
# search_api(query: str) - Searches the API for the query.
|
||||||
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
||||||
_args_schema = args_schema
|
|
||||||
if _args_schema is None and infer_schema:
|
|
||||||
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
|
|
||||||
tool_ = Tool(
|
tool_ = Tool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
func=func,
|
func=func,
|
||||||
args_schema=_args_schema,
|
|
||||||
description=description,
|
description=description,
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -41,10 +41,21 @@ class AgentAction(NamedTuple):
|
|||||||
"""Agent's action to take."""
|
"""Agent's action to take."""
|
||||||
|
|
||||||
tool: str
|
tool: str
|
||||||
tool_input: Union[str, dict]
|
tool_input: str
|
||||||
log: str
|
log: str
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredAgentAction(NamedTuple):
|
||||||
|
"""Agent's action to take."""
|
||||||
|
|
||||||
|
tool: str
|
||||||
|
tool_input: dict
|
||||||
|
log: str
|
||||||
|
|
||||||
|
def to_agent_action(self) -> AgentAction:
|
||||||
|
return AgentAction(self.tool, str(self.tool_input), self.log)
|
||||||
|
|
||||||
|
|
||||||
class AgentFinish(NamedTuple):
|
class AgentFinish(NamedTuple):
|
||||||
"""Agent's return value."""
|
"""Agent's return value."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,242 +1,57 @@
|
|||||||
"""Base implementation for tools or skills."""
|
"""Base implementation for tools or skills."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC
|
||||||
from inspect import signature
|
from typing import Any, Dict, Type, Union
|
||||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
|
|
||||||
|
|
||||||
from pydantic import (
|
|
||||||
BaseModel,
|
|
||||||
Extra,
|
|
||||||
Field,
|
|
||||||
create_model,
|
|
||||||
validate_arguments,
|
|
||||||
validator,
|
|
||||||
)
|
|
||||||
from pydantic.main import ModelMetaclass
|
|
||||||
|
|
||||||
from langchain.callbacks import get_callback_manager
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
|
|
||||||
|
|
||||||
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
|
from langchain.tools.structured import BaseStructuredTool
|
||||||
# For backwards compatability, if run_input is a string,
|
|
||||||
# pass as a positional argument.
|
|
||||||
if isinstance(run_input, str):
|
|
||||||
return (run_input,), {}
|
|
||||||
else:
|
|
||||||
return [], run_input
|
|
||||||
|
|
||||||
|
|
||||||
class SchemaAnnotationError(TypeError):
|
class BaseTool(ABC, BaseStructuredTool[str, str]):
|
||||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMetaclass(ModelMetaclass):
|
|
||||||
"""Metaclass for BaseTool to ensure the provided args_schema
|
|
||||||
|
|
||||||
doesn't silently ignored."""
|
|
||||||
|
|
||||||
def __new__(
|
|
||||||
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
|
|
||||||
) -> ToolMetaclass:
|
|
||||||
"""Create the definition of the new tool class."""
|
|
||||||
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
|
|
||||||
if schema_type is not None:
|
|
||||||
schema_annotations = dct.get("__annotations__", {})
|
|
||||||
args_schema_type = schema_annotations.get("args_schema", None)
|
|
||||||
if args_schema_type is None or args_schema_type == BaseModel:
|
|
||||||
# Throw errors for common mis-annotations.
|
|
||||||
# TODO: Use get_args / get_origin and fully
|
|
||||||
# specify valid annotations.
|
|
||||||
typehint_mandate = """
|
|
||||||
class ChildTool(BaseTool):
|
|
||||||
...
|
|
||||||
args_schema: Type[BaseModel] = SchemaClass
|
|
||||||
..."""
|
|
||||||
raise SchemaAnnotationError(
|
|
||||||
f"Tool definition for {name} must include valid type annotations"
|
|
||||||
f" for argument 'args_schema' to behave as expected.\n"
|
|
||||||
f"Expected annotation of 'Type[BaseModel]'"
|
|
||||||
f" but got '{args_schema_type}'.\n"
|
|
||||||
f"Expected class looks like:\n"
|
|
||||||
f"{typehint_mandate}"
|
|
||||||
)
|
|
||||||
# Pass through to Pydantic's metaclass
|
|
||||||
return super().__new__(cls, name, bases, dct)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_subset_model(
|
|
||||||
name: str, model: BaseModel, field_names: list
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
"""Create a pydantic model with only a subset of model's fields."""
|
|
||||||
fields = {
|
|
||||||
field_name: (
|
|
||||||
model.__fields__[field_name].type_,
|
|
||||||
model.__fields__[field_name].default,
|
|
||||||
)
|
|
||||||
for field_name in field_names
|
|
||||||
if field_name in model.__fields__
|
|
||||||
}
|
|
||||||
return create_model(name, **fields) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
|
||||||
"""Get the arguments from a function's signature."""
|
|
||||||
schema = inferred_model.schema()["properties"]
|
|
||||||
valid_keys = signature(func).parameters
|
|
||||||
return {k: schema[k] for k in valid_keys}
|
|
||||||
|
|
||||||
|
|
||||||
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
|
||||||
"""Create a pydantic schema from a function's signature."""
|
|
||||||
inferred_model = validate_arguments(func).model # type: ignore
|
|
||||||
# Pydantic adds placeholder virtual fields we need to strip
|
|
||||||
filtered_args = get_filtered_args(inferred_model, func)
|
|
||||||
return _create_subset_model(
|
|
||||||
f"{model_name}Schema", inferred_model, list(filtered_args)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|
||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
name: str
|
args_schema: Type[str] = str # :meta private:
|
||||||
description: str
|
|
||||||
args_schema: Optional[Type[BaseModel]] = None
|
|
||||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
|
||||||
return_direct: bool = False
|
|
||||||
verbose: bool = False
|
|
||||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
|
||||||
|
|
||||||
class Config:
|
def _parse_input(self, tool_input: Dict) -> str:
|
||||||
"""Configuration for this pydantic object."""
|
"""Load the tool's input into a pydantic model."""
|
||||||
|
if len(tool_input) == 1:
|
||||||
|
# Make base tools more forwards compatible
|
||||||
|
result = next(iter(tool_input.values()))
|
||||||
|
if not isinstance(result, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool input {tool_input} must be a single string or dict."
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
raise ValueError(f"Tool input {tool_input} must be a single string or dict.")
|
||||||
|
|
||||||
extra = Extra.forbid
|
def _wrap_input(self, tool_input: Union[str, Dict]) -> Dict:
|
||||||
arbitrary_types_allowed = True
|
"""Wrap the tool's input into a pydantic model."""
|
||||||
|
|
||||||
@property
|
|
||||||
def args(self) -> dict:
|
|
||||||
if self.args_schema is not None:
|
|
||||||
return self.args_schema.schema()["properties"]
|
|
||||||
else:
|
|
||||||
inferred_model = validate_arguments(self._run).model # type: ignore
|
|
||||||
return get_filtered_args(inferred_model, self._run)
|
|
||||||
|
|
||||||
def _parse_input(
|
|
||||||
self,
|
|
||||||
tool_input: Union[str, Dict],
|
|
||||||
) -> None:
|
|
||||||
"""Convert tool input to pydantic model."""
|
|
||||||
input_args = self.args_schema
|
|
||||||
if isinstance(tool_input, str):
|
if isinstance(tool_input, str):
|
||||||
if input_args is not None:
|
return {"tool_input": tool_input}
|
||||||
key_ = next(iter(input_args.__fields__.keys()))
|
|
||||||
input_args.validate({key_: tool_input})
|
|
||||||
else:
|
else:
|
||||||
if input_args is not None:
|
return tool_input
|
||||||
input_args.validate(tool_input)
|
|
||||||
|
|
||||||
@validator("callback_manager", pre=True, always=True)
|
|
||||||
def set_callback_manager(
|
|
||||||
cls, callback_manager: Optional[BaseCallbackManager]
|
|
||||||
) -> BaseCallbackManager:
|
|
||||||
"""If callback manager is None, set it.
|
|
||||||
|
|
||||||
This allows users to pass in None as callback manager, which is a nice UX.
|
|
||||||
"""
|
|
||||||
return callback_manager or get_callback_manager()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
|
||||||
"""Use the tool."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
|
||||||
"""Use the tool asynchronously."""
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, Dict],
|
tool_input: Union[str, Dict],
|
||||||
verbose: Optional[bool] = None,
|
verbose: bool | None = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: str | None = "green",
|
||||||
color: Optional[str] = "green",
|
color: str | None = "green",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the tool."""
|
"""Use the tool."""
|
||||||
self._parse_input(tool_input)
|
wrapped_input = self._wrap_input(tool_input)
|
||||||
if not self.verbose and verbose is not None:
|
return super().run(wrapped_input, verbose, start_color, color, **kwargs)
|
||||||
verbose_ = verbose
|
|
||||||
else:
|
|
||||||
verbose_ = self.verbose
|
|
||||||
self.callback_manager.on_tool_start(
|
|
||||||
{"name": self.name, "description": self.description},
|
|
||||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
|
||||||
verbose=verbose_,
|
|
||||||
color=start_color,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
|
|
||||||
observation = self._run(*tool_args, **tool_kwargs)
|
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
|
||||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
|
||||||
raise e
|
|
||||||
self.callback_manager.on_tool_end(
|
|
||||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
|
||||||
)
|
|
||||||
return observation
|
|
||||||
|
|
||||||
async def arun(
|
async def arun(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, Dict],
|
tool_input: Union[str, Dict],
|
||||||
verbose: Optional[bool] = None,
|
verbose: bool | None = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: str | None = "green",
|
||||||
color: Optional[str] = "green",
|
color: str | None = "green",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
self._parse_input(tool_input)
|
wrapped_input = self._wrap_input(tool_input)
|
||||||
if not self.verbose and verbose is not None:
|
return await super().arun(wrapped_input, verbose, start_color, color, **kwargs)
|
||||||
verbose_ = verbose
|
|
||||||
else:
|
|
||||||
verbose_ = self.verbose
|
|
||||||
if self.callback_manager.is_async:
|
|
||||||
await self.callback_manager.on_tool_start(
|
|
||||||
{"name": self.name, "description": self.description},
|
|
||||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
|
||||||
verbose=verbose_,
|
|
||||||
color=start_color,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.callback_manager.on_tool_start(
|
|
||||||
{"name": self.name, "description": self.description},
|
|
||||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
|
||||||
verbose=verbose_,
|
|
||||||
color=start_color,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# We then call the tool on the tool input to get an observation
|
|
||||||
args, kwargs = _to_args_and_kwargs(tool_input)
|
|
||||||
observation = await self._arun(*args, **kwargs)
|
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
|
||||||
if self.callback_manager.is_async:
|
|
||||||
await self.callback_manager.on_tool_error(e, verbose=verbose_)
|
|
||||||
else:
|
|
||||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
|
||||||
raise e
|
|
||||||
if self.callback_manager.is_async:
|
|
||||||
await self.callback_manager.on_tool_end(
|
|
||||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.callback_manager.on_tool_end(
|
|
||||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
|
||||||
)
|
|
||||||
return observation
|
|
||||||
|
|
||||||
def __call__(self, tool_input: str) -> str:
|
|
||||||
"""Make tool callable."""
|
|
||||||
return self.run(tool_input)
|
|
||||||
|
|||||||
@@ -3,30 +3,30 @@ from typing import Optional, Type
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.tools.base import BaseTool
|
|
||||||
from langchain.tools.file_management.utils import get_validated_relative_path
|
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||||
|
from langchain.tools.structured import BaseStructuredTool
|
||||||
|
|
||||||
|
|
||||||
class ReadFileInput(BaseModel):
|
class ReadFileInput(BaseModel):
|
||||||
"""Input for ReadFileTool."""
|
"""Input for ReadFileTool."""
|
||||||
|
|
||||||
file_path: str = Field(..., description="name of file")
|
file_path: Path = Field(..., description="name of file")
|
||||||
|
|
||||||
|
|
||||||
class ReadFileTool(BaseTool):
|
class ReadFileTool(BaseStructuredTool[ReadFileInput, str]):
|
||||||
name: str = "read_file"
|
name: str = "read_file"
|
||||||
args_schema: Type[BaseModel] = ReadFileInput
|
args_schema: Type[ReadFileInput] = ReadFileInput
|
||||||
description: str = "Read file from disk"
|
description: str = "Read file from disk"
|
||||||
root_dir: Optional[str] = None
|
root_dir: Optional[str] = None
|
||||||
"""Directory to read file from.
|
"""Directory to read file from.
|
||||||
|
|
||||||
If specified, raises an error for file_paths oustide root_dir."""
|
If specified, raises an error for file_paths oustide root_dir."""
|
||||||
|
|
||||||
def _run(self, file_path: str) -> str:
|
def _run(self, tool_input: ReadFileInput) -> str:
|
||||||
read_path = (
|
read_path = (
|
||||||
get_validated_relative_path(Path(self.root_dir), file_path)
|
get_validated_relative_path(Path(self.root_dir), tool_input.file_path)
|
||||||
if self.root_dir
|
if self.root_dir
|
||||||
else Path(file_path)
|
else tool_input.file_path
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with read_path.open("r", encoding="utf-8") as f:
|
with read_path.open("r", encoding="utf-8") as f:
|
||||||
@@ -35,6 +35,6 @@ class ReadFileTool(BaseTool):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "Error: " + str(e)
|
return "Error: " + str(e)
|
||||||
|
|
||||||
async def _arun(self, tool_input: str) -> str:
|
async def _arun(self, tool_input: ReadFileInput) -> str:
|
||||||
# TODO: Add aiofiles method
|
# TODO: Add aiofiles method
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
def is_relative_to(path: Path, root: Path) -> bool:
|
def is_relative_to(path: Path, root: Path) -> bool:
|
||||||
@@ -14,7 +15,7 @@ def is_relative_to(path: Path, root: Path) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_validated_relative_path(root: Path, user_path: str) -> Path:
|
def get_validated_relative_path(root: Path, user_path: Union[str, Path]) -> Path:
|
||||||
"""Resolve a relative path, raising an error if not within the root directory."""
|
"""Resolve a relative path, raising an error if not within the root directory."""
|
||||||
# Note, this still permits symlinks from outside that point within the root.
|
# Note, this still permits symlinks from outside that point within the root.
|
||||||
# Further validation would be needed if those are to be disallowed.
|
# Further validation would be needed if those are to be disallowed.
|
||||||
|
|||||||
@@ -3,40 +3,40 @@ from typing import Optional, Type
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.tools.base import BaseTool
|
|
||||||
from langchain.tools.file_management.utils import get_validated_relative_path
|
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||||
|
from langchain.tools.structured import BaseStructuredTool
|
||||||
|
|
||||||
|
|
||||||
class WriteFileInput(BaseModel):
|
class WriteFileInput(BaseModel):
|
||||||
"""Input for WriteFileTool."""
|
"""Input for WriteFileTool."""
|
||||||
|
|
||||||
file_path: str = Field(..., description="name of file")
|
file_path: Path = Field(..., description="name of file")
|
||||||
text: str = Field(..., description="text to write to file")
|
text: str = Field(..., description="text to write to file")
|
||||||
|
|
||||||
|
|
||||||
class WriteFileTool(BaseTool):
|
class WriteFileTool(BaseStructuredTool[WriteFileInput, str]):
|
||||||
name: str = "write_file"
|
name: str = "write_file"
|
||||||
args_schema: Type[BaseModel] = WriteFileInput
|
args_schema: Type[WriteFileInput] = WriteFileInput
|
||||||
description: str = "Write file to disk"
|
description: str = "Write file to disk"
|
||||||
root_dir: Optional[str] = None
|
root_dir: Optional[str] = None
|
||||||
"""Directory to write file to.
|
"""Directory to write file to.
|
||||||
|
|
||||||
If specified, raises an error for file_paths oustide root_dir."""
|
If specified, raises an error for file_paths oustide root_dir."""
|
||||||
|
|
||||||
def _run(self, file_path: str, text: str) -> str:
|
def _run(self, tool_input: WriteFileInput) -> str:
|
||||||
write_path = (
|
write_path = (
|
||||||
get_validated_relative_path(Path(self.root_dir), file_path)
|
get_validated_relative_path(Path(self.root_dir), tool_input.file_path)
|
||||||
if self.root_dir
|
if self.root_dir
|
||||||
else Path(file_path)
|
else tool_input.file_path
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
write_path.parent.mkdir(exist_ok=True, parents=False)
|
write_path.parent.mkdir(exist_ok=True, parents=False)
|
||||||
with write_path.open("w", encoding="utf-8") as f:
|
with write_path.open("w", encoding="utf-8") as f:
|
||||||
f.write(text)
|
f.write(tool_input.text)
|
||||||
return f"File written successfully to {file_path}."
|
return f"File written successfully to {tool_input.file_path}."
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "Error: " + str(e)
|
return "Error: " + str(e)
|
||||||
|
|
||||||
async def _arun(self, file_path: str, text: str) -> str:
|
async def _arun(self, tool_input: WriteFileInput) -> str:
|
||||||
# TODO: Add aiofiles method
|
# TODO: Add aiofiles method
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -127,11 +127,11 @@ class ListPowerBITool(BaseTool):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
def _run(self, tool_input: str = "") -> str:
|
||||||
"""Get the names of the tables."""
|
"""Get the names of the tables."""
|
||||||
return ", ".join(self.powerbi.get_table_names())
|
return ", ".join(self.powerbi.get_table_names())
|
||||||
|
|
||||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
async def _arun(self, tool_input: str = "") -> str:
|
||||||
"""Get the names of the tables."""
|
"""Get the names of the tables."""
|
||||||
return ", ".join(self.powerbi.get_table_names())
|
return ", ".join(self.powerbi.get_table_names())
|
||||||
|
|
||||||
|
|||||||
344
langchain/tools/structured.py
Normal file
344
langchain/tools/structured.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from functools import partial
|
||||||
|
from inspect import signature
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Extra,
|
||||||
|
Field,
|
||||||
|
create_model,
|
||||||
|
validate_arguments,
|
||||||
|
validator,
|
||||||
|
)
|
||||||
|
from pydantic.generics import GenericModel
|
||||||
|
from langchain.callbacks import get_callback_manager
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
|
||||||
|
from langchain.utilities.async_utils import async_or_sync_call
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaAnnotationError(TypeError):
|
||||||
|
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA_T = TypeVar("SCHEMA_T", bound=Union[str, BaseModel])
|
||||||
|
OUTPUT_T = TypeVar("OUTPUT_T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStructuredTool(
|
||||||
|
GenericModel,
|
||||||
|
Generic[SCHEMA_T, OUTPUT_T],
|
||||||
|
BaseModel,
|
||||||
|
):
|
||||||
|
"""Parent class for all structured tools."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
return_direct: bool = False
|
||||||
|
verbose: bool = False
|
||||||
|
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||||
|
args_schema: Type[SCHEMA_T] # :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args(self) -> Dict:
|
||||||
|
if isinstance(self.args_schema, BaseModel):
|
||||||
|
return self.args_schema.schema()["properties"]
|
||||||
|
else:
|
||||||
|
return {"tool_input": "str"}
|
||||||
|
|
||||||
|
def _parse_input(self, tool_input: Dict) -> SCHEMA_T:
|
||||||
|
"""Load the tool's input into a pydantic model."""
|
||||||
|
if not issubclass(self.args_schema, BaseModel):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool with args_schema of type {self.args_schema} must overwrite _parse_input."
|
||||||
|
)
|
||||||
|
# Ignore type because mypy doesn't connect the subclass to the generic SCHEMA_T
|
||||||
|
return self.args_schema.parse_obj(tool_input) # type: ignore
|
||||||
|
|
||||||
|
def _get_verbosity(
|
||||||
|
self,
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
) -> bool:
|
||||||
|
if not self.verbose and verbose is not None:
|
||||||
|
verbose_ = verbose
|
||||||
|
else:
|
||||||
|
verbose_ = self.verbose
|
||||||
|
return verbose_
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run(self, input_: SCHEMA_T) -> OUTPUT_T:
|
||||||
|
"""Use the tool."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _arun(self, input_: SCHEMA_T) -> OUTPUT_T:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
tool_input: dict,
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
start_color: Optional[str] = "green",
|
||||||
|
color: Optional[str] = "green",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> OUTPUT_T:
|
||||||
|
"""Run the tool."""
|
||||||
|
parsed_input = self._parse_input(tool_input)
|
||||||
|
verbose_ = self._get_verbosity(verbose)
|
||||||
|
self.callback_manager.on_tool_start(
|
||||||
|
{"name": self.name, "description": self.description},
|
||||||
|
str(tool_input),
|
||||||
|
verbose=verbose_,
|
||||||
|
color=start_color,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
observation = self._run(parsed_input)
|
||||||
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
|
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||||
|
raise e
|
||||||
|
self.callback_manager.on_tool_end(
|
||||||
|
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
|
||||||
|
)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
async def arun(
|
||||||
|
self,
|
||||||
|
tool_input: dict,
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
start_color: Optional[str] = "green",
|
||||||
|
color: Optional[str] = "green",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> OUTPUT_T:
|
||||||
|
"""Run the tool asynchronously."""
|
||||||
|
parsed_input = self._parse_input(tool_input)
|
||||||
|
verbose_ = self._get_verbosity(verbose)
|
||||||
|
await async_or_sync_call(
|
||||||
|
self.callback_manager.on_tool_start,
|
||||||
|
{"name": self.name, "description": self.description},
|
||||||
|
str(parsed_input),
|
||||||
|
verbose=verbose_,
|
||||||
|
color=start_color,
|
||||||
|
is_async=self.callback_manager.is_async,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# We then call the tool on the tool input to get an observation
|
||||||
|
observation = await self._arun(parsed_input)
|
||||||
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
|
await async_or_sync_call(
|
||||||
|
self.callback_manager.on_tool_error,
|
||||||
|
e,
|
||||||
|
verbose=verbose_,
|
||||||
|
is_async=self.callback_manager.is_async,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
await async_or_sync_call(
|
||||||
|
self.callback_manager.on_tool_end,
|
||||||
|
str(observation),
|
||||||
|
verbose=verbose_,
|
||||||
|
color=color,
|
||||||
|
is_async=self.callback_manager.is_async,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def __call__(self, tool_input: dict) -> OUTPUT_T:
|
||||||
|
"""Make tool callable."""
|
||||||
|
return self.run(tool_input)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_subset_model(
|
||||||
|
name: str, model: BaseModel, field_names: list
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
"""Create a pydantic model with only a subset of model's fields."""
|
||||||
|
fields = {
|
||||||
|
field_name: (
|
||||||
|
model.__fields__[field_name].type_,
|
||||||
|
model.__fields__[field_name].default,
|
||||||
|
)
|
||||||
|
for field_name in field_names
|
||||||
|
if field_name in model.__fields__
|
||||||
|
}
|
||||||
|
return create_model(name, **fields) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||||
|
"""Get the arguments from a function's signature."""
|
||||||
|
schema = inferred_model.schema()["properties"]
|
||||||
|
valid_keys = signature(func).parameters
|
||||||
|
return {k: schema[k] for k in valid_keys}
|
||||||
|
|
||||||
|
|
||||||
|
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||||
|
"""Create a pydantic schema from a function's signature."""
|
||||||
|
inferred_model = validate_arguments(func).model # type: ignore
|
||||||
|
# Pydantic adds placeholder virtual fields we need to strip
|
||||||
|
filtered_args = get_filtered_args(inferred_model, func)
|
||||||
|
return _create_subset_model(
|
||||||
|
f"{model_name}Schema", inferred_model, list(filtered_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredTool(BaseStructuredTool[BaseModel, Any]):
|
||||||
|
"""StructuredTool that takes in function or coroutine directly."""
|
||||||
|
|
||||||
|
func: Callable[..., Any]
|
||||||
|
"""The function to run when the tool is called."""
|
||||||
|
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
||||||
|
"""The asynchronous version of the function."""
|
||||||
|
args_schema: Type[BaseModel] # :meta private:
|
||||||
|
|
||||||
|
@validator("func", pre=True, always=True)
|
||||||
|
def validate_func_not_partial(cls, func: Callable) -> Callable:
|
||||||
|
"""Check that the function is not a partial."""
|
||||||
|
if isinstance(func, partial):
|
||||||
|
raise ValueError("Partial functions not yet supported in structured tools.")
|
||||||
|
return func
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args(self) -> dict:
|
||||||
|
if self.args_schema is not None:
|
||||||
|
return self.args_schema.schema()["properties"]
|
||||||
|
else:
|
||||||
|
inferred_model = validate_arguments(self.func).model # type: ignore
|
||||||
|
return get_filtered_args(inferred_model, self.func)
|
||||||
|
|
||||||
|
def _run(self, tool_input: BaseModel) -> Any:
|
||||||
|
"""Use the tool."""
|
||||||
|
return self.func(**tool_input.dict())
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: BaseModel) -> Any:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
if self.coroutine:
|
||||||
|
return await self.coroutine(**tool_input.dict())
|
||||||
|
raise NotImplementedError(f"StructuredTool {self.name} does not support async")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_function(
|
||||||
|
cls,
|
||||||
|
func: Callable[..., Any],
|
||||||
|
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
||||||
|
return_direct: bool = False,
|
||||||
|
args_schema: Optional[Type[BaseModel]] = None,
|
||||||
|
infer_schema: bool = True,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
) -> "StructuredTool":
|
||||||
|
"""Make tools out of functions, can be used with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to run when the tool is called.
|
||||||
|
coroutine: The asynchronous version of the function.
|
||||||
|
return_direct: Whether to return directly from the tool rather
|
||||||
|
than continuing the agent loop.
|
||||||
|
args_schema: optional argument schema for user to specify
|
||||||
|
infer_schema: Whether to infer the schema of the arguments from
|
||||||
|
the function's signature. This also makes the resultant tool
|
||||||
|
accept a dictionary input to its `run()` function.
|
||||||
|
name: The name of the tool. Defaults to the function name.
|
||||||
|
description: The description of the tool. Defaults to the function
|
||||||
|
docstring.
|
||||||
|
"""
|
||||||
|
description = func.__doc__ or description
|
||||||
|
if description is None or not description.strip():
|
||||||
|
raise ValueError(
|
||||||
|
f"Function {func.__name__} must have a docstring, or set description."
|
||||||
|
)
|
||||||
|
name = name or func.__name__
|
||||||
|
_args_schema = args_schema
|
||||||
|
if _args_schema is None and infer_schema:
|
||||||
|
_args_schema = create_schema_from_function(f"{name}Schema", func)
|
||||||
|
description = f"{name}{signature(func)} - {description}"
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
func=func,
|
||||||
|
coroutine=coroutine,
|
||||||
|
return_direct=return_direct,
|
||||||
|
args_schema=_args_schema,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def structured_tool(
|
||||||
|
*args: Union[str, Callable],
|
||||||
|
return_direct: bool = False,
|
||||||
|
args_schema: Optional[Type[BaseModel]] = None,
|
||||||
|
infer_schema: bool = True,
|
||||||
|
) -> Callable:
|
||||||
|
"""Make tools out of functions, can be used with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: The arguments to the tool.
|
||||||
|
return_direct: Whether to return directly from the tool rather
|
||||||
|
than continuing the agent loop.
|
||||||
|
args_schema: optional argument schema for user to specify
|
||||||
|
infer_schema: Whether to infer the schema of the arguments from
|
||||||
|
the function's signature. This also makes the resultant tool
|
||||||
|
accept a dictionary input to its `run()` function.
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- Function must be of type (str) -> str
|
||||||
|
- Function must have a docstring
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
# Searches the API for the query.
|
||||||
|
return
|
||||||
|
|
||||||
|
@tool("search", return_direct=True)
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
# Searches the API for the query.
|
||||||
|
return
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _make_with_name(tool_name: str) -> Callable:
|
||||||
|
def _make_tool(func: Callable) -> StructuredTool:
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name=tool_name,
|
||||||
|
func=func,
|
||||||
|
args_schema=args_schema,
|
||||||
|
return_direct=return_direct,
|
||||||
|
infer_schema=infer_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _make_tool
|
||||||
|
|
||||||
|
if len(args) == 1 and isinstance(args[0], str):
|
||||||
|
# if the argument is a string, then we use the string as the tool name
|
||||||
|
# Example usage: @tool("search", return_direct=True)
|
||||||
|
return _make_with_name(args[0])
|
||||||
|
elif len(args) == 1 and callable(args[0]):
|
||||||
|
# if the argument is a function, then we use the function name as the tool name
|
||||||
|
# Example usage: @tool
|
||||||
|
return _make_with_name(args[0].__name__)(args[0])
|
||||||
|
elif len(args) == 0:
|
||||||
|
# if there are no arguments, then we use the function name as the tool name
|
||||||
|
# Example usage: @tool(return_direct=True)
|
||||||
|
def _partial(func: Callable[[str], str]) -> BaseStructuredTool:
|
||||||
|
return _make_with_name(func.__name__)(func)
|
||||||
|
|
||||||
|
return _partial
|
||||||
|
else:
|
||||||
|
raise ValueError("Too many arguments for tool decorator")
|
||||||
12
langchain/utilities/async_utils.py
Normal file
12
langchain/utilities/async_utils.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""Async utilities."""
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
async def async_or_sync_call(
|
||||||
|
method: Callable, *args: Any, is_async: bool, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Run the callback manager method asynchronously or synchronously."""
|
||||||
|
if is_async:
|
||||||
|
return await method(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return method(*args, **kwargs)
|
||||||
@@ -1,9 +1,8 @@
|
|||||||
"""Util that calls OpenWeatherMap using PyOWM."""
|
"""Util that calls OpenWeatherMap using PyOWM."""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.tools.base import BaseModel
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
tests/regression_tests/__init__.py
Normal file
0
tests/regression_tests/__init__.py
Normal file
148
tests/regression_tests/multi_input_tools.py
Normal file
148
tests/regression_tests/multi_input_tools.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""Test the BaseOutputParser class and its sub-classes."""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents import initialize_agent
|
||||||
|
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
|
||||||
|
from langchain.agents.agent_toolkits.nla.toolkit import NLAToolkit
|
||||||
|
from langchain.agents.agent_toolkits.openapi.toolkit import RequestsToolkit
|
||||||
|
from langchain.agents.agent_types import AgentType
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
from langchain.memory.buffer import ConversationBufferMemory
|
||||||
|
from langchain.requests import TextRequestsWrapper
|
||||||
|
from langchain.schema import BaseLanguageModel
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
from langchain.tools.json.tool import JsonSpec
|
||||||
|
|
||||||
|
|
||||||
|
def _get_requests_tools_and_questions(**kwargs) -> List[Tuple[BaseTool, List[str]]]:
|
||||||
|
requests_wrapper = TextRequestsWrapper()
|
||||||
|
requests_toolkit = RequestsToolkit(requests_wrapper=requests_wrapper)
|
||||||
|
tools = requests_toolkit.get_tools()
|
||||||
|
tools_dict = {tool.name: tool for tool in tools}
|
||||||
|
method_to_questions = {
|
||||||
|
# "get": ["Get the header of google.com"],
|
||||||
|
"post": ["Post data {'key': 'value'} to google.com"],
|
||||||
|
"patch": ["Patch data {'key': 'value'} to google.com"],
|
||||||
|
"put": ["Put data {'key': 'value'} to google.com"],
|
||||||
|
"delete": ["Delete data with ID 1234abc from google.com"],
|
||||||
|
}
|
||||||
|
results = []
|
||||||
|
for method, qs in method_to_questions.items():
|
||||||
|
results.append((tools_dict[f"requests_{method}"], qs))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _get_json_tools_and_questions(**kwargs) -> List[Tuple[BaseTool, List[str]]]:
|
||||||
|
spec = JsonSpec.from_file(
|
||||||
|
Path("tests/unit_tests/tools/openapi/test_specs/apis-guru/apispec.json")
|
||||||
|
)
|
||||||
|
json_toolkit = JsonToolkit(spec=spec)
|
||||||
|
list_keys, get_value = json_toolkit.get_tools()
|
||||||
|
return [
|
||||||
|
(list_keys, "What keys are in the JSON spec?"),
|
||||||
|
(get_value, "What's in the info.description?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_nla_tools_nad_questions(
|
||||||
|
*,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
) -> List[Tuple[BaseTool, List[str]]]:
|
||||||
|
speak_toolkit = NLAToolkit.from_llm_and_url(
|
||||||
|
llm, "https://api.speak.com/openapi.yaml"
|
||||||
|
)
|
||||||
|
# TODO: make more pointed questions
|
||||||
|
speak_tools_and_questions = [
|
||||||
|
(tool, ["Could you help me learn something new in Spanish?"])
|
||||||
|
for tool in speak_toolkit.get_tools()
|
||||||
|
]
|
||||||
|
klarna_toolkit = NLAToolkit.from_llm_and_url(
|
||||||
|
llm, "https://www.klarna.com/us/shopping/public/openai/v0/api-docs/"
|
||||||
|
)
|
||||||
|
klarna_tools_and_questions = [
|
||||||
|
(tool, ["I want to buy some cheap shoes"])
|
||||||
|
for tool in klarna_toolkit.get_tools()
|
||||||
|
]
|
||||||
|
return speak_tools_and_questions + klarna_tools_and_questions
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tuples() -> (
|
||||||
|
List[Tuple[BaseTool, List[str], BaseLanguageModel, AgentType, bool]]
|
||||||
|
):
|
||||||
|
"""Grid test."""
|
||||||
|
llms = [
|
||||||
|
# ChatOpenAI(),
|
||||||
|
OpenAI(),
|
||||||
|
]
|
||||||
|
generators = [
|
||||||
|
# _get_nla_tools_nad_questions,
|
||||||
|
# _get_json_tools_and_questions,
|
||||||
|
_get_requests_tools_and_questions,
|
||||||
|
]
|
||||||
|
# These types don't really support arbitrary single tools...
|
||||||
|
# excluded_types = (AgentType.SELF_ASK_WITH_SEARCH, AgentType.REACT_DOCSTORE)
|
||||||
|
|
||||||
|
# agent_types = [
|
||||||
|
# agent_type for agent_type in AgentType if agent_type not in excluded_types
|
||||||
|
# ]
|
||||||
|
agent_types = [
|
||||||
|
# AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
|
||||||
|
AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
]
|
||||||
|
results = []
|
||||||
|
for llm in llms:
|
||||||
|
for agent_type in agent_types:
|
||||||
|
for generator in generators:
|
||||||
|
tools_and_queries = generator(llm=llm)
|
||||||
|
for tool, queries in tools_and_queries:
|
||||||
|
results.append((tool, queries, llm, agent_type))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
_AGGREGATE_AXES = ["tool", "llm", "agent_type"]
|
||||||
|
_FAILURE_COUNT = {k: defaultdict(int) for k in _AGGREGATE_AXES}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tool, queries, llm, agent_type", generate_tuples())
|
||||||
|
def test_run_tool(
|
||||||
|
tool: BaseTool,
|
||||||
|
queries: List[str],
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
agent_type: AgentType,
|
||||||
|
) -> None:
|
||||||
|
global _FAILURE_COUNT
|
||||||
|
tool = deepcopy(tool)
|
||||||
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
||||||
|
agent = initialize_agent(
|
||||||
|
llm=llm,
|
||||||
|
tools=[tool],
|
||||||
|
agent=agent_type,
|
||||||
|
memory=memory,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for query in queries:
|
||||||
|
try:
|
||||||
|
result = agent(query)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
results.append(e)
|
||||||
|
|
||||||
|
type_errors = [r for r in results if isinstance(r, TypeError)]
|
||||||
|
if type_errors:
|
||||||
|
print(f"{str(llm)}: {tool.name} failed with: {type_errors}")
|
||||||
|
_FAILURE_COUNT["tool"][tool.name] += 1
|
||||||
|
_FAILURE_COUNT["llm"][str(llm)] += 1
|
||||||
|
_FAILURE_COUNT["agent_type"][str(agent_type)] += 1
|
||||||
|
|
||||||
|
assert not type_errors, type_errors
|
||||||
|
validation_errors = [r for r in results if isinstance(r, ValidationError)]
|
||||||
|
assert not validation_errors, validation_errors
|
||||||
Reference in New Issue
Block a user