diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index 4122cbe63a6..11bb5518cbe 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -1,7 +1,7 @@ import copy import json from json import JSONDecodeError -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional from langchain_core.exceptions import OutputParserException from langchain_core.messages import AIMessage, InvalidToolCall @@ -13,8 +13,9 @@ from langchain_core.messages.tool import ( ) from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.outputs import ChatGeneration, Generation -from langchain_core.pydantic_v1 import BaseModel, ValidationError +from langchain_core.pydantic_v1 import ValidationError from langchain_core.utils.json import parse_partial_json +from langchain_core.utils.pydantic import TypeBaseModel def parse_tool_call( @@ -255,7 +256,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser): class PydanticToolsParser(JsonOutputToolsParser): """Parse tools from OpenAI response.""" - tools: List[Type[BaseModel]] + tools: List[TypeBaseModel] """The tools to parse.""" # TODO: Support more granular streaming of objects. Currently only streams once all diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 1c2debcb6b1..b48dca9d28a 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -1,23 +1,16 @@ import json -from typing import Generic, List, Type, TypeVar, Union +from typing import Generic, List, Type import pydantic # pydantic: ignore from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser from langchain_core.outputs import Generation -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION - -if PYDANTIC_MAJOR_VERSION < 2: - PydanticBaseModel = pydantic.BaseModel - -else: - from pydantic.v1 import BaseModel # pydantic: ignore - - # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. - PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore - -TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) +from langchain_core.utils.pydantic import ( + PYDANTIC_MAJOR_VERSION, + PydanticBaseModel, + TBaseModel, +) class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): @@ -122,3 +115,10 @@ Here is the output schema: ``` {schema} ```""" # noqa: E501 + +# Re-exporting types for backwards compatibility +__all__ = [ + "PydanticBaseModel", + "PydanticOutputParser", + "TBaseModel", +] diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 2c7a434663c..85db1d3b92e 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -89,6 +89,7 @@ from langchain_core.runnables.config import ( ) from langchain_core.runnables.utils import accepts_context from langchain_core.utils.pydantic import ( + TypeBaseModel, _create_subset_model, is_basemodel_subclass, ) @@ -332,8 +333,15 @@ class ChildTool(BaseTool): You can provide few-shot examples as a part of the description. """ - args_schema: Optional[Type[BaseModel]] = None - """Pydantic model class to validate and parse the tool's input arguments.""" + args_schema: Optional[TypeBaseModel] = None + """Pydantic model class to validate and parse the tool's input arguments. + + Args schema should be either: + + - A subclass of pydantic.BaseModel. + or + - A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2 + """ return_direct: bool = False """Whether to return the tool's output directly. @@ -891,7 +899,7 @@ class StructuredTool(BaseTool): """Tool that can operate on any number of inputs.""" description: str = "" - args_schema: Type[BaseModel] = Field(..., description="The tool schema.") + args_schema: TypeBaseModel = Field(..., description="The tool schema.") """The input arguments' schema.""" func: Optional[Callable[..., Any]] """The function to run when the tool is called.""" diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 87586302215..8c5dbdbbb11 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -5,9 +5,14 @@ from __future__ import annotations import inspect import textwrap from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union -from langchain_core.pydantic_v1 import BaseModel, root_validator +import pydantic # pydantic: ignore + +from langchain_core.pydantic_v1 import ( + BaseModel, + root_validator, +) def get_pydantic_major_version() -> int: @@ -23,6 +28,22 @@ def get_pydantic_major_version() -> int: PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() +if PYDANTIC_MAJOR_VERSION == 1: + PydanticBaseModel = pydantic.BaseModel + TypeBaseModel = Type[BaseModel] +elif PYDANTIC_MAJOR_VERSION == 2: + from pydantic.v1 import BaseModel # pydantic: ignore + + # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. + PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore + TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore +else: + raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}") + + +TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) + + def is_basemodel_subclass(cls: Type) -> bool: """Check if the given class is a subclass of Pydantic BaseModel. @@ -37,13 +58,13 @@ def is_basemodel_subclass(cls: Type) -> bool: return False if PYDANTIC_MAJOR_VERSION == 1: - from pydantic import BaseModel as BaseModelV1Proper + from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore if issubclass(cls, BaseModelV1Proper): return True elif PYDANTIC_MAJOR_VERSION == 2: - from pydantic import BaseModel as BaseModelV2 - from pydantic.v1 import BaseModel as BaseModelV1 + from pydantic import BaseModel as BaseModelV2 # pydantic: ignore + from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore if issubclass(cls, BaseModelV2): return True @@ -65,13 +86,13 @@ def is_basemodel_instance(obj: Any) -> bool: * pydantic.v1.BaseModel in Pydantic 2.x """ if PYDANTIC_MAJOR_VERSION == 1: - from pydantic import BaseModel as BaseModelV1Proper + from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore if isinstance(obj, BaseModelV1Proper): return True elif PYDANTIC_MAJOR_VERSION == 2: - from pydantic import BaseModel as BaseModelV2 - from pydantic.v1 import BaseModel as BaseModelV1 + from pydantic import BaseModel as BaseModelV2 # pydantic: ignore + from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore if isinstance(obj, BaseModelV2): return True diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index cd7f9f52dd4..0371a97962d 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -1,12 +1,21 @@ from typing import Any, AsyncIterator, Iterator, List -from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk +import pytest + +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ToolCallChunk, +) from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, JsonOutputToolsParser, PydanticToolsParser, ) +from langchain_core.outputs import ChatGeneration from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION STREAMED_MESSAGES: list = [ AIMessageChunk(content=""), @@ -518,3 +527,108 @@ async def test_partial_pydantic_output_parser_async() -> None: actual = [p async for p in chain.astream(None)] assert actual == EXPECTED_STREAMED_PYDANTIC + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2") +def test_parse_with_different_pydantic_2_v1() -> None: + """Test with pydantic.v1.BaseModel from pydantic 2.""" + import pydantic # pydantic: ignore + + class Forecast(pydantic.v1.BaseModel): + temperature: int + forecast: str + + # Can't get pydantic to work here due to the odd typing of tryig to support + # both v1 and v2 in the same codebase. + parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item] + message = AIMessage( + content="", + tool_calls=[ + { + "id": "call_OwL7f5PE", + "name": "Forecast", + "args": {"temperature": 20, "forecast": "Sunny"}, + } + ], + ) + + generation = ChatGeneration( + message=message, + ) + + assert parser.parse_result([generation]) == [ + Forecast( + temperature=20, + forecast="Sunny", + ) + ] + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2") +def test_parse_with_different_pydantic_2_proper() -> None: + """Test with pydantic.BaseModel from pydantic 2.""" + import pydantic # pydantic: ignore + + class Forecast(pydantic.BaseModel): + temperature: int + forecast: str + + # Can't get pydantic to work here due to the odd typing of tryig to support + # both v1 and v2 in the same codebase. + parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item] + message = AIMessage( + content="", + tool_calls=[ + { + "id": "call_OwL7f5PE", + "name": "Forecast", + "args": {"temperature": 20, "forecast": "Sunny"}, + } + ], + ) + + generation = ChatGeneration( + message=message, + ) + + assert parser.parse_result([generation]) == [ + Forecast( + temperature=20, + forecast="Sunny", + ) + ] + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1") +def test_parse_with_different_pydantic_1_proper() -> None: + """Test with pydantic.BaseModel from pydantic 1.""" + import pydantic # pydantic: ignore + + class Forecast(pydantic.BaseModel): + temperature: int + forecast: str + + # Can't get pydantic to work here due to the odd typing of tryig to support + # both v1 and v2 in the same codebase. + parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item] + message = AIMessage( + content="", + tool_calls=[ + { + "id": "call_OwL7f5PE", + "name": "Forecast", + "args": {"temperature": 20, "forecast": "Sunny"}, + } + ], + ) + + generation = ChatGeneration( + message=message, + ) + + assert parser.parse_result([generation]) == [ + Forecast( + temperature=20, + forecast="Sunny", + ) + ] diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index 9d5601f0124..eb45a65e6ed 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -6,9 +6,9 @@ import pytest from langchain_core.exceptions import OutputParserException from langchain_core.language_models import ParrotFakeChatModel from langchain_core.output_parsers.json import JsonOutputParser -from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel +from langchain_core.output_parsers.pydantic import PydanticOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION +from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel V1BaseModel = pydantic.BaseModel if PYDANTIC_MAJOR_VERSION == 2: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 4c5c073a28e..7f0fccbd050 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1526,3 +1526,40 @@ def test_args_schema_explicitly_typed() -> None: "title": "some_tool", "type": "object", } + + +@pytest.mark.parametrize("pydantic_model", TEST_MODELS) +def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None: + """This should test that one can type the args schema as a pydantic model.""" + from langchain_core.tools import StructuredTool + + def foo(a: int, b: str) -> str: + """Hahaha""" + return "foo" + + foo_tool = StructuredTool.from_function( + func=foo, + args_schema=pydantic_model, + ) + + assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo" + + assert foo_tool.args_schema.schema() == { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "string"}, + }, + "required": ["a", "b"], + "title": pydantic_model.__name__, + "type": "object", + } + + assert foo_tool.get_input_schema().schema() == { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "string"}, + }, + "required": ["a", "b"], + "title": pydantic_model.__name__, + "type": "object", + }