mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 00:58:32 +00:00
Structured Tool Bugfixes (#3324)
- Proactively raise error if a tool subclasses BaseTool, defines its own schema, but fails to add the type-hints - fix the auto-inferred schema of the decorator to strip the unneeded virtual kwargs from the schema dict Helps avoid silent instances of #3297
This commit is contained in:
@@ -1,10 +1,19 @@
|
||||
"""Base implementation for tools or skills."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, validate_arguments, validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
create_model,
|
||||
validate_arguments,
|
||||
validator,
|
||||
)
|
||||
from pydantic.main import ModelMetaclass
|
||||
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
@@ -19,7 +28,77 @@ def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
|
||||
return [], run_input
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel):
|
||||
class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
|
||||
|
||||
class ToolMetaclass(ModelMetaclass):
|
||||
"""Metaclass for BaseTool to ensure the provided args_schema
|
||||
|
||||
doesn't silently ignored."""
|
||||
|
||||
def __new__(
|
||||
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
|
||||
) -> ToolMetaclass:
|
||||
"""Create the definition of the new tool class."""
|
||||
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
|
||||
if schema_type is not None:
|
||||
schema_annotations = dct.get("__annotations__", {})
|
||||
args_schema_type = schema_annotations.get("args_schema", None)
|
||||
if args_schema_type is None or args_schema_type == BaseModel:
|
||||
# Throw errors for common mis-annotations.
|
||||
# TODO: Use get_args / get_origin and fully
|
||||
# specify valid annotations.
|
||||
typehint_mandate = """
|
||||
class ChildTool(BaseTool):
|
||||
...
|
||||
args_schema: Type[BaseModel] = SchemaClass
|
||||
..."""
|
||||
raise SchemaAnnotationError(
|
||||
f"Tool definition for {name} must include valid type annotations"
|
||||
f" for argument 'args_schema' to behave as expected.\n"
|
||||
f"Expected annotation of 'Type[BaseModel]'"
|
||||
f" but got '{args_schema_type}'.\n"
|
||||
f"Expected class looks like:\n"
|
||||
f"{typehint_mandate}"
|
||||
)
|
||||
# Pass through to Pydantic's metaclass
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: BaseModel, field_names: list
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {
|
||||
field_name: (
|
||||
model.__fields__[field_name].type_,
|
||||
model.__fields__[field_name].default,
|
||||
)
|
||||
for field_name in field_names
|
||||
if field_name in model.__fields__
|
||||
}
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
||||
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {k: schema[k] for k in valid_keys}
|
||||
|
||||
|
||||
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature."""
|
||||
inferred_model = validate_arguments(func).model # type: ignore
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
filtered_args = get_filtered_args(inferred_model, func)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(filtered_args)
|
||||
)
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
name: str
|
||||
@@ -42,9 +121,7 @@ class BaseTool(ABC, BaseModel):
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
inferred_model = validate_arguments(self._run).model # type: ignore
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(self._run).parameters
|
||||
return {k: schema[k] for k in valid_keys}
|
||||
return get_filtered_args(inferred_model, self._run)
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user