mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
core[patch]: Fix tool args schema inherited field parsing (#24936)
Fix #24925
This commit is contained in:
parent
fba65ba04f
commit
199e9c5ae0
@ -45,7 +45,7 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import Annotated, cast, get_args, get_origin
|
from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -88,11 +88,16 @@ from langchain_core.runnables.config import (
|
|||||||
run_in_executor,
|
run_in_executor,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.utils import accepts_context
|
from langchain_core.runnables.utils import accepts_context
|
||||||
from langchain_core.utils.function_calling import _parse_google_docstring
|
from langchain_core.utils.function_calling import (
|
||||||
|
_parse_google_docstring,
|
||||||
|
_py_38_safe_origin,
|
||||||
|
)
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
TypeBaseModel,
|
TypeBaseModel,
|
||||||
_create_subset_model,
|
_create_subset_model,
|
||||||
is_basemodel_subclass,
|
is_basemodel_subclass,
|
||||||
|
is_pydantic_v1_subclass,
|
||||||
|
is_pydantic_v2_subclass,
|
||||||
)
|
)
|
||||||
|
|
||||||
FILTERED_ARGS = ("run_manager", "callbacks")
|
FILTERED_ARGS = ("run_manager", "callbacks")
|
||||||
@ -387,7 +392,7 @@ class ChildTool(BaseTool):
|
|||||||
def tool_call_schema(self) -> Type[BaseModel]:
|
def tool_call_schema(self) -> Type[BaseModel]:
|
||||||
full_schema = self.get_input_schema()
|
full_schema = self.get_input_schema()
|
||||||
fields = []
|
fields = []
|
||||||
for name, type_ in full_schema.__annotations__.items():
|
for name, type_ in _get_all_basemodel_annotations(full_schema).items():
|
||||||
if not _is_injected_arg_type(type_):
|
if not _is_injected_arg_type(type_):
|
||||||
fields.append(name)
|
fields.append(name)
|
||||||
return _create_subset_model(
|
return _create_subset_model(
|
||||||
@ -1650,3 +1655,80 @@ def _filter_schema_args(func: Callable) -> List[str]:
|
|||||||
filter_args.append(config_param)
|
filter_args.append(config_param)
|
||||||
# filter_args.extend(_get_non_model_params(type_hints))
|
# filter_args.extend(_get_non_model_params(type_hints))
|
||||||
return filter_args
|
return filter_args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_all_basemodel_annotations(
|
||||||
|
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
||||||
|
) -> Dict[str, Type]:
|
||||||
|
# cls has no subscript: cls = FooBar
|
||||||
|
if isinstance(cls, type):
|
||||||
|
annotations: Dict[str, Type] = {}
|
||||||
|
for name, param in inspect.signature(cls).parameters.items():
|
||||||
|
annotations[name] = param.annotation
|
||||||
|
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
|
||||||
|
# cls has subscript: cls = FooBar[int]
|
||||||
|
else:
|
||||||
|
annotations = _get_all_basemodel_annotations(
|
||||||
|
get_origin(cls), default_to_bound=False
|
||||||
|
)
|
||||||
|
orig_bases = (cls,)
|
||||||
|
|
||||||
|
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
|
||||||
|
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
|
||||||
|
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
|
||||||
|
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
|
||||||
|
# if cls = FooBar[int], orig_bases will contain FooBar[int]
|
||||||
|
for parent in orig_bases:
|
||||||
|
# if class = FooBar inherits from Baz, parent = Baz
|
||||||
|
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
|
||||||
|
annotations.update(
|
||||||
|
_get_all_basemodel_annotations(parent, default_to_bound=False)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
parent_origin = get_origin(parent)
|
||||||
|
|
||||||
|
# if class = FooBar inherits from non-pydantic class
|
||||||
|
if not parent_origin:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if class = FooBar inherits from Baz[str]:
|
||||||
|
# parent = Baz[str],
|
||||||
|
# parent_origin = Baz,
|
||||||
|
# generic_type_vars = (type vars in Baz)
|
||||||
|
# generic_map = {type var in Baz: str}
|
||||||
|
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple())
|
||||||
|
generic_map = {
|
||||||
|
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
|
||||||
|
}
|
||||||
|
for field in getattr(parent_origin, "__annotations__", dict()):
|
||||||
|
annotations[field] = _replace_type_vars(
|
||||||
|
annotations[field], generic_map, default_to_bound
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
k: _replace_type_vars(v, default_to_bound=default_to_bound)
|
||||||
|
for k, v in annotations.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_type_vars(
|
||||||
|
type_: Type,
|
||||||
|
generic_map: Optional[Dict[TypeVar, Type]] = None,
|
||||||
|
default_to_bound: bool = True,
|
||||||
|
) -> Type:
|
||||||
|
generic_map = generic_map or {}
|
||||||
|
if isinstance(type_, TypeVar):
|
||||||
|
if type_ in generic_map:
|
||||||
|
return generic_map[type_]
|
||||||
|
elif default_to_bound:
|
||||||
|
return type_.__bound__ or Any
|
||||||
|
else:
|
||||||
|
return type_
|
||||||
|
elif (origin := get_origin(type_)) and (args := get_args(type_)):
|
||||||
|
new_args = tuple(
|
||||||
|
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
|
||||||
|
)
|
||||||
|
return _py_38_safe_origin(origin)[new_args]
|
||||||
|
else:
|
||||||
|
return type_
|
||||||
|
@ -53,6 +53,13 @@ def is_pydantic_v1_subclass(cls: Type) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_pydantic_v2_subclass(cls: Type) -> bool:
|
||||||
|
"""Check if the installed Pydantic version is 1.x-like."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_subclass(cls: Type) -> bool:
|
def is_basemodel_subclass(cls: Type) -> bool:
|
||||||
"""Check if the given class is a subclass of Pydantic BaseModel.
|
"""Check if the given class is a subclass of Pydantic BaseModel.
|
||||||
|
|
||||||
|
@ -8,10 +8,22 @@ import threading
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from typing_extensions import Annotated, TypedDict
|
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
|
||||||
|
from typing_extensions import Annotated, TypedDict, TypeVar
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
@ -32,12 +44,13 @@ from langchain_core.tools import (
|
|||||||
StructuredTool,
|
StructuredTool,
|
||||||
Tool,
|
Tool,
|
||||||
ToolException,
|
ToolException,
|
||||||
|
_get_all_basemodel_annotations,
|
||||||
_is_message_content_block,
|
_is_message_content_block,
|
||||||
_is_message_content_type,
|
_is_message_content_type,
|
||||||
tool,
|
tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||||
from langchain_core.utils.pydantic import _create_subset_model
|
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
|
||||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||||
from tests.unit_tests.pydantic_utils import _schema
|
from tests.unit_tests.pydantic_utils import _schema
|
||||||
|
|
||||||
@ -1452,6 +1465,66 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_inherited_injected_arg() -> None:
|
||||||
|
class barSchema(BaseModel):
|
||||||
|
"""bar."""
|
||||||
|
|
||||||
|
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
|
||||||
|
..., description="123"
|
||||||
|
)
|
||||||
|
|
||||||
|
class fooSchema(barSchema):
|
||||||
|
"""foo."""
|
||||||
|
|
||||||
|
x: int = Field(..., description="abc")
|
||||||
|
|
||||||
|
class InheritedInjectedArgTool(BaseTool):
|
||||||
|
name: str = "foo"
|
||||||
|
description: str = "foo."
|
||||||
|
args_schema: Type[BaseModel] = fooSchema
|
||||||
|
|
||||||
|
def _run(self, x: int, y: str) -> Any:
|
||||||
|
return y
|
||||||
|
|
||||||
|
tool_ = InheritedInjectedArgTool()
|
||||||
|
assert tool_.get_input_schema().schema() == {
|
||||||
|
"title": "fooSchema",
|
||||||
|
"description": "foo.",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"description": "abc", "title": "X", "type": "integer"},
|
||||||
|
"y": {"description": "123", "title": "Y", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["y", "x"],
|
||||||
|
}
|
||||||
|
assert tool_.tool_call_schema.schema() == {
|
||||||
|
"title": "foo",
|
||||||
|
"description": "foo.",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},
|
||||||
|
"required": ["x"],
|
||||||
|
}
|
||||||
|
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
|
||||||
|
assert tool_.invoke(
|
||||||
|
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
|
||||||
|
) == ToolMessage("bar", tool_call_id="123", name="foo")
|
||||||
|
expected_error = (
|
||||||
|
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
|
||||||
|
)
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
tool_.invoke({"x": 5})
|
||||||
|
|
||||||
|
assert convert_to_openai_function(tool_) == {
|
||||||
|
"name": "foo",
|
||||||
|
"description": "foo.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"x": {"type": "integer", "description": "abc"}},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_parametrized_tools() -> list:
|
def _get_parametrized_tools() -> list:
|
||||||
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
|
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
|
||||||
"""my_tool."""
|
"""my_tool."""
|
||||||
@ -1484,7 +1557,6 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
|||||||
|
|
||||||
def generate_models() -> List[Any]:
|
def generate_models() -> List[Any]:
|
||||||
"""Generate a list of base models depending on the pydantic version."""
|
"""Generate a list of base models depending on the pydantic version."""
|
||||||
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
|
|
||||||
|
|
||||||
class FooProper(BaseModelProper):
|
class FooProper(BaseModelProper):
|
||||||
a: int
|
a: int
|
||||||
@ -1670,3 +1742,124 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
|
|||||||
)
|
)
|
||||||
def test__is_message_content_type(obj: Any, expected: bool) -> None:
|
def test__is_message_content_type(obj: Any, expected: bool) -> None:
|
||||||
assert _is_message_content_type(obj) is expected
|
assert _is_message_content_type(obj) is expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
||||||
|
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
||||||
|
def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
|
||||||
|
A = TypeVar("A")
|
||||||
|
|
||||||
|
if use_v1_namespace:
|
||||||
|
|
||||||
|
class ModelA(BaseModel, Generic[A]):
|
||||||
|
a: A
|
||||||
|
else:
|
||||||
|
|
||||||
|
class ModelA(BaseModelProper, Generic[A]): # type: ignore[no-redef]
|
||||||
|
a: A
|
||||||
|
|
||||||
|
class ModelB(ModelA[str]):
|
||||||
|
b: Annotated[ModelA[Dict[str, Any]], "foo"]
|
||||||
|
|
||||||
|
class Mixin(object):
|
||||||
|
def foo(self) -> str:
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
class ModelC(Mixin, ModelB):
|
||||||
|
c: dict
|
||||||
|
|
||||||
|
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelC)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelB)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
expected = {"a": Any}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelA)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
expected = {"a": int}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelA[int])
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
D = TypeVar("D", bound=Union[str, int])
|
||||||
|
|
||||||
|
class ModelD(ModelC, Generic[D]):
|
||||||
|
d: Optional[D]
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"a": str,
|
||||||
|
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
|
||||||
|
"c": dict,
|
||||||
|
"d": Union[str, int, None],
|
||||||
|
}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelD)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"a": str,
|
||||||
|
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
|
||||||
|
"c": dict,
|
||||||
|
"d": Union[int, None],
|
||||||
|
}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelD[int])
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.")
|
||||||
|
def test__get_all_basemodel_annotations_v1() -> None:
|
||||||
|
A = TypeVar("A")
|
||||||
|
|
||||||
|
class ModelA(BaseModel, Generic[A]):
|
||||||
|
a: A
|
||||||
|
|
||||||
|
class ModelB(ModelA[str]):
|
||||||
|
b: Annotated[ModelA[Dict[str, Any]], "foo"]
|
||||||
|
|
||||||
|
class Mixin(object):
|
||||||
|
def foo(self) -> str:
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
class ModelC(Mixin, ModelB):
|
||||||
|
c: dict
|
||||||
|
|
||||||
|
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelC)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelB)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
expected = {"a": Any}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelA)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
expected = {"a": int}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelA[int])
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
D = TypeVar("D", bound=Union[str, int])
|
||||||
|
|
||||||
|
class ModelD(ModelC, Generic[D]):
|
||||||
|
d: Optional[D]
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"a": str,
|
||||||
|
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
|
||||||
|
"c": dict,
|
||||||
|
"d": Union[str, int, None],
|
||||||
|
}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelD)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"a": str,
|
||||||
|
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
|
||||||
|
"c": dict,
|
||||||
|
"d": Union[int, None],
|
||||||
|
}
|
||||||
|
actual = _get_all_basemodel_annotations(ModelD[int])
|
||||||
|
assert actual == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user