From 5f98975be0d2005cfc98977cfe35f387c4ab6fbd Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 3 Sep 2024 21:53:50 -0400 Subject: [PATCH] core[patch]: Fix injected args in tool signature (#25991) - Fix injected args in tool signature - Fix another unit test that was using the wrong namespace import in pydantic --- libs/core/langchain_core/utils/pydantic.py | 11 +++++++++ libs/core/tests/unit_tests/test_tools.py | 28 ++++++++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index fe28ee5ce15..dd6b9aba597 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -273,6 +273,17 @@ def _create_subset_model_v2( fields[field_name] = (field.annotation, field_info) rtn = create_model(name, **fields) # type: ignore + # TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations. + # This is done to preserve __annotations__ when working with pydantic 2.x + # and using the Annotated type with TypedDict. + # Comment out the following line, to trigger the relevant test case. + selected_annotations = [ + (name, annotation) + for name, annotation in model.__annotations__.items() + if name in field_names + ] + + rtn.__annotations__ = dict(selected_annotations) rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "") return rtn diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index bcbcc668737..f1afba566aa 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1756,13 +1756,17 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None: A = TypeVar("A") if use_v1_namespace: + from pydantic.v1 import BaseModel as BM1 - class ModelA(BaseModel, Generic[A], extra="allow"): + class ModelA(BM1, Generic[A], extra="allow"): a: A else: + from pydantic import BaseModel as BM2 + from pydantic import ConfigDict - class ModelA(BaseModelProper, Generic[A], extra="allow"): # type: ignore[no-redef] + class ModelA(BM2, Generic[A], extra="allow"): # type: ignore[no-redef] a: A + model_config = ConfigDict(arbitrary_types_allowed=True) class ModelB(ModelA[str]): b: Annotated[ModelA[Dict[str, Any]], "foo"] @@ -1871,6 +1875,26 @@ def test__get_all_basemodel_annotations_v1() -> None: assert actual == expected +def test_tool_annotations_preserved() -> None: + """Test that annotations are preserved when creating a tool.""" + + @tool + def my_tool(val: int, other_val: Annotated[dict, "my annotation"]) -> str: + """Tool docstring.""" + return "foo" + + schema = my_tool.get_input_schema() # type: ignore[attr-defined] + + func = my_tool.func # type: ignore[attr-defined] + + expected_type_hints = { + name: hint + for name, hint in func.__annotations__.items() + if name in inspect.signature(func).parameters + } + assert schema.__annotations__ == expected_type_hints + + @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") def test_tool_args_schema_pydantic_v2_with_metadata() -> None: from pydantic import BaseModel as BaseModelV2