This commit is contained in:
vowelparrot
2023-04-25 15:43:46 -07:00
parent 6bc9700863
commit f99348fb12
2 changed files with 146 additions and 297 deletions

View File

@@ -1,161 +1,21 @@
"""Base implementation for tools or skills."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, TypeVar
from abc import ABC
from typing import Type
from pydantic import (
Extra,
Field,
validator,
BaseModel,
)
from pydantic.generics import GenericModel
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.utilities.async_utils import async_or_sync_call
from langchain.tools.structured import BaseStructuredTool
IN_T = TypeVar("IN_T")
OUT_T = TypeVar("OUT_T")
class ToolMixin(GenericModel, Generic[IN_T, OUT_T]):
class BaseTool(ABC, BaseStructuredTool[str, str, str]):
"""Interface LangChain tools must implement."""
name: str
description: str
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
args_schema: Type[str] = str # :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
def _get_verbosity(
self,
verbose: Optional[bool] = None,
) -> bool:
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
return verbose_
@abstractmethod
def run(
self,
tool_input: IN_T,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUT_T:
"""Use the tool."""
@abstractmethod
async def arun(
self,
tool_input: IN_T,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUT_T:
"""Use the tool asynchronously."""
def __call__(self, tool_input: IN_T) -> OUT_T:
"""Make tool callable."""
return self.run(tool_input)
class BaseTool(ABC, ToolMixin[str, str]):
"""Interface LangChain tools must implement."""
@abstractmethod
def _run(self, tool_input: str) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
def run(
self,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> str:
"""Run the tool."""
verbose_ = self._get_verbosity(verbose)
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
observation = self._run(tool_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
async def arun(
self,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> str:
"""Run the tool asynchronously."""
verbose_ = self._get_verbosity(verbose)
await async_or_sync_call(
self.callback_manager.on_tool_start,
{"name": self.name, "description": self.description},
tool_input,
verbose=verbose_,
color=start_color,
is_async=self.callback_manager.is_async,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(tool_input)
except (Exception, KeyboardInterrupt) as e:
await async_or_sync_call(
self.callback_manager.on_tool_error,
e,
verbose=verbose_,
is_async=self.callback_manager.is_async,
)
raise e
await async_or_sync_call(
self.callback_manager.on_tool_end,
observation,
verbose=verbose_,
color=color,
is_async=self.callback_manager.is_async,
**kwargs,
)
return observation
def _parse_input(self, tool_input: str) -> str:
"""Load the tool's input into a pydantic model."""
return tool_input

View File

