core[patch]: Fix tool args schema inherited field parsing (#24936)

Fix #24925
This commit is contained in:
Bagatur 2024-08-01 18:36:33 -07:00 committed by GitHub
parent fba65ba04f
commit 199e9c5ae0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 289 additions and 7 deletions

View File

@ -45,7 +45,7 @@ from typing import (
get_type_hints,
)
from typing_extensions import Annotated, cast, get_args, get_origin
from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -88,11 +88,16 @@ from langchain_core.runnables.config import (
run_in_executor,
)
from langchain_core.runnables.utils import accepts_context
from langchain_core.utils.function_calling import _parse_google_docstring
from langchain_core.utils.function_calling import (
_parse_google_docstring,
_py_38_safe_origin,
)
from langchain_core.utils.pydantic import (
TypeBaseModel,
_create_subset_model,
is_basemodel_subclass,
is_pydantic_v1_subclass,
is_pydantic_v2_subclass,
)
FILTERED_ARGS = ("run_manager", "callbacks")
@ -387,7 +392,7 @@ class ChildTool(BaseTool):
def tool_call_schema(self) -> Type[BaseModel]:
full_schema = self.get_input_schema()
fields = []
for name, type_ in full_schema.__annotations__.items():
for name, type_ in _get_all_basemodel_annotations(full_schema).items():
if not _is_injected_arg_type(type_):
fields.append(name)
return _create_subset_model(
@ -1650,3 +1655,80 @@ def _filter_schema_args(func: Callable) -> List[str]:
filter_args.append(config_param)
# filter_args.extend(_get_non_model_params(type_hints))
return filter_args
def _get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> Dict[str, Type]:
# cls has no subscript: cls = FooBar
if isinstance(cls, type):
annotations: Dict[str, Type] = {}
for name, param in inspect.signature(cls).parameters.items():
annotations[name] = param.annotation
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
# cls has subscript: cls = FooBar[int]
else:
annotations = _get_all_basemodel_annotations(
get_origin(cls), default_to_bound=False
)
orig_bases = (cls,)
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
# if cls = FooBar[int], orig_bases will contain FooBar[int]
for parent in orig_bases:
# if class = FooBar inherits from Baz, parent = Baz
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
annotations.update(
_get_all_basemodel_annotations(parent, default_to_bound=False)
)
continue
parent_origin = get_origin(parent)
# if class = FooBar inherits from non-pydantic class
if not parent_origin:
continue
# if class = FooBar inherits from Baz[str]:
# parent = Baz[str],
# parent_origin = Baz,
# generic_type_vars = (type vars in Baz)
# generic_map = {type var in Baz: str}
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple())
generic_map = {
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
}
for field in getattr(parent_origin, "__annotations__", dict()):
annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound
)
return {
k: _replace_type_vars(v, default_to_bound=default_to_bound)
for k, v in annotations.items()
}
def _replace_type_vars(
type_: Type,
generic_map: Optional[Dict[TypeVar, Type]] = None,
default_to_bound: bool = True,
) -> Type:
generic_map = generic_map or {}
if isinstance(type_, TypeVar):
if type_ in generic_map:
return generic_map[type_]
elif default_to_bound:
return type_.__bound__ or Any
else:
return type_
elif (origin := get_origin(type_)) and (args := get_args(type_)):
new_args = tuple(
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
)
return _py_38_safe_origin(origin)[new_args]
else:
return type_

View File

