From 199e9c5ae0f642b0421b317bd33c9448e289e9ed Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:36:33 -0700 Subject: [PATCH] core[patch]: Fix tool args schema inherited field parsing (#24936) Fix #24925 --- libs/core/langchain_core/tools.py | 88 ++++++++- libs/core/langchain_core/utils/pydantic.py | 7 + libs/core/tests/unit_tests/test_tools.py | 201 ++++++++++++++++++++- 3 files changed, 289 insertions(+), 7 deletions(-) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 4cf12f050e2..5ff3a72cd33 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -45,7 +45,7 @@ from typing import ( get_type_hints, ) -from typing_extensions import Annotated, cast, get_args, get_origin +from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -88,11 +88,16 @@ from langchain_core.runnables.config import ( run_in_executor, ) from langchain_core.runnables.utils import accepts_context -from langchain_core.utils.function_calling import _parse_google_docstring +from langchain_core.utils.function_calling import ( + _parse_google_docstring, + _py_38_safe_origin, +) from langchain_core.utils.pydantic import ( TypeBaseModel, _create_subset_model, is_basemodel_subclass, + is_pydantic_v1_subclass, + is_pydantic_v2_subclass, ) FILTERED_ARGS = ("run_manager", "callbacks") @@ -387,7 +392,7 @@ class ChildTool(BaseTool): def tool_call_schema(self) -> Type[BaseModel]: full_schema = self.get_input_schema() fields = [] - for name, type_ in full_schema.__annotations__.items(): + for name, type_ in _get_all_basemodel_annotations(full_schema).items(): if not _is_injected_arg_type(type_): fields.append(name) return _create_subset_model( @@ -1650,3 +1655,80 @@ def _filter_schema_args(func: Callable) -> List[str]: filter_args.append(config_param) # filter_args.extend(_get_non_model_params(type_hints)) return filter_args + + +def _get_all_basemodel_annotations( + cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True +) -> Dict[str, Type]: + # cls has no subscript: cls = FooBar + if isinstance(cls, type): + annotations: Dict[str, Type] = {} + for name, param in inspect.signature(cls).parameters.items(): + annotations[name] = param.annotation + orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple()) + # cls has subscript: cls = FooBar[int] + else: + annotations = _get_all_basemodel_annotations( + get_origin(cls), default_to_bound=False + ) + orig_bases = (cls,) + + # Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not. + if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)): + # if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str] + # if cls = FooBar inherits from Baz, orig_bases will contain Baz + # if cls = FooBar[int], orig_bases will contain FooBar[int] + for parent in orig_bases: + # if class = FooBar inherits from Baz, parent = Baz + if isinstance(parent, type) and is_pydantic_v1_subclass(parent): + annotations.update( + _get_all_basemodel_annotations(parent, default_to_bound=False) + ) + continue + + parent_origin = get_origin(parent) + + # if class = FooBar inherits from non-pydantic class + if not parent_origin: + continue + + # if class = FooBar inherits from Baz[str]: + # parent = Baz[str], + # parent_origin = Baz, + # generic_type_vars = (type vars in Baz) + # generic_map = {type var in Baz: str} + generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple()) + generic_map = { + type_var: t for type_var, t in zip(generic_type_vars, get_args(parent)) + } + for field in getattr(parent_origin, "__annotations__", dict()): + annotations[field] = _replace_type_vars( + annotations[field], generic_map, default_to_bound + ) + + return { + k: _replace_type_vars(v, default_to_bound=default_to_bound) + for k, v in annotations.items() + } + + +def _replace_type_vars( + type_: Type, + generic_map: Optional[Dict[TypeVar, Type]] = None, + default_to_bound: bool = True, +) -> Type: + generic_map = generic_map or {} + if isinstance(type_, TypeVar): + if type_ in generic_map: + return generic_map[type_] + elif default_to_bound: + return type_.__bound__ or Any + else: + return type_ + elif (origin := get_origin(type_)) and (args := get_args(type_)): + new_args = tuple( + _replace_type_vars(arg, generic_map, default_to_bound) for arg in args + ) + return _py_38_safe_origin(origin)[new_args] + else: + return type_ diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 00c17676775..e1ed53d1ed4 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -53,6 +53,13 @@ def is_pydantic_v1_subclass(cls: Type) -> bool: return False +def is_pydantic_v2_subclass(cls: Type) -> bool: + """Check if the installed Pydantic version is 1.x-like.""" + from pydantic import BaseModel + + return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel) + + def is_basemodel_subclass(cls: Type) -> bool: """Check if the given class is a subclass of Pydantic BaseModel. diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 6249ae15023..a222bc01500 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -8,10 +8,22 @@ import threading from datetime import datetime from enum import Enum from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) import pytest -from typing_extensions import Annotated, TypedDict +from pydantic import BaseModel as BaseModelProper # pydantic: ignore +from typing_extensions import Annotated, TypedDict, TypeVar from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, @@ -32,12 +44,13 @@ from langchain_core.tools import ( StructuredTool, Tool, ToolException, + _get_all_basemodel_annotations, _is_message_content_block, _is_message_content_type, tool, ) from langchain_core.utils.function_calling import convert_to_openai_function -from langchain_core.utils.pydantic import _create_subset_model +from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.pydantic_utils import _schema @@ -1452,6 +1465,66 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None: } +def test_tool_inherited_injected_arg() -> None: + class barSchema(BaseModel): + """bar.""" + + y: Annotated[str, "foobar comment", InjectedToolArg()] = Field( + ..., description="123" + ) + + class fooSchema(barSchema): + """foo.""" + + x: int = Field(..., description="abc") + + class InheritedInjectedArgTool(BaseTool): + name: str = "foo" + description: str = "foo." + args_schema: Type[BaseModel] = fooSchema + + def _run(self, x: int, y: str) -> Any: + return y + + tool_ = InheritedInjectedArgTool() + assert tool_.get_input_schema().schema() == { + "title": "fooSchema", + "description": "foo.", + "type": "object", + "properties": { + "x": {"description": "abc", "title": "X", "type": "integer"}, + "y": {"description": "123", "title": "Y", "type": "string"}, + }, + "required": ["y", "x"], + } + assert tool_.tool_call_schema.schema() == { + "title": "foo", + "description": "foo.", + "type": "object", + "properties": {"x": {"description": "abc", "title": "X", "type": "integer"}}, + "required": ["x"], + } + assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" + assert tool_.invoke( + {"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"} + ) == ToolMessage("bar", tool_call_id="123", name="foo") + expected_error = ( + ValidationError if not isinstance(tool_, InjectedTool) else TypeError + ) + with pytest.raises(expected_error): + tool_.invoke({"x": 5}) + + assert convert_to_openai_function(tool_) == { + "name": "foo", + "description": "foo.", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "abc"}}, + "required": ["x"], + }, + } + + def _get_parametrized_tools() -> list: def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str: """my_tool.""" @@ -1484,7 +1557,6 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None: def generate_models() -> List[Any]: """Generate a list of base models depending on the pydantic version.""" - from pydantic import BaseModel as BaseModelProper # pydantic: ignore class FooProper(BaseModelProper): a: int @@ -1670,3 +1742,124 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None: ) def test__is_message_content_type(obj: Any, expected: bool) -> None: assert _is_message_content_type(obj) is expected + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") +@pytest.mark.parametrize("use_v1_namespace", [True, False]) +def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None: + A = TypeVar("A") + + if use_v1_namespace: + + class ModelA(BaseModel, Generic[A]): + a: A + else: + + class ModelA(BaseModelProper, Generic[A]): # type: ignore[no-redef] + a: A + + class ModelB(ModelA[str]): + b: Annotated[ModelA[Dict[str, Any]], "foo"] + + class Mixin(object): + def foo(self) -> str: + return "foo" + + class ModelC(Mixin, ModelB): + c: dict + + expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict} + actual = _get_all_basemodel_annotations(ModelC) + assert actual == expected + + expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]} + actual = _get_all_basemodel_annotations(ModelB) + assert actual == expected + + expected = {"a": Any} + actual = _get_all_basemodel_annotations(ModelA) + assert actual == expected + + expected = {"a": int} + actual = _get_all_basemodel_annotations(ModelA[int]) + assert actual == expected + + D = TypeVar("D", bound=Union[str, int]) + + class ModelD(ModelC, Generic[D]): + d: Optional[D] + + expected = { + "a": str, + "b": Annotated[ModelA[Dict[str, Any]], "foo"], + "c": dict, + "d": Union[str, int, None], + } + actual = _get_all_basemodel_annotations(ModelD) + assert actual == expected + + expected = { + "a": str, + "b": Annotated[ModelA[Dict[str, Any]], "foo"], + "c": dict, + "d": Union[int, None], + } + actual = _get_all_basemodel_annotations(ModelD[int]) + assert actual == expected + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.") +def test__get_all_basemodel_annotations_v1() -> None: + A = TypeVar("A") + + class ModelA(BaseModel, Generic[A]): + a: A + + class ModelB(ModelA[str]): + b: Annotated[ModelA[Dict[str, Any]], "foo"] + + class Mixin(object): + def foo(self) -> str: + return "foo" + + class ModelC(Mixin, ModelB): + c: dict + + expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict} + actual = _get_all_basemodel_annotations(ModelC) + assert actual == expected + + expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]} + actual = _get_all_basemodel_annotations(ModelB) + assert actual == expected + + expected = {"a": Any} + actual = _get_all_basemodel_annotations(ModelA) + assert actual == expected + + expected = {"a": int} + actual = _get_all_basemodel_annotations(ModelA[int]) + assert actual == expected + + D = TypeVar("D", bound=Union[str, int]) + + class ModelD(ModelC, Generic[D]): + d: Optional[D] + + expected = { + "a": str, + "b": Annotated[ModelA[Dict[str, Any]], "foo"], + "c": dict, + "d": Union[str, int, None], + } + actual = _get_all_basemodel_annotations(ModelD) + assert actual == expected + + expected = { + "a": str, + "b": Annotated[ModelA[Dict[str, Any]], "foo"], + "c": dict, + "d": Union[int, None], + } + actual = _get_all_basemodel_annotations(ModelD[int]) + assert actual == expected