mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 06:40:04 +00:00
core[patch]: Fix tool args schema inherited field parsing (#24936)
Fix #24925
This commit is contained in:
parent
fba65ba04f
commit
199e9c5ae0
@ -45,7 +45,7 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, cast, get_args, get_origin
|
||||
from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -88,11 +88,16 @@ from langchain_core.runnables.config import (
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.runnables.utils import accepts_context
|
||||
from langchain_core.utils.function_calling import _parse_google_docstring
|
||||
from langchain_core.utils.function_calling import (
|
||||
_parse_google_docstring,
|
||||
_py_38_safe_origin,
|
||||
)
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
_create_subset_model,
|
||||
is_basemodel_subclass,
|
||||
is_pydantic_v1_subclass,
|
||||
is_pydantic_v2_subclass,
|
||||
)
|
||||
|
||||
FILTERED_ARGS = ("run_manager", "callbacks")
|
||||
@ -387,7 +392,7 @@ class ChildTool(BaseTool):
|
||||
def tool_call_schema(self) -> Type[BaseModel]:
|
||||
full_schema = self.get_input_schema()
|
||||
fields = []
|
||||
for name, type_ in full_schema.__annotations__.items():
|
||||
for name, type_ in _get_all_basemodel_annotations(full_schema).items():
|
||||
if not _is_injected_arg_type(type_):
|
||||
fields.append(name)
|
||||
return _create_subset_model(
|
||||
@ -1650,3 +1655,80 @@ def _filter_schema_args(func: Callable) -> List[str]:
|
||||
filter_args.append(config_param)
|
||||
# filter_args.extend(_get_non_model_params(type_hints))
|
||||
return filter_args
|
||||
|
||||
|
||||
def _get_all_basemodel_annotations(
|
||||
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
||||
) -> Dict[str, Type]:
|
||||
# cls has no subscript: cls = FooBar
|
||||
if isinstance(cls, type):
|
||||
annotations: Dict[str, Type] = {}
|
||||
for name, param in inspect.signature(cls).parameters.items():
|
||||
annotations[name] = param.annotation
|
||||
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
|
||||
# cls has subscript: cls = FooBar[int]
|
||||
else:
|
||||
annotations = _get_all_basemodel_annotations(
|
||||
get_origin(cls), default_to_bound=False
|
||||
)
|
||||
orig_bases = (cls,)
|
||||
|
||||
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
|
||||
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
|
||||
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
|
||||
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
|
||||
# if cls = FooBar[int], orig_bases will contain FooBar[int]
|
||||
for parent in orig_bases:
|
||||
# if class = FooBar inherits from Baz, parent = Baz
|
||||
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
|
||||
annotations.update(
|
||||
_get_all_basemodel_annotations(parent, default_to_bound=False)
|
||||
)
|
||||
continue
|
||||
|
||||
parent_origin = get_origin(parent)
|
||||
|
||||
# if class = FooBar inherits from non-pydantic class
|
||||
if not parent_origin:
|
||||
continue
|
||||
|
||||
# if class = FooBar inherits from Baz[str]:
|
||||
# parent = Baz[str],
|
||||
# parent_origin = Baz,
|
||||
# generic_type_vars = (type vars in Baz)
|
||||
# generic_map = {type var in Baz: str}
|
||||
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple())
|
||||
generic_map = {
|
||||
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
|
||||
}
|
||||
for field in getattr(parent_origin, "__annotations__", dict()):
|
||||
annotations[field] = _replace_type_vars(
|
||||
annotations[field], generic_map, default_to_bound
|
||||
)
|
||||
|
||||
return {
|
||||
k: _replace_type_vars(v, default_to_bound=default_to_bound)
|
||||
for k, v in annotations.items()
|
||||
}
|
||||
|
||||
|
||||
def _replace_type_vars(
|
||||
type_: Type,
|
||||
generic_map: Optional[Dict[TypeVar, Type]] = None,
|
||||
default_to_bound: bool = True,
|
||||
) -> Type:
|
||||
generic_map = generic_map or {}
|
||||
if isinstance(type_, TypeVar):
|
||||
if type_ in generic_map:
|
||||
return generic_map[type_]
|
||||
elif default_to_bound:
|
||||
return type_.__bound__ or Any
|
||||
else:
|
||||
return type_
|
||||
elif (origin := get_origin(type_)) and (args := get_args(type_)):
|
||||
new_args = tuple(
|
||||
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
|
||||
)
|
||||
return _py_38_safe_origin(origin)[new_args]
|
||||
else:
|
||||
return type_
|
||||
|
@ -53,6 +53,13 @@ def is_pydantic_v1_subclass(cls: Type) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_pydantic_v2_subclass(cls: Type) -> bool:
|
||||
"""Check if the installed Pydantic version is 1.x-like."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
|
||||
|
||||
|
||||
def is_basemodel_subclass(cls: Type) -> bool:
|
||||
"""Check if the given class is a subclass of Pydantic BaseModel.
|
||||
|
||||
|
@ -8,10 +8,22 @@ import threading
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
|
||||
from typing_extensions import Annotated, TypedDict, TypeVar
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
@ -32,12 +44,13 @@ from langchain_core.tools import (
|
||||
StructuredTool,
|
||||
Tool,
|
||||
ToolException,
|
||||
_get_all_basemodel_annotations,
|
||||
_is_message_content_block,
|
||||
_is_message_content_type,
|
||||
tool,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.pydantic import _create_subset_model
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
@ -1452,6 +1465,66 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_tool_inherited_injected_arg() -> None:
|
||||
class barSchema(BaseModel):
|
||||
"""bar."""
|
||||
|
||||
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
|
||||
..., description="123"
|
||||
)
|
||||
|
||||
class fooSchema(barSchema):
|
||||
"""foo."""
|
||||
|
||||
x: int = Field(..., description="abc")
|
||||
|
||||
class InheritedInjectedArgTool(BaseTool):
|
||||
name: str = "foo"
|
||||
description: str = "foo."
|
||||
args_schema: Type[BaseModel] = fooSchema
|
||||
|
||||
def _run(self, x: int, y: str) -> Any:
|
||||
return y
|
||||
|
||||
tool_ = InheritedInjectedArgTool()
|
||||
assert tool_.get_input_schema().schema() == {
|
||||
"title": "fooSchema",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"description": "abc", "title": "X", "type": "integer"},
|
||||
"y": {"description": "123", "title": "Y", "type": "string"},
|
||||
},
|
||||
"required": ["y", "x"],
|
||||
}
|
||||
assert tool_.tool_call_schema.schema() == {
|
||||
"title": "foo",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
"properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
|
||||
assert tool_.invoke(
|
||||
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
|
||||
) == ToolMessage("bar", tool_call_id="123", name="foo")
|
||||
expected_error = (
|
||||
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
|
||||
)
|
||||
with pytest.raises(expected_error):
|
||||
tool_.invoke({"x": 5})
|
||||
|
||||
assert convert_to_openai_function(tool_) == {
|
||||
"name": "foo",
|
||||
"description": "foo.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "abc"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_parametrized_tools() -> list:
|
||||
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
|
||||
"""my_tool."""
|
||||
@ -1484,7 +1557,6 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
||||
|
||||
def generate_models() -> List[Any]:
|
||||
"""Generate a list of base models depending on the pydantic version."""
|
||||
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
|
||||
|
||||
class FooProper(BaseModelProper):
|
||||
a: int
|
||||
@ -1670,3 +1742,124 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
|
||||
)
|
||||
def test__is_message_content_type(obj: Any, expected: bool) -> None:
|
||||
assert _is_message_content_type(obj) is expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
||||
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
||||
def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
|
||||
A = TypeVar("A")
|
||||
|
||||
if use_v1_namespace:
|
||||
|
||||
class ModelA(BaseModel, Generic[A]):
|
||||
a: A
|
||||
else:
|
||||
|
||||
class ModelA(BaseModelProper, Generic[A]): # type: ignore[no-redef]
|
||||
a: A
|
||||
|
||||
class ModelB(ModelA[str]):
|
||||
b: Annotated[ModelA[Dict[str, Any]], "foo"]
|
||||
|
||||
class Mixin(object):
|
||||
def foo(self) -> str:
|
||||
return "foo"
|
||||
|
||||
class ModelC(Mixin, ModelB):
|
||||
c: dict
|
||||
|
||||
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
|
||||
actual = _get_all_basemodel_annotations(ModelC)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
|
||||
actual = _get_all_basemodel_annotations(ModelB)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": Any}
|
||||
actual = _get_all_basemodel_annotations(ModelA)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": int}
|
||||
actual = _get_all_basemodel_annotations(ModelA[int])
|
||||
assert actual == expected
|
||||
|
||||
D = TypeVar("D", bound=Union[str, int])
|
||||
|
||||
class ModelD(ModelC, Generic[D]):
|
||||
d: Optional[D]
|
||||
|
||||
expected = {
|
||||
"a": str,
|
||||
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
|
||||
"c": dict,
|
||||
"d": Union[str, int, None],
|
||||
}
|
||||
actual = _get_all_basemodel_annotations(ModelD)
|
||||
assert actual == expected
|
||||
|
||||
expected = {
|
||||
"a": str,
|
||||
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
|
||||
"c": dict,
|
||||
"d": Union[int, None],
|
||||
}
|
||||
actual = _get_all_basemodel_annotations(ModelD[int])
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.")
|
||||
def test__get_all_basemodel_annotations_v1() -> None:
|
||||
A = TypeVar("A")
|
||||
|
||||
class ModelA(BaseModel, Generic[A]):
|
||||
a: A
|
||||
|
||||
class ModelB(ModelA[str]):
|
||||
b: Annotated[ModelA[Dict[str, Any]], "foo"]
|
||||
|
||||
class Mixin(object):
|
||||
def foo(self) -> str:
|
||||
return "foo"
|
||||
|
||||
class ModelC(Mixin, ModelB):
|
||||
c: dict
|
||||
|
||||
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
|
||||
actual = _get_all_basemodel_annotations(ModelC)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
|
||||
actual = _get_all_basemodel_annotations(ModelB)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": Any}
|
||||
actual = _get_all_basemodel_annotations(ModelA)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": int}
|
||||
actual = _get_all_basemodel_annotations(ModelA[int])
|
||||
assert actual == expected
|
||||
|
||||
D = TypeVar("D", bound=Union[str, int])
|
||||
|
||||
class ModelD(ModelC, Generic[D]):
|
||||
d: Optional[D]
|
||||
|
||||
expected = {
|
||||
"a": str,
|
||||
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
|
||||
"c": dict,
|
||||
"d": Union[str, int, None],
|
||||
}
|
||||
actual = _get_all_basemodel_annotations(ModelD)
|
||||
assert actual == expected
|
||||
|
||||
expected = {
|
||||
"a": str,
|
||||
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
|
||||
"c": dict,
|
||||
"d": Union[int, None],
|
||||
}
|
||||
actual = _get_all_basemodel_annotations(ModelD[int])
|
||||
assert actual == expected
|
||||
|
Loading…
Reference in New Issue
Block a user