Structured Tool Bugfixes (#3324)

- Proactively raise error if a tool subclasses BaseTool, defines its
own schema, but fails to add the type-hints
- fix the auto-inferred schema of the decorator to strip the
unneeded virtual kwargs from the schema dict

Helps avoid silent instances of #3297
This commit is contained in:
Zander Chase
2023-04-24 09:58:29 -07:00
committed by GitHub
parent f22b9d0e57
commit 49122a96e7
3 changed files with 287 additions and 16 deletions

View File

@@ -1,10 +1,19 @@
"""Base implementation for tools or skills."""
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import BaseModel, Extra, Field, validate_arguments, validator
from pydantic import (
BaseModel,
Extra,
Field,
create_model,
validate_arguments,
validator,
)
from pydantic.main import ModelMetaclass
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
@@ -19,7 +28,77 @@ def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
return [], run_input
class BaseTool(ABC, BaseModel):
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
class ToolMetaclass(ModelMetaclass):
"""Metaclass for BaseTool to ensure the provided args_schema
doesn't silently ignored."""
def __new__(
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
) -> ToolMetaclass:
"""Create the definition of the new tool class."""
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
if schema_type is not None:
schema_annotations = dct.get("__annotations__", {})
args_schema_type = schema_annotations.get("args_schema", None)
if args_schema_type is None or args_schema_type == BaseModel:
# Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully
# specify valid annotations.
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
..."""
raise SchemaAnnotationError(
f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
)
# Pass through to Pydantic's metaclass
return super().__new__(cls, name, bases, dct)
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
)
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
"""Interface LangChain tools must implement."""
name: str
@@ -42,9 +121,7 @@ class BaseTool(ABC, BaseModel):
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self._run).model # type: ignore
schema = inferred_model.schema()["properties"]
valid_keys = signature(self._run).parameters
return {k: schema[k] for k in valid_keys}
return get_filtered_args(inferred_model, self._run)
def _parse_input(
self,