core[patch]: make get_all_basemodel_annotations public (#27762)

This commit is contained in:
Bagatur 2024-10-30 14:43:23 -07:00 committed by GitHub
parent 807314661d
commit 5e3cee6c98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 17 deletions

View File

@ -362,7 +362,7 @@ class ChildTool(BaseTool):
def tool_call_schema(self) -> Type[BaseModel]: def tool_call_schema(self) -> Type[BaseModel]:
full_schema = self.get_input_schema() full_schema = self.get_input_schema()
fields = [] fields = []
for name, type_ in _get_all_basemodel_annotations(full_schema).items(): for name, type_ in get_all_basemodel_annotations(full_schema).items():
if not _is_injected_arg_type(type_): if not _is_injected_arg_type(type_):
fields.append(name) fields.append(name)
return _create_subset_model( return _create_subset_model(
@ -858,7 +858,7 @@ def _is_injected_arg_type(type_: Type) -> bool:
) )
def _get_all_basemodel_annotations( def get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> Dict[str, Type]: ) -> Dict[str, Type]:
# cls has no subscript: cls = FooBar # cls has no subscript: cls = FooBar
@ -876,7 +876,7 @@ def _get_all_basemodel_annotations(
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple()) orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
# cls has subscript: cls = FooBar[int] # cls has subscript: cls = FooBar[int]
else: else:
annotations = _get_all_basemodel_annotations( annotations = get_all_basemodel_annotations(
get_origin(cls), default_to_bound=False get_origin(cls), default_to_bound=False
) )
orig_bases = (cls,) orig_bases = (cls,)
@ -890,7 +890,7 @@ def _get_all_basemodel_annotations(
# if class = FooBar inherits from Baz, parent = Baz # if class = FooBar inherits from Baz, parent = Baz
if isinstance(parent, type) and is_pydantic_v1_subclass(parent): if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
annotations.update( annotations.update(
_get_all_basemodel_annotations(parent, default_to_bound=False) get_all_basemodel_annotations(parent, default_to_bound=False)
) )
continue continue

View File

@ -48,9 +48,9 @@ from langchain_core.tools import (
from langchain_core.tools.base import ( from langchain_core.tools.base import (
InjectedToolArg, InjectedToolArg,
SchemaAnnotationError, SchemaAnnotationError,
_get_all_basemodel_annotations,
_is_message_content_block, _is_message_content_block,
_is_message_content_type, _is_message_content_type,
get_all_basemodel_annotations,
) )
from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
@ -1773,19 +1773,19 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
c: dict c: dict
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict} expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
actual = _get_all_basemodel_annotations(ModelC) actual = get_all_basemodel_annotations(ModelC)
assert actual == expected assert actual == expected
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]} expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
actual = _get_all_basemodel_annotations(ModelB) actual = get_all_basemodel_annotations(ModelB)
assert actual == expected assert actual == expected
expected = {"a": Any} expected = {"a": Any}
actual = _get_all_basemodel_annotations(ModelA) actual = get_all_basemodel_annotations(ModelA)
assert actual == expected assert actual == expected
expected = {"a": int} expected = {"a": int}
actual = _get_all_basemodel_annotations(ModelA[int]) actual = get_all_basemodel_annotations(ModelA[int])
assert actual == expected assert actual == expected
D = TypeVar("D", bound=Union[str, int]) D = TypeVar("D", bound=Union[str, int])
@ -1799,7 +1799,7 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
"c": dict, "c": dict,
"d": Union[str, int, None], "d": Union[str, int, None],
} }
actual = _get_all_basemodel_annotations(ModelD) actual = get_all_basemodel_annotations(ModelD)
assert actual == expected assert actual == expected
expected = { expected = {
@ -1808,7 +1808,7 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
"c": dict, "c": dict,
"d": Union[int, None], "d": Union[int, None],
} }
actual = _get_all_basemodel_annotations(ModelD[int]) actual = get_all_basemodel_annotations(ModelD[int])
assert actual == expected assert actual == expected
@ -1830,19 +1830,19 @@ def test__get_all_basemodel_annotations_v1() -> None:
c: dict c: dict
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict} expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
actual = _get_all_basemodel_annotations(ModelC) actual = get_all_basemodel_annotations(ModelC)
assert actual == expected assert actual == expected
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]} expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
actual = _get_all_basemodel_annotations(ModelB) actual = get_all_basemodel_annotations(ModelB)
assert actual == expected assert actual == expected
expected = {"a": Any} expected = {"a": Any}
actual = _get_all_basemodel_annotations(ModelA) actual = get_all_basemodel_annotations(ModelA)
assert actual == expected assert actual == expected
expected = {"a": int} expected = {"a": int}
actual = _get_all_basemodel_annotations(ModelA[int]) actual = get_all_basemodel_annotations(ModelA[int])
assert actual == expected assert actual == expected
D = TypeVar("D", bound=Union[str, int]) D = TypeVar("D", bound=Union[str, int])
@ -1856,7 +1856,7 @@ def test__get_all_basemodel_annotations_v1() -> None:
"c": dict, "c": dict,
"d": Union[str, int, None], "d": Union[str, int, None],
} }
actual = _get_all_basemodel_annotations(ModelD) actual = get_all_basemodel_annotations(ModelD)
assert actual == expected assert actual == expected
expected = { expected = {
@ -1865,7 +1865,7 @@ def test__get_all_basemodel_annotations_v1() -> None:
"c": dict, "c": dict,
"d": Union[int, None], "d": Union[int, None],
} }
actual = _get_all_basemodel_annotations(ModelD[int]) actual = get_all_basemodel_annotations(ModelD[int])
assert actual == expected assert actual == expected