Compare commits

...

2 Commits

Author SHA1 Message Date
vowelparrot
02fc6f6309 Other option 2023-04-17 22:59:26 -07:00
vowelparrot
766aeed53e Update to not generate a model on each call 2023-04-17 22:53:15 -07:00
3 changed files with 127 additions and 112 deletions

View File

@@ -1,10 +1,23 @@
"""Interface for tools."""
import inspect
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pydantic import BaseModel
from pydantic import BaseModel, Field, create_model
from langchain.tools.base import BaseTool, create_args_schema_model_from_signature
from langchain.tools.base import (
BaseTool,
)
class Tool(BaseTool):
@@ -16,23 +29,22 @@ class Tool(BaseTool):
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
# Infer the schema directly from the function to add more structured
# arguments.
return create_args_schema_model_from_signature(self.func)
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: Union[str, BaseModel]) -> str:
"""Use the tool."""
return self.func(*args, **kwargs)
if isinstance(tool_input, str):
return self.func(tool_input)
else:
args, kwargs = _to_args_and_kwargs(tool_input)
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: Union[str, BaseModel]) -> str:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
if isinstance(tool_input, str):
return await self.coroutine(tool_input)
else:
args, kwargs = _to_args_and_kwargs(tool_input)
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future
@@ -51,13 +63,89 @@ class InvalidTool(BaseTool):
name = "invalid_tool"
description = "Called when tool name is invalid."
def _run(self, tool_name: str) -> str:
def _run(self, tool_name: Union[str, BaseModel]) -> str:
"""Use the tool."""
return f"{tool_name} is not a valid tool, try another one."
return f"{str(tool_name)} is not a valid tool, try another one."
async def _arun(self, tool_name: str) -> str:
async def _arun(self, tool_name: Union[str, BaseModel]) -> str:
"""Use the tool asynchronously."""
return f"{tool_name} is not a valid tool, try another one."
return f"{str(tool_name)} is not a valid tool, try another one."
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
"""Convert pydantic model to args and kwargs."""
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
def _create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]:
"""Create a pydantic model type from a function's signature."""
signature_ = inspect.signature(run_func)
field_definitions: Dict[str, Any] = {}
for name, param in signature_.parameters.items():
if name == "self":
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _create_schema_if_multiarg(
func: Callable,
) -> Optional[Type[BaseModel]]:
signature_ = inspect.signature(func)
parameters = signature_.parameters
if len(parameters) == 1 and next(iter(parameters.values())).annotation == str:
# Default tools take a single string as input and don't need a dynamic
# schema validation
return None
else:
return _create_args_schema_model_from_signature(func)
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
@@ -87,7 +175,7 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
args_schema = create_args_schema_model_from_signature(func)
args_schema = _create_schema_if_multiarg(func)
tool_ = Tool(
name=tool_name,
func=func,

View File

@@ -1,77 +1,14 @@
"""Base implementation for tools or skills."""
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, Optional, Type, Union
from pydantic import BaseModel, Extra, Field, create_model, validator
from pydantic import BaseModel, Extra, Field, validator
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]:
"""Create a pydantic model type from a function's signature."""
signature_ = inspect.signature(run_func)
field_definitions: Dict[str, Any] = {}
for name, param in signature_.parameters.items():
if name == "self":
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
class BaseTool(ABC, BaseModel):
"""Interface LangChain tools must implement."""
@@ -89,27 +26,15 @@ class BaseTool(ABC, BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
return create_args_schema_model_from_signature(self._run)
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> BaseModel:
) -> Union[str, BaseModel]:
"""Convert tool input to pydantic model."""
pydantic_input_type = self.args
if isinstance(tool_input, str):
# For backwards compatibility, a tool that only takes
# a single string input will be converted to a dict.
# to be validated.
field_name = next(iter(pydantic_input_type.__fields__))
tool_input = {field_name: tool_input}
if pydantic_input_type is not None:
return pydantic_input_type.parse_obj(tool_input)
return tool_input
if self.args_schema is not None:
return self.args_schema.parse_obj(tool_input)
else:
raise ValueError(
f"args_schema required for tool {self.name} in order to"
@@ -127,11 +52,11 @@ class BaseTool(ABC, BaseModel):
return callback_manager or get_callback_manager()
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: Union[str, BaseModel]) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: Union[str, BaseModel]) -> str:
"""Use the tool asynchronously."""
def run(
@@ -156,8 +81,7 @@ class BaseTool(ABC, BaseModel):
**kwargs,
)
try:
args, kwargs = _to_args_and_kwargs(run_input)
observation = self._run(*args, **kwargs)
observation = self._run(run_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
@@ -183,7 +107,7 @@ class BaseTool(ABC, BaseModel):
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(run_input.dict()),
str(run_input),
verbose=verbose_,
color=start_color,
**kwargs,
@@ -191,15 +115,14 @@ class BaseTool(ABC, BaseModel):
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(run_input.dict()),
str(run_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(run_input)
observation = await self._arun(*args, **kwargs)
observation = await self._arun(run_input)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -219,3 +142,7 @@ class BaseTool(ABC, BaseModel):
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
return self.run(tool_input)
def foo(tool_input: str) -> str:
return tool_input

View File

@@ -34,8 +34,8 @@ class _MockStructuredTool(BaseTool):
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
def _run(self, schema: BaseModel) -> str:
return f"{schema.arg1} {schema.arg2} {schema.arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError