From b2ba4f4072c80163a858276205cc5d1a35a79296 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 6 Sep 2024 17:14:17 -0400 Subject: [PATCH] core[patch]: fix deprecated pydantic code (#26161) --- .../core/langchain_core/indexing/in_memory.py | 2 +- .../output_parsers/openai_functions.py | 10 ++- libs/core/langchain_core/prompts/base.py | 2 +- libs/core/langchain_core/runnables/utils.py | 11 +++- libs/core/langchain_core/tools/base.py | 5 +- libs/core/langchain_core/tracers/core.py | 2 +- libs/core/langchain_core/tracers/langchain.py | 24 +++++-- libs/core/langchain_core/tracers/schemas.py | 12 +++- libs/core/langchain_core/utils/pydantic.py | 66 ++++++++++--------- libs/core/pyproject.toml | 3 + libs/core/tests/unit_tests/test_tools.py | 34 ++++++++-- .../tests/unit_tests/utils/test_pydantic.py | 2 +- 12 files changed, 118 insertions(+), 55 deletions(-) diff --git a/libs/core/langchain_core/indexing/in_memory.py b/libs/core/langchain_core/indexing/in_memory.py index 25ac4c0463e..acc4d3f9584 100644 --- a/libs/core/langchain_core/indexing/in_memory.py +++ b/libs/core/langchain_core/indexing/in_memory.py @@ -32,7 +32,7 @@ class InMemoryDocumentIndex(DocumentIndex): for item in items: if item.id is None: id_ = str(uuid.uuid4()) - item_ = item.copy() + item_ = item.model_copy() item_.id = id_ else: item_ = item diff --git a/libs/core/langchain_core/output_parsers/openai_functions.py b/libs/core/langchain_core/output_parsers/openai_functions.py index ab706acfd80..4324eac47d9 100644 --- a/libs/core/langchain_core/output_parsers/openai_functions.py +++ b/libs/core/langchain_core/output_parsers/openai_functions.py @@ -275,10 +275,14 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): else: fn_name = _result["name"] _args = _result["arguments"] - if hasattr(self.pydantic_schema, "model_validate_json"): - pydantic_args = self.pydantic_schema[fn_name].model_validate_json(_args) # type: ignore + if isinstance(self.pydantic_schema, dict): + pydantic_schema = self.pydantic_schema[fn_name] else: - pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore + pydantic_schema = self.pydantic_schema + if hasattr(pydantic_schema, "model_validate_json"): + pydantic_args = pydantic_schema.model_validate_json(_args) # type: ignore + else: + pydantic_args = pydantic_schema.parse_raw(_args) # type: ignore return pydantic_args diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index d565efb1ffe..b8473fcd09b 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -310,7 +310,7 @@ class BasePromptTemplate( Raises: NotImplementedError: If the prompt type is not implemented. """ - prompt_dict = super().dict(**kwargs) + prompt_dict = super().model_dump(**kwargs) try: prompt_dict["_type"] = self._prompt_type except NotImplementedError: diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 018ef1d674e..e46a45103da 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -6,6 +6,7 @@ import ast import asyncio import inspect import textwrap +import warnings from functools import lru_cache from inspect import signature from itertools import groupby @@ -31,13 +32,14 @@ from typing import ( cast, ) -from pydantic import BaseModel, ConfigDict, RootModel +from pydantic import BaseModel, ConfigDict, PydanticDeprecationWarning, RootModel from pydantic import create_model as _create_model_base # pydantic :ignore from pydantic.json_schema import ( DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode, ) +from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import TypeGuard from langchain_core.runnables.schema import StreamEvent @@ -754,7 +756,12 @@ def create_base_class( if default_ is not NO_DEFAULT: base_class_attributes["root"] = default_ - custom_root_type = type(name, (RootModel,), base_class_attributes) + with warnings.catch_warnings(): + if isinstance(type_, type) and issubclass(type_, BaseModelV1): + warnings.filterwarnings( + action="ignore", category=PydanticDeprecationWarning + ) + custom_root_type = type(name, (RootModel,), base_class_attributes) return cast(Type[BaseModel], custom_root_type) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 0bde16338e2..1929a391501 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -476,7 +476,10 @@ class ChildTool(BaseTool): if isinstance(tool_input, str): if input_args is not None: key_ = next(iter(get_fields(input_args).keys())) - input_args.validate({key_: tool_input}) + if hasattr(input_args, "model_validate"): + input_args.model_validate({key_: tool_input}) + else: + input_args.parse_obj({key_: tool_input}) return tool_input else: if input_args is not None: diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index b2a809f4d89..15e23023bb0 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -283,7 +283,7 @@ class _TracerCore(ABC): def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run: llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) - llm_run.outputs = response.dict() + llm_run.outputs = response.model_dump() for i, generations in enumerate(response.generations): for j, generation in enumerate(generations): output_generation = llm_run.outputs["generations"][i][j] diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index a8ef48fd013..78b731a8252 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import warnings from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -10,6 +11,7 @@ from uuid import UUID from langsmith import Client from langsmith import utils as ls_utils +from pydantic import PydanticDeprecationWarning from tenacity import ( Retrying, retry_if_exception_type, @@ -69,11 +71,16 @@ def _get_executor() -> ThreadPoolExecutor: def _run_to_dict(run: Run) -> dict: - return { - **run.dict(exclude={"child_runs", "inputs", "outputs"}), - "inputs": run.inputs.copy() if run.inputs is not None else None, - "outputs": run.outputs.copy() if run.outputs is not None else None, - } + # TODO: Update once langsmith moves to Pydantic V2 and we can swap run.dict for + # run.model_dump + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=PydanticDeprecationWarning) + + return { + **run.dict(exclude={"child_runs", "inputs", "outputs"}), + "inputs": run.inputs.copy() if run.inputs is not None else None, + "outputs": run.outputs.copy() if run.outputs is not None else None, + } class LangChainTracer(BaseTracer): @@ -152,7 +159,12 @@ class LangChainTracer(BaseTracer): return chat_model_run def _persist_run(self, run: Run) -> None: - run_ = run.copy() + # TODO: Update once langsmith moves to Pydantic V2 and we can swap run.copy for + # run.model_copy + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=PydanticDeprecationWarning) + + run_ = run.copy() run_.reference_example_id = self.example_id self.latest_run = run_ diff --git a/libs/core/langchain_core/tracers/schemas.py b/libs/core/langchain_core/tracers/schemas.py index 2725a3a8b53..4a54588357c 100644 --- a/libs/core/langchain_core/tracers/schemas.py +++ b/libs/core/langchain_core/tracers/schemas.py @@ -9,6 +9,7 @@ from uuid import UUID from langsmith.schemas import RunBase as BaseRunV2 from langsmith.schemas import RunTypeEnum as RunTypeEnumDep +from pydantic import PydanticDeprecationWarning from langchain_core._api import deprecated from langchain_core.pydantic_v1 import BaseModel, Field, root_validator @@ -142,9 +143,14 @@ class Run(BaseRunV2): return values -ChainRun.update_forward_refs() -ToolRun.update_forward_refs() -Run.update_forward_refs() +# TODO: Update once langsmith moves to Pydantic V2 and we can swap Run.model_rebuild +# for Run.update_forward_refs +with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=PydanticDeprecationWarning) + + ChainRun.update_forward_refs() + ToolRun.update_forward_refs() + Run.update_forward_refs() __all__ = [ "BaseRun", diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 25a1f06c757..e739532b42b 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -4,11 +4,12 @@ from __future__ import annotations import inspect import textwrap +import warnings from functools import wraps from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload import pydantic -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, PydanticDeprecationWarning, root_validator from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import core_schema @@ -134,42 +135,45 @@ def pre_init(func: Callable) -> Any: Any: The decorated function. """ - @root_validator(pre=True) - @wraps(func) - def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]: - """Decorator to run a function before model initialization. + with warnings.catch_warnings(): + warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning) - Args: - cls (Type[BaseModel]): The model class. - values (Dict[str, Any]): The values to initialize the model with. + @root_validator(pre=True) + @wraps(func) + def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]: + """Decorator to run a function before model initialization. - Returns: - Dict[str, Any]: The values to initialize the model with. - """ - # Insert default values - fields = cls.model_fields - for name, field_info in fields.items(): - # Check if allow_population_by_field_name is enabled - # If yes, then set the field name to the alias - if hasattr(cls, "Config"): - if hasattr(cls.Config, "allow_population_by_field_name"): - if cls.Config.allow_population_by_field_name: + Args: + cls (Type[BaseModel]): The model class. + values (Dict[str, Any]): The values to initialize the model with. + + Returns: + Dict[str, Any]: The values to initialize the model with. + """ + # Insert default values + fields = cls.model_fields + for name, field_info in fields.items(): + # Check if allow_population_by_field_name is enabled + # If yes, then set the field name to the alias + if hasattr(cls, "Config"): + if hasattr(cls.Config, "allow_population_by_field_name"): + if cls.Config.allow_population_by_field_name: + if field_info.alias in values: + values[name] = values.pop(field_info.alias) + if hasattr(cls, "model_config"): + if cls.model_config.get("populate_by_name"): if field_info.alias in values: values[name] = values.pop(field_info.alias) - if hasattr(cls, "model_config"): - if cls.model_config.get("populate_by_name"): - if field_info.alias in values: - values[name] = values.pop(field_info.alias) - if name not in values or values[name] is None: - if not field_info.is_required(): - if field_info.default_factory is not None: - values[name] = field_info.default_factory() - else: - values[name] = field_info.default + if name not in values or values[name] is None: + if not field_info.is_required(): + if field_info.default_factory is not None: + values[name] = field_info.default_factory() + else: + values[name] = field_info.default - # Call the decorated function - return func(cls, values) + # Call the decorated function + return func(cls, values) return wrapper diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 2b52da4b514..ef366375a17 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -56,6 +56,9 @@ markers = [ "compile: mark placeholder test used to compile integration tests without running them", ] asyncio_mode = "auto" +filterwarnings = [ + "ignore::langchain_core._api.beta_decorator.LangChainBetaWarning", +] [tool.poetry.group.lint] optional = true diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 2eff4dd7916..c1d07762cf9 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -293,7 +293,7 @@ def test_structured_tool_types_parsed() -> None: assert isinstance(structured_tool, StructuredTool) args = { "some_enum": SomeEnum.A.value, - "some_base_model": SomeBaseModel(foo="bar").dict(), + "some_base_model": SomeBaseModel(foo="bar").model_dump(), } result = structured_tool.run(json.loads(json.dumps(args))) expected = { @@ -1600,7 +1600,13 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None: name="some_tool", description="some description", args_schema=pydantic_model ) - assert tool.get_input_schema().schema() == { + input_schema = tool.get_input_schema() + input_json_schema = ( + input_schema.model_json_schema() + if hasattr(input_schema, "model_json_schema") + else input_schema.schema() + ) + assert input_json_schema == { "properties": { "a": {"title": "A", "type": "integer"}, "b": {"title": "B", "type": "string"}, @@ -1610,7 +1616,13 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None: "type": "object", } - assert tool.tool_call_schema.schema() == { + tool_schema = tool.tool_call_schema + tool_json_schema = ( + tool_schema.model_json_schema() + if hasattr(tool_schema, "model_json_schema") + else tool_schema.schema() + ) + assert tool_json_schema == { "description": "some description", "properties": { "a": {"title": "A", "type": "integer"}, @@ -1684,7 +1696,13 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) - assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo" - assert foo_tool.args_schema.schema() == { + args_schema = foo_tool.args_schema + args_json_schema = ( + args_schema.model_json_schema() + if hasattr(args_schema, "model_json_schema") + else args_schema.schema() + ) + assert args_json_schema == { "properties": { "a": {"title": "A", "type": "integer"}, "b": {"title": "B", "type": "string"}, @@ -1694,7 +1712,13 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) - "type": "object", } - assert foo_tool.get_input_schema().schema() == { + input_schema = foo_tool.get_input_schema() + input_json_schema = ( + input_schema.model_json_schema() + if hasattr(input_schema, "model_json_schema") + else input_schema.schema() + ) + assert input_json_schema == { "properties": { "a": {"title": "A", "type": "integer"}, "b": {"title": "B", "type": "string"}, diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index c579c105b22..c49c5050a7b 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -171,7 +171,7 @@ def test_fields_pydantic_v1() -> None: x: int fields = get_fields(Foo) - assert fields == {"x": Foo.__fields__["x"]} # type: ignore[index] + assert fields == {"x": Foo.model_fields["x"]} # type: ignore[index] @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")