core[patch]: Fix tool args schema inherited field parsing (#24936)

Fix #24925
This commit is contained in:
Bagatur 2024-08-01 18:36:33 -07:00 committed by GitHub
parent fba65ba04f
commit 199e9c5ae0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 289 additions and 7 deletions

View File

@ -45,7 +45,7 @@ from typing import (
get_type_hints, get_type_hints,
) )
from typing_extensions import Annotated, cast, get_args, get_origin from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -88,11 +88,16 @@ from langchain_core.runnables.config import (
run_in_executor, run_in_executor,
) )
from langchain_core.runnables.utils import accepts_context from langchain_core.runnables.utils import accepts_context
from langchain_core.utils.function_calling import _parse_google_docstring from langchain_core.utils.function_calling import (
_parse_google_docstring,
_py_38_safe_origin,
)
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import (
TypeBaseModel, TypeBaseModel,
_create_subset_model, _create_subset_model,
is_basemodel_subclass, is_basemodel_subclass,
is_pydantic_v1_subclass,
is_pydantic_v2_subclass,
) )
FILTERED_ARGS = ("run_manager", "callbacks") FILTERED_ARGS = ("run_manager", "callbacks")
@ -387,7 +392,7 @@ class ChildTool(BaseTool):
def tool_call_schema(self) -> Type[BaseModel]: def tool_call_schema(self) -> Type[BaseModel]:
full_schema = self.get_input_schema() full_schema = self.get_input_schema()
fields = [] fields = []
for name, type_ in full_schema.__annotations__.items(): for name, type_ in _get_all_basemodel_annotations(full_schema).items():
if not _is_injected_arg_type(type_): if not _is_injected_arg_type(type_):
fields.append(name) fields.append(name)
return _create_subset_model( return _create_subset_model(
@ -1650,3 +1655,80 @@ def _filter_schema_args(func: Callable) -> List[str]:
filter_args.append(config_param) filter_args.append(config_param)
# filter_args.extend(_get_non_model_params(type_hints)) # filter_args.extend(_get_non_model_params(type_hints))
return filter_args return filter_args
def _get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> Dict[str, Type]:
# cls has no subscript: cls = FooBar
if isinstance(cls, type):
annotations: Dict[str, Type] = {}
for name, param in inspect.signature(cls).parameters.items():
annotations[name] = param.annotation
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
# cls has subscript: cls = FooBar[int]
else:
annotations = _get_all_basemodel_annotations(
get_origin(cls), default_to_bound=False
)
orig_bases = (cls,)
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
# if cls = FooBar[int], orig_bases will contain FooBar[int]
for parent in orig_bases:
# if class = FooBar inherits from Baz, parent = Baz
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
annotations.update(
_get_all_basemodel_annotations(parent, default_to_bound=False)
)
continue
parent_origin = get_origin(parent)
# if class = FooBar inherits from non-pydantic class
if not parent_origin:
continue
# if class = FooBar inherits from Baz[str]:
# parent = Baz[str],
# parent_origin = Baz,
# generic_type_vars = (type vars in Baz)
# generic_map = {type var in Baz: str}
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple())
generic_map = {
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
}
for field in getattr(parent_origin, "__annotations__", dict()):
annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound
)
return {
k: _replace_type_vars(v, default_to_bound=default_to_bound)
for k, v in annotations.items()
}
def _replace_type_vars(
type_: Type,
generic_map: Optional[Dict[TypeVar, Type]] = None,
default_to_bound: bool = True,
) -> Type:
generic_map = generic_map or {}
if isinstance(type_, TypeVar):
if type_ in generic_map:
return generic_map[type_]
elif default_to_bound:
return type_.__bound__ or Any
else:
return type_
elif (origin := get_origin(type_)) and (args := get_args(type_)):
new_args = tuple(
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
)
return _py_38_safe_origin(origin)[new_args]
else:
return type_

View File

