From ffd3d08a9a88b55b785559bcaea9a1e786318491 Mon Sep 17 00:00:00 2001 From: "vincent.min" Date: Wed, 26 Feb 2025 14:02:58 +0100 Subject: [PATCH 1/4] fix: make pydantic validator work with BaseTool --- libs/core/langchain_core/tools/base.py | 1 - libs/core/tests/unit_tests/test_tools.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 5a618f6a160..9417ee875f4 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -588,7 +588,6 @@ class ChildTool(BaseTool): return { k: getattr(result, k) for k, v in result_dict.items() - if k in tool_input } return tool_input diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index d83c386fd1a..5c19aae74db 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -628,7 +628,7 @@ def test_named_tool_decorator_return_direct() -> None: """Test functionality when arguments and return direct are provided as input.""" @tool("search", return_direct=True) - def search_api(query: str, *args: Any) -> str: + def search_api(query: str) -> str: """Search the API for the query.""" return "API result" From dfeba105ff18eb57c9a0106f0ee06c109be10de9 Mon Sep 17 00:00:00 2001 From: "vincent.min" Date: Wed, 26 Feb 2025 14:26:58 +0100 Subject: [PATCH 2/4] add test with pydantic validator --- libs/core/tests/unit_tests/test_tools.py | 29 +++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 5c19aae74db..996b6e017e1 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -21,7 +21,7 @@ from typing import ( ) import pytest -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, Field, ValidationError, model_validator from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from typing_extensions import TypedDict @@ -2560,3 +2560,30 @@ def test_tool_decorator_description() -> None: ] == "description" ) + +def test_tool_args_schema_with_pydantic_validator() -> None: + class NestedArgsSchema(BaseModel): + y: int + + class ArgsSchema(BaseModel): + x: NestedArgsSchema + + @model_validator(mode="before") + def wrap_in_x(cls, data: dict) -> dict: + if "x" not in data: + return {"x": data} + return data + + @tool(args_schema=ArgsSchema) + def foo(**args) -> ArgsSchema: + """Bar.""" + return ArgsSchema.model_validate(args) + + + # Test case where validator is identity function + valid_inputs = {"x": {"y": 5}} + assert foo.invoke(valid_inputs) == ArgsSchema.model_validate(valid_inputs) + + # Test case where validator wraps input in "x" + invalid_inputs = {"y": 5} + assert foo.invoke(invalid_inputs) == ArgsSchema.model_validate({"x": invalid_inputs}) From 096cb660d43a3651f9b4ce71e2b6f6ce8e09d881 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 28 Mar 2025 15:27:51 -0400 Subject: [PATCH 3/4] Update libs/core/tests/unit_tests/test_tools.py --- libs/core/tests/unit_tests/test_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 996b6e017e1..e8e1257d52b 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -628,7 +628,7 @@ def test_named_tool_decorator_return_direct() -> None: """Test functionality when arguments and return direct are provided as input.""" @tool("search", return_direct=True) - def search_api(query: str) -> str: + def search_api(query: str, *args: Any) -> str: """Search the API for the query.""" return "API result" From 300299d0d097319ec4481fb90a8f7f15599667cb Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Wed, 16 Jul 2025 10:54:27 -0400 Subject: [PATCH 4/4] format + typing --- libs/core/langchain_core/tools/base.py | 4 +--- libs/core/tests/unit_tests/test_tools.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 34bf4c4582b..4df4a98b203 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -689,9 +689,7 @@ class ChildTool(BaseTool): f"args_schema must be a Pydantic BaseModel, got {self.args_schema}" ) raise NotImplementedError(msg) - return { - k: getattr(result, k) for k, v in result_dict.items() - } + return {k: getattr(result, k) for k, v in result_dict.items()} return tool_input @model_validator(mode="before") diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 3e0ac61cbcd..934cf5c6364 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -634,7 +634,7 @@ def test_named_tool_decorator_return_direct() -> None: """Test functionality when arguments and return direct are provided as input.""" @tool("search", return_direct=True) - def search_api(query: str, *args: Any) -> str: + def search_api(query: str) -> str: """Search the API for the query.""" return "API result" @@ -2549,24 +2549,28 @@ def test_tool_args_schema_with_pydantic_validator() -> None: x: NestedArgsSchema @model_validator(mode="before") - def wrap_in_x(cls, data: dict) -> dict: + def wrap_in_x(cls, data: Any) -> Any: # noqa: N805 + if not isinstance(data, dict): + return {"x": data} + if "x" not in data: return {"x": data} return data @tool(args_schema=ArgsSchema) - def foo(**args) -> ArgsSchema: + def foo(**args: Any) -> ArgsSchema: """Bar.""" return ArgsSchema.model_validate(args) - # Test case where validator is identity function valid_inputs = {"x": {"y": 5}} assert foo.invoke(valid_inputs) == ArgsSchema.model_validate(valid_inputs) # Test case where validator wraps input in "x" invalid_inputs = {"y": 5} - assert foo.invoke(invalid_inputs) == ArgsSchema.model_validate({"x": invalid_inputs}) + assert foo.invoke(invalid_inputs) == ArgsSchema.model_validate( + {"x": invalid_inputs} + ) def test_title_property_preserved() -> None: