diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 2ae9d0a23e0..2c7a434663c 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -42,11 +42,10 @@ from typing import ( Tuple, Type, Union, - cast, get_type_hints, ) -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, cast, get_args, get_origin from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -89,6 +88,10 @@ from langchain_core.runnables.config import ( run_in_executor, ) from langchain_core.runnables.utils import accepts_context +from langchain_core.utils.pydantic import ( + _create_subset_model, + is_basemodel_subclass, +) FILTERED_ARGS = ("run_manager", "callbacks") @@ -110,34 +113,6 @@ def _get_annotation_description(arg_type: Type) -> str | None: return None -def _create_subset_model( - name: str, - model: Type[BaseModel], - field_names: list, - *, - descriptions: Optional[dict] = None, - fn_description: Optional[str] = None, -) -> Type[BaseModel]: - """Create a pydantic model with only a subset of model's fields.""" - fields = {} - - for field_name in field_names: - field = model.__fields__[field_name] - t = ( - # this isn't perfect but should work for most functions - field.outer_type_ - if field.required and not field.allow_none - else Optional[field.outer_type_] - ) - if descriptions and field_name in descriptions: - field.field_info.description = descriptions[field_name] - fields[field_name] = (t, field.field_info) - - rtn = create_model(name, **fields) # type: ignore - rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "") - return rtn - - def _get_filtered_args( inferred_model: Type[BaseModel], func: Callable, @@ -403,6 +378,16 @@ class ChildTool(BaseTool): two-tuple corresponding to the (content, artifact) of a ToolMessage. """ + def __init__(self, **kwargs: Any) -> None: + """Initialize the tool.""" + if "args_schema" in kwargs and kwargs["args_schema"] is not None: + if not is_basemodel_subclass(kwargs["args_schema"]): + raise TypeError( + f"args_schema must be a subclass of pydantic BaseModel. " + f"Got: {kwargs['args_schema']}." + ) + super().__init__(**kwargs) + class Config(Serializable.Config): """Configuration for this pydantic object.""" diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index e60a8d9e21e..87586302215 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -1,7 +1,11 @@ """Utilities for tests.""" +from __future__ import annotations + +import inspect +import textwrap from functools import wraps -from typing import Any, Callable, Dict, Type +from typing import Any, Callable, Dict, List, Optional, Type from langchain_core.pydantic_v1 import BaseModel, root_validator @@ -19,6 +23,66 @@ def get_pydantic_major_version() -> int: PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() +def is_basemodel_subclass(cls: Type) -> bool: + """Check if the given class is a subclass of Pydantic BaseModel. + + Check if the given class is a subclass of any of the following: + + * pydantic.BaseModel in Pydantic 1.x + * pydantic.BaseModel in Pydantic 2.x + * pydantic.v1.BaseModel in Pydantic 2.x + """ + # Before we can use issubclass on the cls we need to check if it is a class + if not inspect.isclass(cls): + return False + + if PYDANTIC_MAJOR_VERSION == 1: + from pydantic import BaseModel as BaseModelV1Proper + + if issubclass(cls, BaseModelV1Proper): + return True + elif PYDANTIC_MAJOR_VERSION == 2: + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + if issubclass(cls, BaseModelV2): + return True + + if issubclass(cls, BaseModelV1): + return True + else: + raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}") + return False + + +def is_basemodel_instance(obj: Any) -> bool: + """Check if the given class is an instance of Pydantic BaseModel. + + Check if the given class is an instance of any of the following: + + * pydantic.BaseModel in Pydantic 1.x + * pydantic.BaseModel in Pydantic 2.x + * pydantic.v1.BaseModel in Pydantic 2.x + """ + if PYDANTIC_MAJOR_VERSION == 1: + from pydantic import BaseModel as BaseModelV1Proper + + if isinstance(obj, BaseModelV1Proper): + return True + elif PYDANTIC_MAJOR_VERSION == 2: + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + if isinstance(obj, BaseModelV2): + return True + + if isinstance(obj, BaseModelV1): + return True + else: + raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}") + return False + + # How to type hint this? def pre_init(func: Callable) -> Any: """Decorator to run a function before model initialization. @@ -64,3 +128,106 @@ def pre_init(func: Callable) -> Any: return func(cls, values) return wrapper + + +def _create_subset_model_v1( + name: str, + model: Type[BaseModel], + field_names: list, + *, + descriptions: Optional[dict] = None, + fn_description: Optional[str] = None, +) -> Type[BaseModel]: + """Create a pydantic model with only a subset of model's fields.""" + from langchain_core.pydantic_v1 import create_model + + fields = {} + + for field_name in field_names: + field = model.__fields__[field_name] + t = ( + # this isn't perfect but should work for most functions + field.outer_type_ + if field.required and not field.allow_none + else Optional[field.outer_type_] + ) + if descriptions and field_name in descriptions: + field.field_info.description = descriptions[field_name] + fields[field_name] = (t, field.field_info) + + rtn = create_model(name, **fields) # type: ignore + rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "") + return rtn + + +def _create_subset_model_v2( + name: str, + model: Type[BaseModel], + field_names: List[str], + *, + descriptions: Optional[dict] = None, + fn_description: Optional[str] = None, +) -> Type[BaseModel]: + """Create a pydantic model with a subset of the model fields.""" + from pydantic import create_model # pydantic: ignore + from pydantic.fields import FieldInfo # pydantic: ignore + + descriptions_ = descriptions or {} + fields = {} + for field_name in field_names: + field = model.model_fields[field_name] # type: ignore + description = descriptions_.get(field_name, field.description) + fields[field_name] = ( + field.annotation, + FieldInfo(description=description, default=field.default), + ) + rtn = create_model(name, **fields) # type: ignore + + rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "") + return rtn + + +# Private functionality to create a subset model that's compatible across +# different versions of pydantic. +# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x. +# However, can't find a way to type hint this. +def _create_subset_model( + name: str, + model: Type[BaseModel], + field_names: List[str], + *, + descriptions: Optional[dict] = None, + fn_description: Optional[str] = None, +) -> Type[BaseModel]: + """Create subset model using the same pydantic version as the input model.""" + if PYDANTIC_MAJOR_VERSION == 1: + return _create_subset_model_v1( + name, + model, + field_names, + descriptions=descriptions, + fn_description=fn_description, + ) + elif PYDANTIC_MAJOR_VERSION == 2: + from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore + + if issubclass(model, BaseModelV1): + return _create_subset_model_v1( + name, + model, + field_names, + descriptions=descriptions, + fn_description=fn_description, + ) + else: + return _create_subset_model_v2( + name, + model, + field_names, + descriptions=descriptions, + fn_description=fn_description, + ) + else: + raise NotImplementedError( + f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}" + ) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 0298cc26c5e..4c5c073a28e 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -31,10 +31,10 @@ from langchain_core.tools import ( StructuredTool, Tool, ToolException, - _create_subset_model, tool, ) from langchain_core.utils.function_calling import convert_to_openai_function +from langchain_core.utils.pydantic import _create_subset_model from tests.unit_tests.fake.callbacks import FakeCallbackHandler @@ -1417,3 +1417,112 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None: "required": ["x"], }, } + + +def generate_models() -> List[Any]: + """Generate a list of base models depending on the pydantic version.""" + from pydantic import BaseModel as BaseModelProper # pydantic: ignore + + class FooProper(BaseModelProper): + a: int + b: str + + return [FooProper] + + +def generate_backwards_compatible_v1() -> List[Any]: + """Generate a model with pydantic 2 from the v1 namespace.""" + from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore + + class FooV1Namespace(BaseModelV1): + a: int + b: str + + return [FooV1Namespace] + + +# This generates a list of models that can be used for testing that our APIs +# behave well with either pydantic 1 proper, +# pydantic v1 from pydantic 2, +# or pydantic 2 proper. +TEST_MODELS = generate_models() + generate_backwards_compatible_v1() + + +@pytest.mark.parametrize("pydantic_model", TEST_MODELS) +def test_args_schema_as_pydantic(pydantic_model: Any) -> None: + class SomeTool(BaseTool): + args_schema: Type[pydantic_model] = pydantic_model + + def _run(self, *args: Any, **kwargs: Any) -> str: + return "foo" + + tool = SomeTool( + name="some_tool", description="some description", args_schema=pydantic_model + ) + + assert tool.get_input_schema().schema() == { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "string"}, + }, + "required": ["a", "b"], + "title": pydantic_model.__name__, + "type": "object", + } + + assert tool.tool_call_schema.schema() == { + "description": "some description", + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "string"}, + }, + "required": ["a", "b"], + "title": "some_tool", + "type": "object", + } + + +def test_args_schema_explicitly_typed() -> None: + """This should test that one can type the args schema as a pydantic model. + + Please note that this will test using pydantic 2 even though BaseTool + is a pydantic 1 model! + """ + # Check with whatever pydantic model is passed in and not via v1 namespace + from pydantic import BaseModel # pydantic: ignore + + class Foo(BaseModel): + a: int + b: str + + class SomeTool(BaseTool): + # type ignoring here since we're allowing overriding a type + # signature of pydantic.v1.BaseModel with pydantic.BaseModel + # for pydantic 2! + args_schema: Type[BaseModel] = Foo # type: ignore[assignment] + + def _run(self, *args: Any, **kwargs: Any) -> str: + return "foo" + + tool = SomeTool(name="some_tool", description="some description") + + assert tool.get_input_schema().schema() == { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "string"}, + }, + "required": ["a", "b"], + "title": "Foo", + "type": "object", + } + + assert tool.tool_call_schema.schema() == { + "description": "some description", + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "string"}, + }, + "required": ["a", "b"], + "title": "some_tool", + "type": "object", + } diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index cedf4db7c25..57f4538ed5a 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -3,7 +3,12 @@ from typing import Any, Dict, Optional from langchain_core.pydantic_v1 import BaseModel, Field -from langchain_core.utils.pydantic import pre_init +from langchain_core.utils.pydantic import ( + PYDANTIC_MAJOR_VERSION, + is_basemodel_instance, + is_basemodel_subclass, + pre_init, +) def test_pre_init_decorator() -> None: @@ -73,3 +78,46 @@ def test_with_aliases() -> None: foo = Foo(y=2) # type: ignore assert foo.x == 2 assert foo.z == 2 + + +def test_is_basemodel_subclass() -> None: + """Test pydantic.""" + if PYDANTIC_MAJOR_VERSION == 1: + from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore + + assert is_basemodel_subclass(BaseModelV1Proper) + elif PYDANTIC_MAJOR_VERSION == 2: + from pydantic import BaseModel as BaseModelV2 # pydantic: ignore + from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore + + assert is_basemodel_subclass(BaseModelV2) + + assert is_basemodel_subclass(BaseModelV1) + else: + raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}") + + +def test_is_basemodel_instance() -> None: + """Test pydantic.""" + if PYDANTIC_MAJOR_VERSION == 1: + from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore + + class FooV1(BaseModelV1Proper): + x: int + + assert is_basemodel_instance(FooV1(x=5)) + elif PYDANTIC_MAJOR_VERSION == 2: + from pydantic import BaseModel as BaseModelV2 # pydantic: ignore + from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore + + class Foo(BaseModelV2): + x: int + + assert is_basemodel_instance(Foo(x=5)) + + class Bar(BaseModelV1): + x: int + + assert is_basemodel_instance(Bar(x=5)) + else: + raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")