@ -53,6 +53,13 @@ def is_pydantic_v1_subclass(cls: Type) -> bool:
return False
def is_pydantic_v2_subclass(cls: Type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
def is_basemodel_subclass(cls: Type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.

View File

@ -8,10 +8,22 @@ import threading
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import pytest
from typing_extensions import Annotated, TypedDict
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
from typing_extensions import Annotated, TypedDict, TypeVar
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
@ -32,12 +44,13 @@ from langchain_core.tools import (
StructuredTool,
Tool,
ToolException,
_get_all_basemodel_annotations,
_is_message_content_block,
_is_message_content_type,
tool,
)
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import _create_subset_model
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _schema
@ -1452,6 +1465,66 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
}
def test_tool_inherited_injected_arg() -> None:
class barSchema(BaseModel):
"""bar."""
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
..., description="123"
)
class fooSchema(barSchema):
"""foo."""
x: int = Field(..., description="abc")
class InheritedInjectedArgTool(BaseTool):
name: str = "foo"
description: str = "foo."
args_schema: Type[BaseModel] = fooSchema
def _run(self, x: int, y: str) -> Any:
return y
tool_ = InheritedInjectedArgTool()
assert tool_.get_input_schema().schema() == {
"title": "fooSchema",
"description": "foo.",
"type": "object",
"properties": {
"x": {"description": "abc", "title": "X", "type": "integer"},
"y": {"description": "123", "title": "Y", "type": "string"},
},
"required": ["y", "x"],
}
assert tool_.tool_call_schema.schema() == {
"title": "foo",
"description": "foo.",
"type": "object",
"properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},
"required": ["x"],
}
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
assert tool_.invoke(
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
) == ToolMessage("bar", tool_call_id="123", name="foo")
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
with pytest.raises(expected_error):
tool_.invoke({"x": 5})
assert convert_to_openai_function(tool_) == {
"name": "foo",
"description": "foo.",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "abc"}},
"required": ["x"],
},
}
def _get_parametrized_tools() -> list:
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
"""my_tool."""
@ -1484,7 +1557,6 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
def generate_models() -> List[Any]:
"""Generate a list of base models depending on the pydantic version."""
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
class FooProper(BaseModelProper):
a: int
@ -1670,3 +1742,124 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
)
def test__is_message_content_type(obj: Any, expected: bool) -> None:
assert _is_message_content_type(obj) is expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
@pytest.mark.parametrize("use_v1_namespace", [True, False])
def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
A = TypeVar("A")
if use_v1_namespace:
class ModelA(BaseModel, Generic[A]):
a: A
else:
class ModelA(BaseModelProper, Generic[A]): # type: ignore[no-redef]
a: A
class ModelB(ModelA[str]):
b: Annotated[ModelA[Dict[str, Any]], "foo"]
class Mixin(object):
def foo(self) -> str:
return "foo"
class ModelC(Mixin, ModelB):
c: dict
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
actual = _get_all_basemodel_annotations(ModelC)
assert actual == expected
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
actual = _get_all_basemodel_annotations(ModelB)
assert actual == expected
expected = {"a": Any}
actual = _get_all_basemodel_annotations(ModelA)
assert actual == expected
expected = {"a": int}
actual = _get_all_basemodel_annotations(ModelA[int])
assert actual == expected
D = TypeVar("D", bound=Union[str, int])
class ModelD(ModelC, Generic[D]):
d: Optional[D]
expected = {
"a": str,
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
"c": dict,
"d": Union[str, int, None],
}
actual = _get_all_basemodel_annotations(ModelD)
assert actual == expected
expected = {
"a": str,
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
"c": dict,
"d": Union[int, None],
}
actual = _get_all_basemodel_annotations(ModelD[int])
assert actual == expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.")
def test__get_all_basemodel_annotations_v1() -> None:
A = TypeVar("A")
class ModelA(BaseModel, Generic[A]):
a: A
class ModelB(ModelA[str]):
b: Annotated[ModelA[Dict[str, Any]], "foo"]
class Mixin(object):
def foo(self) -> str:
return "foo"
class ModelC(Mixin, ModelB):
c: dict
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
actual = _get_all_basemodel_annotations(ModelC)
assert actual == expected
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
actual = _get_all_basemodel_annotations(ModelB)
assert actual == expected
expected = {"a": Any}
actual = _get_all_basemodel_annotations(ModelA)
assert actual == expected
expected = {"a": int}
actual = _get_all_basemodel_annotations(ModelA[int])
assert actual == expected
D = TypeVar("D", bound=Union[str, int])
class ModelD(ModelC, Generic[D]):
d: Optional[D]
expected = {
"a": str,
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
"c": dict,
"d": Union[str, int, None],
}
actual = _get_all_basemodel_annotations(ModelD)
assert actual == expected
expected = {
"a": str,
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
"c": dict,
"d": Union[int, None],
}
actual = _get_all_basemodel_annotations(ModelD[int])
assert actual == expected