@@ -3,16 +3,13 @@ from __future__ import annotations
from abc import abstractmethod
from functools import partial
from inspect import signature
import inspect
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
@@ -21,14 +18,18 @@ from typing import (
from pydantic import (
BaseModel,
Extra,
Field,
create_model,
validate_arguments,
validator,
)
from pydantic.generics import GenericModel
from pydantic.main import ModelMetaclass
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.tools.base import BaseTool, ToolMixin
from langchain.tools.base import BaseTool
from langchain.utilities.async_utils import async_or_sync_call
@@ -36,38 +37,137 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
class ToolMetaclass(ModelMetaclass):
"""Metaclass for BaseTool to ensure the provided args_schema
INPUT_T = TypeVar("INPUT_T")
SCHEMA_T = TypeVar("SCHEMA_T", bound=Union[str, BaseModel])
OUTPUT_T = TypeVar("OUTPUT_T")
doesn't silently ignored."""
def __new__(
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
) -> ToolMetaclass:
"""Create the definition of the new tool class."""
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
if schema_type is not None:
schema_annotations = dct.get("__annotations__", {})
args_schema_type = schema_annotations.get("args_schema", None)
if args_schema_type is None or args_schema_type == BaseModel:
# Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully
# specify valid annotations.
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
..."""
raise SchemaAnnotationError(
f"StructuredTool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
)
# Pass through to Pydantic's metaclass
return super().__new__(cls, name, bases, dct)
class BaseStructuredTool(
GenericModel,
Generic[INPUT_T, SCHEMA_T, OUTPUT_T],
BaseModel,
):
"""Parent class for all structured tools."""
name: str
description: str
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
args_schema: Type[SCHEMA_T] # :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> dict:
if isinstance(self.args_schema, BaseModel):
return self.args_schema.schema()["properties"]
else:
return {"tool_input": "str"}
def _parse_input(self, tool_input: Union[INPUT_T, Any]) -> SCHEMA_T:
"""Load the tool's input into a pydantic model."""
if not issubclass(self.args_schema, BaseModel):
raise ValueError(
f"Tool with args_schema of type {self.args_schema} must overwrite _parse_input."
)
# Ignore type because mypy doesn't connect the subclass to the generic SCHEMA_T
return self.args_schema.parse_obj(tool_input) # type: ignore
def _get_verbosity(
self,
verbose: Optional[bool] = None,
) -> bool:
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
return verbose_
@abstractmethod
def _run(self, input_: SCHEMA_T) -> OUTPUT_T:
"""Use the tool."""
@abstractmethod
async def _arun(self, input_: SCHEMA_T) -> OUTPUT_T:
"""Use the tool asynchronously."""
def run(
self,
tool_input: INPUT_T,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUTPUT_T:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
verbose_ = self._get_verbosity(verbose)
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
observation = self._run(parsed_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
async def arun(
self,
tool_input: INPUT_T,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUTPUT_T:
"""Run the tool asynchronously."""
parsed_input = self._parse_input(tool_input)
verbose_ = self._get_verbosity(verbose)
await async_or_sync_call(
self.callback_manager.on_tool_start,
{"name": self.name, "description": self.description},
str(parsed_input),
verbose=verbose_,
color=start_color,
is_async=self.callback_manager.is_async,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(parsed_input)
except (Exception, KeyboardInterrupt) as e:
await async_or_sync_call(
self.callback_manager.on_tool_error,
e,
verbose=verbose_,
is_async=self.callback_manager.is_async,
)
raise e
await async_or_sync_call(
self.callback_manager.on_tool_end,
str(observation),
verbose=verbose_,
color=color,
is_async=self.callback_manager.is_async,
**kwargs,
)
return observation
def __call__(self, tool_input: INPUT_T) -> OUTPUT_T:
"""Make tool callable."""
return self.run(tool_input)
def _create_subset_model(
@@ -102,122 +202,14 @@ def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseMod
)
class BaseStructuredTool(ToolMixin[Dict, Dict], metaclass=ToolMetaclass):
"""Parent class for all structured tools."""
args_schema: Type[BaseModel] # :meta private:
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self._run).model # type: ignore
return get_filtered_args(inferred_model, self._run)
def _load_parsed_input(self, tool_input: Union[Dict, Any]) -> BaseModel:
"""Load the tool's input into a pydantic model."""
if not isinstance(tool_input, dict):
# Despite being typed as a Dict, there are cases when the LLM
# will not actually output dict args (e.g., for single arg inputs).
single_field = next(iter(self.args_schema.__fields__))
tool_input = {single_field: tool_input}
return self.args_schema.parse_obj(tool_input)
@abstractmethod
def _run(self, input_: BaseModel) -> Any:
"""Use the tool."""
@abstractmethod
async def _arun(self, input_: BaseModel) -> Any:
"""Use the tool asynchronously."""
def run(
self,
tool_input: Dict,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> Any:
"""Run the tool."""
parsed_input = self.args_schema.parse_obj(tool_input)
verbose_ = self._get_verbosity(verbose)
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
observation = self._run(parsed_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
)
if isinstance(observation, BaseModel):
observation = observation.dict()
return str(observation)
async def arun(
self,
tool_input: Union[Dict, Any],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
parsed_input = self.args_schema.parse_obj(tool_input)
verbose_ = self._get_verbosity(verbose)
await async_or_sync_call(
self.callback_manager.on_tool_start,
{"name": self.name, "description": self.description},
str(parsed_input.dict()),
verbose=verbose_,
color=start_color,
is_async=self.callback_manager.is_async,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._run(parsed_input)
except (Exception, KeyboardInterrupt) as e:
await async_or_sync_call(
self.callback_manager.on_tool_error,
e,
verbose=verbose_,
is_async=self.callback_manager.is_async,
)
raise e
if isinstance(observation, BaseModel):
observation = observation.dict()
await async_or_sync_call(
self.callback_manager.on_tool_end,
str(observation),
verbose=verbose_,
color=color,
is_async=self.callback_manager.is_async,
**kwargs,
)
return observation
def __call__(self, tool_input: Dict) -> Any:
"""Make tool callable."""
return self.run(tool_input)
class StructuredTool(BaseStructuredTool):
class StructuredTool(BaseStructuredTool[Dict, BaseModel, Any]):
"""StructuredTool that takes in function or coroutine directly."""
description: str = ""
func: Callable[..., Any]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
args_schema: Type[BaseModel] # :meta private:
@validator("func", pre=True, always=True)
def validate_func_not_partial(cls, func: Callable) -> Callable:
@@ -242,7 +234,7 @@ class StructuredTool(BaseStructuredTool):
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(**tool_input.dict())
raise NotImplementedError("StructuredTool does not support async")
raise NotImplementedError(f"StructuredTool {self.name} does not support async")
@classmethod
def from_function(
@@ -354,6 +346,3 @@ def structured_tool(
return _partial
else:
raise ValueError("Too many arguments for tool decorator")
TOOL_T = TypeVar("TOOL_T", bound=ToolMixin, default=BaseTool)