Compare commits

...

2 Commits

Author SHA1 Message Date
vowelparrot
02fc6f6309 Other option 2023-04-17 22:59:26 -07:00
vowelparrot
766aeed53e Update to not generate a model on each call 2023-04-17 22:53:15 -07:00
3 changed files with 127 additions and 112 deletions

View File

@@ -1,10 +1,23 @@
"""Interface for tools.""" """Interface for tools."""
import inspect
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union from typing import (
Any,
Awaitable,
Callable,
Dict,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pydantic import BaseModel from pydantic import BaseModel, Field, create_model
from langchain.tools.base import BaseTool, create_args_schema_model_from_signature from langchain.tools.base import (
BaseTool,
)
class Tool(BaseTool): class Tool(BaseTool):
@@ -16,23 +29,22 @@ class Tool(BaseTool):
coroutine: Optional[Callable[..., Awaitable[str]]] = None coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function.""" """The asynchronous version of the function."""
@property def _run(self, tool_input: Union[str, BaseModel]) -> str:
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
# Infer the schema directly from the function to add more structured
# arguments.
return create_args_schema_model_from_signature(self.func)
def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool.""" """Use the tool."""
return self.func(*args, **kwargs) if isinstance(tool_input, str):
return self.func(tool_input)
else:
args, kwargs = _to_args_and_kwargs(tool_input)
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> str: async def _arun(self, tool_input: Union[str, BaseModel]) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
return await self.coroutine(*args, **kwargs) if isinstance(tool_input, str):
return await self.coroutine(tool_input)
else:
args, kwargs = _to_args_and_kwargs(tool_input)
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async") raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
@@ -51,13 +63,89 @@ class InvalidTool(BaseTool):
name = "invalid_tool" name = "invalid_tool"
description = "Called when tool name is invalid." description = "Called when tool name is invalid."
def _run(self, tool_name: str) -> str: def _run(self, tool_name: Union[str, BaseModel]) -> str:
"""Use the tool.""" """Use the tool."""
return f"{tool_name} is not a valid tool, try another one." return f"{str(tool_name)} is not a valid tool, try another one."
async def _arun(self, tool_name: str) -> str: async def _arun(self, tool_name: Union[str, BaseModel]) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
return f"{tool_name} is not a valid tool, try another one." return f"{str(tool_name)} is not a valid tool, try another one."
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
"""Convert pydantic model to args and kwargs."""
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
def _create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]:
"""Create a pydantic model type from a function's signature."""
signature_ = inspect.signature(run_func)
field_definitions: Dict[str, Any] = {}
for name, param in signature_.parameters.items():
if name == "self":
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _create_schema_if_multiarg(
func: Callable,
) -> Optional[Type[BaseModel]]:
signature_ = inspect.signature(func)
parameters = signature_.parameters
if len(parameters) == 1 and next(iter(parameters.values())).annotation == str:
# Default tools take a single string as input and don't need a dynamic
# schema validation
return None
else:
return _create_args_schema_model_from_signature(func)
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable: def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
@@ -87,7 +175,7 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
args_schema = create_args_schema_model_from_signature(func) args_schema = _create_schema_if_multiarg(func)
tool_ = Tool( tool_ = Tool(
name=tool_name, name=tool_name,
func=func, func=func,

View File

@@ -1,77 +1,14 @@
"""Base implementation for tools or skills.""" """Base implementation for tools or skills."""
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union from typing import Any, Dict, Optional, Type, Union
from pydantic import BaseModel, Extra, Field, create_model, validator from pydantic import BaseModel, Extra, Field, validator
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]:
"""Create a pydantic model type from a function's signature."""
signature_ = inspect.signature(run_func)
field_definitions: Dict[str, Any] = {}
for name, param in signature_.parameters.items():
if name == "self":
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
class BaseTool(ABC, BaseModel): class BaseTool(ABC, BaseModel):
"""Interface LangChain tools must implement.""" """Interface LangChain tools must implement."""
@@ -89,27 +26,15 @@ class BaseTool(ABC, BaseModel):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
return create_args_schema_model_from_signature(self._run)
def _parse_input( def _parse_input(
self, self,
tool_input: Union[str, Dict], tool_input: Union[str, Dict],
) -> BaseModel: ) -> Union[str, BaseModel]:
"""Convert tool input to pydantic model.""" """Convert tool input to pydantic model."""
pydantic_input_type = self.args
if isinstance(tool_input, str): if isinstance(tool_input, str):
# For backwards compatibility, a tool that only takes return tool_input
# a single string input will be converted to a dict. if self.args_schema is not None:
# to be validated. return self.args_schema.parse_obj(tool_input)
field_name = next(iter(pydantic_input_type.__fields__))
tool_input = {field_name: tool_input}
if pydantic_input_type is not None:
return pydantic_input_type.parse_obj(tool_input)
else: else:
raise ValueError( raise ValueError(
f"args_schema required for tool {self.name} in order to" f"args_schema required for tool {self.name} in order to"
@@ -127,11 +52,11 @@ class BaseTool(ABC, BaseModel):
return callback_manager or get_callback_manager() return callback_manager or get_callback_manager()
@abstractmethod @abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> str: def _run(self, tool_input: Union[str, BaseModel]) -> str:
"""Use the tool.""" """Use the tool."""
@abstractmethod @abstractmethod
async def _arun(self, *args: Any, **kwargs: Any) -> str: async def _arun(self, tool_input: Union[str, BaseModel]) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
def run( def run(
@@ -156,8 +81,7 @@ class BaseTool(ABC, BaseModel):
**kwargs, **kwargs,
) )
try: try:
args, kwargs = _to_args_and_kwargs(run_input) observation = self._run(run_input)
observation = self._run(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_) self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e raise e
@@ -183,7 +107,7 @@ class BaseTool(ABC, BaseModel):
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_start( await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
str(run_input.dict()), str(run_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
@@ -191,15 +115,14 @@ class BaseTool(ABC, BaseModel):
else: else:
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
str(run_input.dict()), str(run_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(run_input) observation = await self._arun(run_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_) await self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -219,3 +142,7 @@ class BaseTool(ABC, BaseModel):
def __call__(self, tool_input: str) -> str: def __call__(self, tool_input: str) -> str:
"""Make tool callable.""" """Make tool callable."""
return self.run(tool_input) return self.run(tool_input)
def foo(tool_input: str) -> str:
return tool_input

View File

@@ -34,8 +34,8 @@ class _MockStructuredTool(BaseTool):
args_schema: Type[BaseModel] = _MockSchema args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool" description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: def _run(self, schema: BaseModel) -> str:
return f"{arg1} {arg2} {arg3}" return f"{schema.arg1} {schema.arg2} {schema.arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError raise NotImplementedError