diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 1929a391501..a300c27b8e1 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -38,6 +38,7 @@ from pydantic import ( validate_arguments, ) from pydantic.v1 import BaseModel as BaseModelV1 +from pydantic.v1 import validate_arguments as validate_arguments_v1 from typing_extensions import Annotated from langchain_core._api import deprecated @@ -164,6 +165,35 @@ def _infer_arg_descriptions( return description, arg_descriptions +def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool: + """Determine if a type annotation is a Pydantic model.""" + base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel + try: + return issubclass(annotation, base_model_class) + except TypeError: + return False + + +def _function_annotations_are_pydantic_v1( + signature: inspect.Signature, func: Callable +) -> bool: + """Determine if all Pydantic annotations in a function signature are from V1.""" + any_v1_annotations = any( + _is_pydantic_annotation(parameter.annotation, pydantic_version="v1") + for parameter in signature.parameters.values() + ) + any_v2_annotations = any( + _is_pydantic_annotation(parameter.annotation, pydantic_version="v2") + for parameter in signature.parameters.values() + ) + if any_v1_annotations and any_v2_annotations: + raise NotImplementedError( + f"Function {func} contains a mix of Pydantic v1 and v2 annotations. " + "Only one version of Pydantic annotations per function is supported." + ) + return any_v1_annotations and not any_v2_annotations + + class _SchemaConfig: """Configuration for the pydantic model. @@ -208,16 +238,19 @@ def create_schema_from_function( Returns: A pydantic model with the same arguments as the function. """ - # https://docs.pydantic.dev/latest/usage/validation_decorator/ - with warnings.catch_warnings(): - # We are using deprecated functionality here. - # This code should be re-written to simply construct a pydantic model - # using inspect.signature and create_model. - warnings.simplefilter("ignore", category=PydanticDeprecationWarning) - validated = validate_arguments(func, config=_SchemaConfig) # type: ignore - sig = inspect.signature(func) + if _function_annotations_are_pydantic_v1(sig, func): + validated = validate_arguments_v1(func, config=_SchemaConfig) # type: ignore + else: + # https://docs.pydantic.dev/latest/usage/validation_decorator/ + with warnings.catch_warnings(): + # We are using deprecated functionality here. + # This code should be re-written to simply construct a pydantic model + # using inspect.signature and create_model. + warnings.simplefilter("ignore", category=PydanticDeprecationWarning) + validated = validate_arguments(func, config=_SchemaConfig) # type: ignore + # Let's ignore `self` and `cls` arguments for class and instance methods if func.__qualname__ and "." in func.__qualname__: # Then it likely belongs in a class namespace diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index e2dabc1c125..0c3593c939b 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -23,7 +23,7 @@ from typing import ( import pytest from pydantic import BaseModel, Field, ValidationError -from pydantic import BaseModel as BaseModelProper +from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Annotated, TypedDict, TypeVar from langchain_core import tools @@ -80,6 +80,14 @@ class _MockSchema(BaseModel): arg3: Optional[dict] = None +class _MockSchemaV1(BaseModelV1): + """Return the arguments directly.""" + + arg1: int + arg2: bool + arg3: Optional[dict] = None + + class _MockStructuredTool(BaseTool): name: str = "structured_api" args_schema: Type[BaseModel] = _MockSchema @@ -169,6 +177,13 @@ def test_decorator_with_specified_schema() -> None: assert isinstance(tool_func, BaseTool) assert tool_func.args_schema == _MockSchema + @tool(args_schema=_MockSchemaV1) + def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + return f"{arg1} {arg2} {arg3}" + + assert isinstance(tool_func_v1, BaseTool) + assert tool_func_v1.args_schema == _MockSchemaV1 + def test_decorated_function_schema_equivalent() -> None: """Test that a BaseTool without a schema meets expectations.""" @@ -302,6 +317,51 @@ def test_structured_tool_types_parsed() -> None: assert result == expected +def test_structured_tool_types_parsed_pydantic_v1() -> None: + """Test the non-primitive types are correctly passed to structured tools.""" + + class SomeBaseModel(BaseModelV1): + foo: str + + class AnotherBaseModel(BaseModelV1): + bar: str + + @tool + def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel: + """Return the arguments directly.""" + return AnotherBaseModel(bar=some_base_model.foo) + + assert isinstance(structured_tool, StructuredTool) + + expected = AnotherBaseModel(bar="baz") + for arg in [ + SomeBaseModel(foo="baz"), + SomeBaseModel(foo="baz").dict(), + ]: + args = {"some_base_model": arg} + result = structured_tool.run(args) + assert result == expected + + +def test_structured_tool_types_parsed_pydantic_mixed() -> None: + """Test handling of tool with mixed Pydantic version arguments.""" + + class SomeBaseModel(BaseModelV1): + foo: str + + class AnotherBaseModel(BaseModel): + bar: str + + with pytest.raises(NotImplementedError): + + @tool + def structured_tool( + some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel + ) -> None: + """Return the arguments directly.""" + pass + + def test_base_tool_inheritance_base_schema() -> None: """Test schema is correctly inferred when inheriting from BaseTool.""" @@ -1562,7 +1622,7 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None: def generate_models() -> List[Any]: """Generate a list of base models depending on the pydantic version.""" - class FooProper(BaseModelProper): + class FooProper(BaseModel): a: int b: str