@ -53,6 +53,13 @@ def is_pydantic_v1_subclass(cls: Type) -> bool:
return False return False
def is_pydantic_v2_subclass(cls: Type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
def is_basemodel_subclass(cls: Type) -> bool: def is_basemodel_subclass(cls: Type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel. """Check if the given class is a subclass of Pydantic BaseModel.

View File

@ -8,10 +8,22 @@ import threading
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import pytest import pytest
from typing_extensions import Annotated, TypedDict from pydantic import BaseModel as BaseModelProper # pydantic: ignore
from typing_extensions import Annotated, TypedDict, TypeVar
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
@ -32,12 +44,13 @@ from langchain_core.tools import (
StructuredTool, StructuredTool,
Tool, Tool,
ToolException, ToolException,
_get_all_basemodel_annotations,
_is_message_content_block, _is_message_content_block,
_is_message_content_type, _is_message_content_type,
tool, tool,
) )
from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import _create_subset_model from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _schema from tests.unit_tests.pydantic_utils import _schema
@ -1452,6 +1465,66 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
} }
def test_tool_inherited_injected_arg() -> None:
class barSchema(BaseModel):
"""bar."""
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
..., description="123"
)
class fooSchema(barSchema):
"""foo."""
x: int = Field(..., description="abc")
class InheritedInjectedArgTool(BaseTool):
name: str = "foo"
description: str = "foo."
args_schema: Type[BaseModel] = fooSchema
def _run(self, x: int, y: str) -> Any:
return y
tool_ = InheritedInjectedArgTool()
assert tool_.get_input_schema().schema() == {
"title": "fooSchema",
"description": "foo.",
"type": "object",
"properties": {
"x": {"description": "abc", "title": "X", "type": "integer"},
"y": {"description": "123", "title": "Y", "type": "string"},
},
"required": ["y", "x"],
}
assert tool_.tool_call_schema.schema() == {
"title": "foo",
"description": "foo.",
"type": "object",
"properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},
"required": ["x"],
}
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
assert tool_.invoke(
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
) == ToolMessage("bar", tool_call_id="123", name="foo")
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
with pytest.raises(expected_error):
tool_.invoke({"x": 5})
assert convert_to_openai_function(tool_) == {
"name": "foo",
"description": "foo.",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "abc"}},
"required": ["x"],
},
}
def _get_parametrized_tools() -> list: def _get_parametrized_tools() -> list:
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str: def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
"""my_tool.""" """my_tool."""
@ -1484,7 +1557,6 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
def generate_models() -> List[Any]: def generate_models() -> List[Any]:
"""Generate a list of base models depending on the pydantic version.""" """Generate a list of base models depending on the pydantic version."""
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
class FooProper(BaseModelProper): class FooProper(BaseModelProper):
a: int a: int
@ -1670,3 +1742,124 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
) )
def test__is_message_content_type(obj: Any, expected: bool) -> None: def test__is_message_content_type(obj: Any, expected: bool) -> None:
assert _is_message_content_type(obj) is expected assert _is_message_content_type(obj) is expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
@pytest.mark.parametrize("use_v1_namespace", [True, False])
def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
A = TypeVar("A")
if use_v1_namespace:
class ModelA(BaseModel, Generic[A]):
a: A
else:
class ModelA(BaseModelProper, Generic[A]): # type: ignore[no-redef]
a: A
class ModelB(ModelA[str]):
b: Annotated[ModelA[Dict[str, Any]], "foo"]
class Mixin(object):
def foo(self) -> str:
return "foo"
class ModelC(Mixin, ModelB):
c: dict
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
actual = _get_all_basemodel_annotations(ModelC)
assert actual == expected
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
actual = _get_all_basemodel_annotations(ModelB)
assert actual == expected
expected = {"a": Any}
actual = _get_all_basemodel_annotations(ModelA)
assert actual == expected
expected = {"a": int}
actual = _get_all_basemodel_annotations(ModelA[int])
assert actual == expected
D = TypeVar("D", bound=Union[str, int])
class ModelD(ModelC, Generic[D]):
d: Optional[D]
expected = {
"a": str,
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
"c": dict,
"d": Union[str, int, None],
}
actual = _get_all_basemodel_annotations(ModelD)
assert actual == expected
expected = {
"a": str,
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
"c": dict,
"d": Union[int, None],
}
actual = _get_all_basemodel_annotations(ModelD[int])
assert actual == expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.")
def test__get_all_basemodel_annotations_v1() -> None:
A = TypeVar("A")
class ModelA(BaseModel, Generic[A]):
a: A
class ModelB(ModelA[str]):
b: Annotated[ModelA[Dict[str, Any]], "foo"]
class Mixin(object):
def foo(self) -> str:
return "foo"
class ModelC(Mixin, ModelB):
c: dict
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
actual = _get_all_basemodel_annotations(ModelC)
assert actual == expected
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
actual = _get_all_basemodel_annotations(ModelB)
assert actual == expected
expected = {"a": Any}
actual = _get_all_basemodel_annotations(ModelA)
assert actual == expected
expected = {"a": int}
actual = _get_all_basemodel_annotations(ModelA[int])
assert actual == expected
D = TypeVar("D", bound=Union[str, int])
class ModelD(ModelC, Generic[D]):
d: Optional[D]
expected = {
"a": str,
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
"c": dict,
"d": Union[str, int, None],
}
actual = _get_all_basemodel_annotations(ModelD)
assert actual == expected
expected = {
"a": str,
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
"c": dict,
"d": Union[int, None],
}
actual = _get_all_basemodel_annotations(ModelD[int])
assert actual == expected