Compare commits

...

9 Commits

Author SHA1 Message Date
vowelparrot
d7e9b72380 Merge branch 'master' into vwp/chatregtests 2023-04-25 17:15:55 -07:00
vowelparrot
203e97d789 update 2023-04-25 16:14:26 -07:00
vowelparrot
f99348fb12 foo 2023-04-25 15:43:46 -07:00
vowelparrot
6bc9700863 undo 2023-04-25 15:41:10 -07:00
vowelparrot
4abeea3d42 don't like 2023-04-25 14:36:55 -07:00
vowelparrot
e79e003218 Hm 2023-04-25 14:10:33 -07:00
vowelparrot
37fca4ae75 foo 2023-04-25 11:49:34 -07:00
vowelparrot
c90ce64757 hm 2023-04-25 10:14:33 -07:00
vowelparrot
4cd6a2223a MITRT 2023-04-25 10:14:13 -07:00
13 changed files with 722 additions and 267 deletions

View File

@@ -0,0 +1,145 @@
"""Chain that takes in an input and produces an action and action input."""
from __future__ import annotations
import json
import logging
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import yaml
from pydantic import BaseModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import (
StructuredAgentAction,
AgentFinish,
BaseLanguageModel,
)
from langchain.tools.base import BaseTool
logger = logging.getLogger(__name__)
class BaseSingleActionAgent(BaseModel):
"""Base Agent class."""
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]:
return None
@abstractmethod
def plan(
self, intermediate_steps: List[Tuple[StructuredAgentAction, str]], **kwargs: Any
) -> Union[StructuredAgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
@abstractmethod
async def aplan(
self, intermediate_steps: List[Tuple[StructuredAgentAction, str]], **kwargs: Any
) -> Union[StructuredAgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[StructuredAgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations."""
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, ""
)
else:
raise ValueError(
f"Got unsupported early_stopping_method `{early_stopping_method}`"
)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
raise NotImplementedError
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
_dict["_type"] = str(self._agent_type)
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the agent.
Args:
file_path: Path to file to save the agent to.
Example:
.. code-block:: python
# If working with agent executor
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
agent_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(agent_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(agent_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict:
return {}

View File

@@ -1,14 +1,12 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from typing import Any, Awaitable, Callable, Optional, Union
from pydantic import BaseModel, validate_arguments, validator
from pydantic import validator
from langchain.tools.base import (
BaseTool,
create_schema_from_function,
get_filtered_args,
)
@@ -28,22 +26,14 @@ class Tool(BaseTool):
raise ValueError("Partial functions not yet supported in tools.")
return func
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: str) -> str:
"""Use the tool."""
return self.func(*args, **kwargs)
return self.func(tool_input)
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
return await self.coroutine(tool_input)
raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future
@@ -74,8 +64,6 @@ class InvalidTool(BaseTool):
def tool(
*args: Union[str, Callable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@@ -83,10 +71,6 @@ def tool(
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
Requires:
- Function must be of type (str) -> str
@@ -112,13 +96,9 @@ def tool(
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
tool_ = Tool(
name=tool_name,
func=func,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
)

View File

@@ -41,10 +41,21 @@ class AgentAction(NamedTuple):
"""Agent's action to take."""
tool: str
tool_input: Union[str, dict]
tool_input: str
log: str
class StructuredAgentAction(NamedTuple):
"""Agent's action to take."""
tool: str
tool_input: dict
log: str
def to_agent_action(self) -> AgentAction:
return AgentAction(self.tool, str(self.tool_input), self.log)
class AgentFinish(NamedTuple):
"""Agent's return value."""

View File

@@ -1,242 +1,57 @@
"""Base implementation for tools or skills."""
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import (
BaseModel,
Extra,
Field,
create_model,
validate_arguments,
validator,
)
from pydantic.main import ModelMetaclass
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from abc import ABC
from typing import Any, Dict, Type, Union
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input
from langchain.tools.structured import BaseStructuredTool
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
class ToolMetaclass(ModelMetaclass):
"""Metaclass for BaseTool to ensure the provided args_schema
doesn't silently ignored."""
def __new__(
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
) -> ToolMetaclass:
"""Create the definition of the new tool class."""
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
if schema_type is not None:
schema_annotations = dct.get("__annotations__", {})
args_schema_type = schema_annotations.get("args_schema", None)
if args_schema_type is None or args_schema_type == BaseModel:
# Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully
# specify valid annotations.
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
..."""
raise SchemaAnnotationError(
f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
)
# Pass through to Pydantic's metaclass
return super().__new__(cls, name, bases, dct)
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
)
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
class BaseTool(ABC, BaseStructuredTool[str, str]):
"""Interface LangChain tools must implement."""
name: str
description: str
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
args_schema: Type[str] = str # :meta private:
class Config:
"""Configuration for this pydantic object."""
def _parse_input(self, tool_input: Dict) -> str:
"""Load the tool's input into a pydantic model."""
if len(tool_input) == 1:
# Make base tools more forwards compatible
result = next(iter(tool_input.values()))
if not isinstance(result, str):
raise ValueError(
f"Tool input {tool_input} must be a single string or dict."
)
return result
raise ValueError(f"Tool input {tool_input} must be a single string or dict.")
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self._run).model # type: ignore
return get_filtered_args(inferred_model, self._run)
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> None:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
def _wrap_input(self, tool_input: Union[str, Dict]) -> Dict:
"""Wrap the tool's input into a pydantic model."""
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return {"tool_input": tool_input}
else:
if input_args is not None:
input_args.validate(tool_input)
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously."""
return tool_input
def run(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
verbose: bool | None = None,
start_color: str | None = "green",
color: str | None = "green",
**kwargs: Any,
) -> str:
"""Run the tool."""
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
"""Use the tool."""
wrapped_input = self._wrap_input(tool_input)
return super().run(wrapped_input, verbose, start_color, color, **kwargs)
async def arun(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
verbose: bool | None = None,
start_color: str | None = "green",
color: str | None = "green",
**kwargs: Any,
) -> str:
"""Run the tool asynchronously."""
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
else:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
else:
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
return self.run(tool_input)
"""Use the tool asynchronously."""
wrapped_input = self._wrap_input(tool_input)
return await super().arun(wrapped_input, verbose, start_color, color, **kwargs)

View File

@@ -3,30 +3,30 @@ from typing import Optional, Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
from langchain.tools.file_management.utils import get_validated_relative_path
from langchain.tools.structured import BaseStructuredTool
class ReadFileInput(BaseModel):
"""Input for ReadFileTool."""
file_path: str = Field(..., description="name of file")
file_path: Path = Field(..., description="name of file")
class ReadFileTool(BaseTool):
class ReadFileTool(BaseStructuredTool[ReadFileInput, str]):
name: str = "read_file"
args_schema: Type[BaseModel] = ReadFileInput
args_schema: Type[ReadFileInput] = ReadFileInput
description: str = "Read file from disk"
root_dir: Optional[str] = None
"""Directory to read file from.
If specified, raises an error for file_paths oustide root_dir."""
def _run(self, file_path: str) -> str:
def _run(self, tool_input: ReadFileInput) -> str:
read_path = (
get_validated_relative_path(Path(self.root_dir), file_path)
get_validated_relative_path(Path(self.root_dir), tool_input.file_path)
if self.root_dir
else Path(file_path)
else tool_input.file_path
)
try:
with read_path.open("r", encoding="utf-8") as f:
@@ -35,6 +35,6 @@ class ReadFileTool(BaseTool):
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
async def _arun(self, tool_input: ReadFileInput) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

View File

@@ -1,5 +1,6 @@
import sys
from pathlib import Path
from typing import Union
def is_relative_to(path: Path, root: Path) -> bool:
@@ -14,7 +15,7 @@ def is_relative_to(path: Path, root: Path) -> bool:
return False
def get_validated_relative_path(root: Path, user_path: str) -> Path:
def get_validated_relative_path(root: Path, user_path: Union[str, Path]) -> Path:
"""Resolve a relative path, raising an error if not within the root directory."""
# Note, this still permits symlinks from outside that point within the root.
# Further validation would be needed if those are to be disallowed.

View File

@@ -3,40 +3,40 @@ from typing import Optional, Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
from langchain.tools.file_management.utils import get_validated_relative_path
from langchain.tools.structured import BaseStructuredTool
class WriteFileInput(BaseModel):
"""Input for WriteFileTool."""
file_path: str = Field(..., description="name of file")
file_path: Path = Field(..., description="name of file")
text: str = Field(..., description="text to write to file")
class WriteFileTool(BaseTool):
class WriteFileTool(BaseStructuredTool[WriteFileInput, str]):
name: str = "write_file"
args_schema: Type[BaseModel] = WriteFileInput
args_schema: Type[WriteFileInput] = WriteFileInput
description: str = "Write file to disk"
root_dir: Optional[str] = None
"""Directory to write file to.
If specified, raises an error for file_paths oustide root_dir."""
def _run(self, file_path: str, text: str) -> str:
def _run(self, tool_input: WriteFileInput) -> str:
write_path = (
get_validated_relative_path(Path(self.root_dir), file_path)
get_validated_relative_path(Path(self.root_dir), tool_input.file_path)
if self.root_dir
else Path(file_path)
else tool_input.file_path
)
try:
write_path.parent.mkdir(exist_ok=True, parents=False)
with write_path.open("w", encoding="utf-8") as f:
f.write(text)
return f"File written successfully to {file_path}."
f.write(tool_input.text)
return f"File written successfully to {tool_input.file_path}."
except Exception as e:
return "Error: " + str(e)
async def _arun(self, file_path: str, text: str) -> str:
async def _arun(self, tool_input: WriteFileInput) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

View File

@@ -127,11 +127,11 @@ class ListPowerBITool(BaseTool):
arbitrary_types_allowed = True
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: str = "") -> str:
"""Get the names of the tables."""
return ", ".join(self.powerbi.get_table_names())
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: str = "") -> str:
"""Get the names of the tables."""
return ", ".join(self.powerbi.get_table_names())

View File

@@ -0,0 +1,344 @@
from __future__ import annotations
from abc import abstractmethod
from functools import partial
from inspect import signature
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
Optional,
Type,
TypeVar,
Union,
)
from pydantic import (
BaseModel,
Extra,
Field,
create_model,
validate_arguments,
validator,
)
from pydantic.generics import GenericModel
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.utilities.async_utils import async_or_sync_call
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
SCHEMA_T = TypeVar("SCHEMA_T", bound=Union[str, BaseModel])
OUTPUT_T = TypeVar("OUTPUT_T")
class BaseStructuredTool(
GenericModel,
Generic[SCHEMA_T, OUTPUT_T],
BaseModel,
):
"""Parent class for all structured tools."""
name: str
description: str
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
args_schema: Type[SCHEMA_T] # :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> Dict:
if isinstance(self.args_schema, BaseModel):
return self.args_schema.schema()["properties"]
else:
return {"tool_input": "str"}
def _parse_input(self, tool_input: Dict) -> SCHEMA_T:
"""Load the tool's input into a pydantic model."""
if not issubclass(self.args_schema, BaseModel):
raise ValueError(
f"Tool with args_schema of type {self.args_schema} must overwrite _parse_input."
)
# Ignore type because mypy doesn't connect the subclass to the generic SCHEMA_T
return self.args_schema.parse_obj(tool_input) # type: ignore
def _get_verbosity(
self,
verbose: Optional[bool] = None,
) -> bool:
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
return verbose_
@abstractmethod
def _run(self, input_: SCHEMA_T) -> OUTPUT_T:
"""Use the tool."""
@abstractmethod
async def _arun(self, input_: SCHEMA_T) -> OUTPUT_T:
"""Use the tool asynchronously."""
def run(
self,
tool_input: dict,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUTPUT_T:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
verbose_ = self._get_verbosity(verbose)
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
observation = self._run(parsed_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
async def arun(
self,
tool_input: dict,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUTPUT_T:
"""Run the tool asynchronously."""
parsed_input = self._parse_input(tool_input)
verbose_ = self._get_verbosity(verbose)
await async_or_sync_call(
self.callback_manager.on_tool_start,
{"name": self.name, "description": self.description},
str(parsed_input),
verbose=verbose_,
color=start_color,
is_async=self.callback_manager.is_async,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(parsed_input)
except (Exception, KeyboardInterrupt) as e:
await async_or_sync_call(
self.callback_manager.on_tool_error,
e,
verbose=verbose_,
is_async=self.callback_manager.is_async,
)
raise e
await async_or_sync_call(
self.callback_manager.on_tool_end,
str(observation),
verbose=verbose_,
color=color,
is_async=self.callback_manager.is_async,
**kwargs,
)
return observation
def __call__(self, tool_input: dict) -> OUTPUT_T:
"""Make tool callable."""
return self.run(tool_input)
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
)
class StructuredTool(BaseStructuredTool[BaseModel, Any]):
"""StructuredTool that takes in function or coroutine directly."""
func: Callable[..., Any]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
args_schema: Type[BaseModel] # :meta private:
@validator("func", pre=True, always=True)
def validate_func_not_partial(cls, func: Callable) -> Callable:
"""Check that the function is not a partial."""
if isinstance(func, partial):
raise ValueError("Partial functions not yet supported in structured tools.")
return func
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)
def _run(self, tool_input: BaseModel) -> Any:
"""Use the tool."""
return self.func(**tool_input.dict())
async def _arun(self, tool_input: BaseModel) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(**tool_input.dict())
raise NotImplementedError(f"StructuredTool {self.name} does not support async")
@classmethod
def from_function(
cls,
func: Callable[..., Any],
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
name: Optional[str] = None,
description: Optional[str] = None,
) -> "StructuredTool":
"""Make tools out of functions, can be used with or without arguments.
Args:
func: The function to run when the tool is called.
coroutine: The asynchronous version of the function.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
name: The name of the tool. Defaults to the function name.
description: The description of the tool. Defaults to the function
docstring.
"""
description = func.__doc__ or description
if description is None or not description.strip():
raise ValueError(
f"Function {func.__name__} must have a docstring, or set description."
)
name = name or func.__name__
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
description = f"{name}{signature(func)} - {description}"
return cls(
name=name,
func=func,
coroutine=coroutine,
return_direct=return_direct,
args_schema=_args_schema,
description=description,
)
def structured_tool(
*args: Union[str, Callable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
Args:
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
Requires:
- Function must be of type (str) -> str
- Function must have a docstring
Examples:
.. code-block:: python
@tool
def search_api(query: str) -> str:
# Searches the API for the query.
return
@tool("search", return_direct=True)
def search_api(query: str) -> str:
# Searches the API for the query.
return
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> StructuredTool:
return StructuredTool.from_function(
name=tool_name,
func=func,
args_schema=args_schema,
return_direct=return_direct,
infer_schema=infer_schema,
)
return _make_tool
if len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])
elif len(args) == 1 and callable(args[0]):
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool
return _make_with_name(args[0].__name__)(args[0])
elif len(args) == 0:
# if there are no arguments, then we use the function name as the tool name
# Example usage: @tool(return_direct=True)
def _partial(func: Callable[[str], str]) -> BaseStructuredTool:
return _make_with_name(func.__name__)(func)
return _partial
else:
raise ValueError("Too many arguments for tool decorator")

View File

@@ -0,0 +1,12 @@
"""Async utilities."""
from typing import Any, Callable
async def async_or_sync_call(
method: Callable, *args: Any, is_async: bool, **kwargs: Any
) -> Any:
"""Run the callback manager method asynchronously or synchronously."""
if is_async:
return await method(*args, **kwargs)
else:
return method(*args, **kwargs)

View File

@@ -1,9 +1,8 @@
"""Util that calls OpenWeatherMap using PyOWM."""
from typing import Any, Dict, Optional
from pydantic import Extra, root_validator
from pydantic import BaseModel, Extra, root_validator
from langchain.tools.base import BaseModel
from langchain.utils import get_from_dict_or_env

View File

View File

@@ -0,0 +1,148 @@
"""Test the BaseOutputParser class and its sub-classes."""
from collections import defaultdict
import json
from copy import deepcopy
from pathlib import Path
from typing import List, Tuple
from pydantic import ValidationError
import pytest
from langchain.agents import initialize_agent
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.nla.toolkit import NLAToolkit
from langchain.agents.agent_toolkits.openapi.toolkit import RequestsToolkit
from langchain.agents.agent_types import AgentType
from langchain.llms.openai import OpenAI
from langchain.memory.buffer import ConversationBufferMemory
from langchain.requests import TextRequestsWrapper
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.json.tool import JsonSpec
def _get_requests_tools_and_questions(**kwargs) -> List[Tuple[BaseTool, List[str]]]:
requests_wrapper = TextRequestsWrapper()
requests_toolkit = RequestsToolkit(requests_wrapper=requests_wrapper)
tools = requests_toolkit.get_tools()
tools_dict = {tool.name: tool for tool in tools}
method_to_questions = {
# "get": ["Get the header of google.com"],
"post": ["Post data {'key': 'value'} to google.com"],
"patch": ["Patch data {'key': 'value'} to google.com"],
"put": ["Put data {'key': 'value'} to google.com"],
"delete": ["Delete data with ID 1234abc from google.com"],
}
results = []
for method, qs in method_to_questions.items():
results.append((tools_dict[f"requests_{method}"], qs))
return results
def _get_json_tools_and_questions(**kwargs) -> List[Tuple[BaseTool, List[str]]]:
spec = JsonSpec.from_file(
Path("tests/unit_tests/tools/openapi/test_specs/apis-guru/apispec.json")
)
json_toolkit = JsonToolkit(spec=spec)
list_keys, get_value = json_toolkit.get_tools()
return [
(list_keys, "What keys are in the JSON spec?"),
(get_value, "What's in the info.description?"),
]
def _get_nla_tools_nad_questions(
*,
llm: BaseLanguageModel,
) -> List[Tuple[BaseTool, List[str]]]:
speak_toolkit = NLAToolkit.from_llm_and_url(
llm, "https://api.speak.com/openapi.yaml"
)
# TODO: make more pointed questions
speak_tools_and_questions = [
(tool, ["Could you help me learn something new in Spanish?"])
for tool in speak_toolkit.get_tools()
]
klarna_toolkit = NLAToolkit.from_llm_and_url(
llm, "https://www.klarna.com/us/shopping/public/openai/v0/api-docs/"
)
klarna_tools_and_questions = [
(tool, ["I want to buy some cheap shoes"])
for tool in klarna_toolkit.get_tools()
]
return speak_tools_and_questions + klarna_tools_and_questions
def generate_tuples() -> (
List[Tuple[BaseTool, List[str], BaseLanguageModel, AgentType, bool]]
):
"""Grid test."""
llms = [
# ChatOpenAI(),
OpenAI(),
]
generators = [
# _get_nla_tools_nad_questions,
# _get_json_tools_and_questions,
_get_requests_tools_and_questions,
]
# These types don't really support arbitrary single tools...
# excluded_types = (AgentType.SELF_ASK_WITH_SEARCH, AgentType.REACT_DOCSTORE)
# agent_types = [
# agent_type for agent_type in AgentType if agent_type not in excluded_types
# ]
agent_types = [
# AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
]
results = []
for llm in llms:
for agent_type in agent_types:
for generator in generators:
tools_and_queries = generator(llm=llm)
for tool, queries in tools_and_queries:
results.append((tool, queries, llm, agent_type))
return results
_AGGREGATE_AXES = ["tool", "llm", "agent_type"]
_FAILURE_COUNT = {k: defaultdict(int) for k in _AGGREGATE_AXES}
@pytest.mark.parametrize("tool, queries, llm, agent_type", generate_tuples())
def test_run_tool(
tool: BaseTool,
queries: List[str],
llm: BaseLanguageModel,
agent_type: AgentType,
) -> None:
global _FAILURE_COUNT
tool = deepcopy(tool)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
agent = initialize_agent(
llm=llm,
tools=[tool],
agent=agent_type,
memory=memory,
verbose=True,
)
results = []
for query in queries:
try:
result = agent(query)
results.append(result)
except Exception as e:
results.append(e)
type_errors = [r for r in results if isinstance(r, TypeError)]
if type_errors:
print(f"{str(llm)}: {tool.name} failed with: {type_errors}")
_FAILURE_COUNT["tool"][tool.name] += 1
_FAILURE_COUNT["llm"][str(llm)] += 1
_FAILURE_COUNT["agent_type"][str(agent_type)] += 1
assert not type_errors, type_errors
validation_errors = [r for r in results if isinstance(r, ValidationError)]
assert not validation_errors, validation_errors