diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index e54a09709d6..6d0baf148dd 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -687,9 +687,7 @@ class ChildTool(BaseTool): f"args_schema must be a Pydantic BaseModel, got {self.args_schema}" ) raise NotImplementedError(msg) - return { - k: getattr(result, k) for k, v in result_dict.items() if k in tool_input - } + return {k: getattr(result, k) for k, v in result_dict.items()} return tool_input @model_validator(mode="before") diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 57b4573d70d..5278b641f75 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -21,7 +21,7 @@ from typing import ( ) import pytest -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, Field, ValidationError, model_validator from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from typing_extensions import TypedDict @@ -634,7 +634,7 @@ def test_named_tool_decorator_return_direct() -> None: """Test functionality when arguments and return direct are provided as input.""" @tool("search", return_direct=True) - def search_api(query: str, *args: Any) -> str: + def search_api(query: str) -> str: """Search the API for the query.""" return "API result" @@ -2547,6 +2547,38 @@ def test_tool_decorator_description() -> None: ) +def test_tool_args_schema_with_pydantic_validator() -> None: + class NestedArgsSchema(BaseModel): + y: int + + class ArgsSchema(BaseModel): + x: NestedArgsSchema + + @model_validator(mode="before") + def wrap_in_x(cls, data: Any) -> Any: # noqa: N805 + if not isinstance(data, dict): + return {"x": data} + + if "x" not in data: + return {"x": data} + return data + + @tool(args_schema=ArgsSchema) + def foo(**args: Any) -> ArgsSchema: + """Bar.""" + return ArgsSchema.model_validate(args) + + # Test case where validator is identity function + valid_inputs = {"x": {"y": 5}} + assert foo.invoke(valid_inputs) == ArgsSchema.model_validate(valid_inputs) + + # Test case where validator wraps input in "x" + invalid_inputs = {"y": 5} + assert foo.invoke(invalid_inputs) == ArgsSchema.model_validate( + {"x": invalid_inputs} + ) + + def test_title_property_preserved() -> None: """Test that the title property is preserved when generating schema.