core: Cleanup Pydantic models and handle deprecation warnings (#30799)

* Simplified Pydantic handling since Pydantic v1 is not supported
anymore.
* Replace use of deprecated v1 methods by corresponding v2 methods.
* Remove use of other deprecated methods.
* Activate mypy errors on deprecated methods use.

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2025-06-20 16:42:52 +02:00 committed by GitHub
parent 29e17fbd6b
commit 7e046ea848
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 134 additions and 413 deletions

View File

@ -23,6 +23,8 @@ from typing import (
cast,
)
from pydantic.fields import FieldInfo
from pydantic.v1.fields import FieldInfo as FieldInfoV1
from typing_extensions import ParamSpec
from langchain_core._api.internal import is_caller_internal
@ -152,10 +154,6 @@ def deprecated(
_package: str = package,
) -> T:
"""Implementation of the decorator returned by `deprecated`."""
from langchain_core.utils.pydantic import ( # type: ignore[attr-defined]
FieldInfoV1,
FieldInfoV2,
)
def emit_warning() -> None:
"""Emit the warning."""
@ -249,7 +247,7 @@ def deprecated(
),
)
elif isinstance(obj, FieldInfoV2):
elif isinstance(obj, FieldInfo):
wrapped = None
if not _obj_type:
_obj_type = "attribute"
@ -261,7 +259,7 @@ def deprecated(
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
return cast(
"T",
FieldInfoV2(
FieldInfo(
default=obj.default,
default_factory=obj.default_factory,
description=new_doc,

View File

@ -326,7 +326,7 @@ class BaseOutputParser(
def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
output_parser_dict = super().model_dump(**kwargs)
with contextlib.suppress(NotImplementedError):
output_parser_dict["_type"] = self._type
return output_parser_dict

View File

@ -9,6 +9,7 @@ from typing import Annotated, Any, Optional, TypeVar, Union
import jsonpatch # type: ignore[import-untyped]
import pydantic
from pydantic import SkipValidation
from pydantic.v1 import BaseModel
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
@ -20,16 +21,9 @@ from langchain_core.utils.json import (
parse_json_markdown,
parse_partial_json,
)
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
if IS_PYDANTIC_V1:
PydanticBaseModel = pydantic.BaseModel
else:
from pydantic.v1 import BaseModel
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel]
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)

View File

@ -7,6 +7,7 @@ from typing import Any, Optional, Union
import jsonpatch # type: ignore[import-untyped]
from pydantic import BaseModel, model_validator
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
@ -275,10 +276,13 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
pydantic_schema = self.pydantic_schema[fn_name]
else:
pydantic_schema = self.pydantic_schema
if hasattr(pydantic_schema, "model_validate_json"):
if issubclass(pydantic_schema, BaseModel):
pydantic_args = pydantic_schema.model_validate_json(_args)
else:
elif issubclass(pydantic_schema, BaseModelV1):
pydantic_args = pydantic_schema.parse_raw(_args)
else:
msg = f"Unsupported pydantic schema: {pydantic_schema}"
raise ValueError(msg)
return pydantic_args

View File

@ -11,7 +11,6 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V2,
PydanticBaseModel,
TBaseModel,
)
@ -24,22 +23,16 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
if IS_PYDANTIC_V2:
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
return self.pydantic_object.parse_obj(obj)
msg = f"Unsupported model version for PydanticOutputParser: \
{self.pydantic_object.__class__}"
raise OutputParserException(msg)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) from e
else: # pydantic v1
try:
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e:
raise self._parser_exception(e, obj) from e
msg = f"Unsupported model version for PydanticOutputParser: \
{self.pydantic_object.__class__}"
raise OutputParserException(msg)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) from e
def _parser_exception(
self, e: Exception, json_object: dict

View File

@ -134,7 +134,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
message=BaseMessageChunk(**chunk.model_dump())
)
else:
chunk_gen = GenerationChunk(text=chunk)
@ -161,7 +161,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
message=BaseMessageChunk(**chunk.model_dump())
)
else:
chunk_gen = GenerationChunk(text=chunk)

View File

@ -2,25 +2,10 @@
from importlib import metadata
from pydantic.v1 import * # noqa: F403
from langchain_core._api.deprecation import warn_deprecated
# Create namespaces for pydantic v1 and v2.
# This code must stay at the top of the file before other modules may
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
#
# This hack is done for the following reasons:
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
# both dependencies and dependents may be stuck on either version of v1 or v2.
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
# unambiguously uses either v1 or v2 API.
# * This change is easier to roll out and roll back.
try:
from pydantic.v1 import * # noqa: F403
except ImportError:
from pydantic import * # type: ignore[assignment,no-redef] # noqa: F403
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:

View File

@ -1,11 +1,8 @@
"""Pydantic v1 compatibility shim."""
from langchain_core._api import warn_deprecated
from pydantic.v1.dataclasses import * # noqa: F403
try:
from pydantic.v1.dataclasses import * # noqa: F403
except ImportError:
from pydantic.dataclasses import * # type: ignore[no-redef] # noqa: F403
from langchain_core._api import warn_deprecated
warn_deprecated(
"0.3.0",

View File

@ -1,11 +1,8 @@
"""Pydantic v1 compatibility shim."""
from langchain_core._api import warn_deprecated
from pydantic.v1.main import * # noqa: F403
try:
from pydantic.v1.main import * # noqa: F403
except ImportError:
from pydantic.main import * # type: ignore[assignment,no-redef] # noqa: F403
from langchain_core._api import warn_deprecated
warn_deprecated(
"0.3.0",

View File

@ -540,10 +540,13 @@ class ChildTool(BaseTool):
)
raise ValueError(msg)
key_ = next(iter(get_fields(input_args).keys()))
if hasattr(input_args, "model_validate"):
if issubclass(input_args, BaseModel):
input_args.model_validate({key_: tool_input})
else:
elif issubclass(input_args, BaseModelV1):
input_args.parse_obj({key_: tool_input})
else:
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
raise TypeError(msg)
return tool_input
if input_args is not None:
if isinstance(input_args, dict):

View File

@ -2,8 +2,8 @@
from __future__ import annotations
import datetime
import warnings
from datetime import datetime, timezone
from typing import Any, Optional
from uuid import UUID
@ -32,7 +32,7 @@ def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
class TracerSessionV1Base(BaseModelV1):
"""Base class for TracerSessionV1."""
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
name: Optional[str] = None
extra: Optional[dict[str, Any]] = None
@ -69,8 +69,8 @@ class BaseRun(BaseModelV1):
uuid: str
parent_uuid: Optional[str] = None
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
extra: Optional[dict[str, Any]] = None
execution_order: int
child_execution_order: int

View File

@ -21,9 +21,12 @@ from typing import (
import pydantic
from packaging import version
from pydantic import (
# root_validator is deprecated but we need it for backward compatibility of @pre_init
from pydantic import ( # type: ignore[deprecated]
BaseModel,
ConfigDict,
Field,
PydanticDeprecationWarning,
RootModel,
root_validator,
@ -38,29 +41,23 @@ from pydantic.json_schema import (
JsonSchemaMode,
JsonSchemaValue,
)
from typing_extensions import override
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import create_model as create_model_v1
from pydantic.v1.fields import ModelField
from typing_extensions import deprecated, override
if TYPE_CHECKING:
from pydantic_core import core_schema
try:
import pydantic
PYDANTIC_VERSION = version.parse(pydantic.__version__)
except ImportError:
PYDANTIC_VERSION = version.parse("0.0.0")
PYDANTIC_VERSION = version.parse(pydantic.__version__)
@deprecated("Use PYDANTIC_VERSION.major instead.")
def get_pydantic_major_version() -> int:
"""DEPRECATED - Get the major version of Pydantic.
Use PYDANTIC_VERSION.major instead.
"""
warnings.warn(
"get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.",
DeprecationWarning,
stacklevel=2,
)
return PYDANTIC_VERSION.major
@ -70,43 +67,20 @@ PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
if IS_PYDANTIC_V1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = type[BaseModel]
elif IS_PYDANTIC_V2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
from pydantic.v1.fields import ModelField
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
PydanticBaseModel = BaseModel
TypeBaseModel = type[BaseModel]
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def is_pydantic_v1_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
if IS_PYDANTIC_V1:
return True
if IS_PYDANTIC_V2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV1):
return True
return False
return issubclass(cls, BaseModelV1)
def is_pydantic_v2_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel
return IS_PYDANTIC_V2 and issubclass(cls, BaseModel)
return issubclass(cls, BaseModel)
def is_basemodel_subclass(cls: type) -> bool:
@ -114,7 +88,6 @@ def is_basemodel_subclass(cls: type) -> bool:
Check if the given class is a subclass of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
@ -122,24 +95,7 @@ def is_basemodel_subclass(cls: type) -> bool:
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
return False
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
if issubclass(cls, BaseModelV1Proper):
return True
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV2):
return True
if issubclass(cls, BaseModelV1):
return True
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
return False
return issubclass(cls, (BaseModel, BaseModelV1))
def is_basemodel_instance(obj: Any) -> bool:
@ -147,28 +103,10 @@ def is_basemodel_instance(obj: Any) -> bool:
Check if the given class is an instance of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
if isinstance(obj, BaseModelV1Proper):
return True
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if isinstance(obj, BaseModelV2):
return True
if isinstance(obj, BaseModelV1):
return True
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
return False
return isinstance(obj, (BaseModel, BaseModelV1))
# How to type hint this?
@ -184,6 +122,9 @@ def pre_init(func: Callable) -> Any:
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
# Ideally we would use @model_validator(mode="before") but this would change the
# order of the validators. See https://github.com/pydantic/pydantic/discussions/7434.
# So we keep root_validator for backward compatibility.
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
@ -244,26 +185,18 @@ class _IgnoreUnserializable(GenerateJsonSchema):
def _create_subset_model_v1(
name: str,
model: type[BaseModel],
model: type[BaseModelV1],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
if IS_PYDANTIC_V1:
from pydantic import create_model
elif IS_PYDANTIC_V2:
from pydantic.v1 import create_model # type: ignore[no-redef]
else:
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
raise NotImplementedError(msg)
fields = {}
for field_name in field_names:
# Using pydantic v1 so can access __fields__ as a dict.
field = model.__fields__[field_name] # type: ignore[index]
field = model.__fields__[field_name]
t = (
# this isn't perfect but should work for most functions
field.outer_type_
@ -274,34 +207,31 @@ def _create_subset_model_v1(
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore[call-overload]
rtn = create_model_v1(name, **fields) # type: ignore[call-overload]
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
def _create_subset_model_v2(
name: str,
model: type[pydantic.BaseModel],
model: type[BaseModel],
field_names: list[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> type[pydantic.BaseModel]:
) -> type[BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
from pydantic import create_model
from pydantic.fields import FieldInfo
descriptions_ = descriptions or {}
fields = {}
for field_name in field_names:
field = model.model_fields[field_name]
description = descriptions_.get(field_name, field.description)
field_info = FieldInfo(description=description, default=field.default)
field_info = FieldInfoV2(description=description, default=field.default)
if field.metadata:
field_info.metadata = field.metadata
fields[field_name] = (field.annotation, field_info)
rtn = create_model( # type: ignore[call-overload]
rtn = _create_model_base( # type: ignore[call-overload]
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
)
@ -322,7 +252,7 @@ def _create_subset_model_v2(
# Private functionality to create a subset model that's compatible across
# different versions of pydantic.
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
# Handles pydantic versions 2.x. including v1 of pydantic in 2.x.
# However, can't find a way to type hint this.
def _create_subset_model(
name: str,
@ -333,7 +263,7 @@ def _create_subset_model(
fn_description: Optional[str] = None,
) -> type[BaseModel]:
"""Create subset model using the same pydantic version as the input model."""
if IS_PYDANTIC_V1:
if issubclass(model, BaseModelV1):
return _create_subset_model_v1(
name,
model,
@ -341,68 +271,43 @@ def _create_subset_model(
descriptions=descriptions,
fn_description=fn_description,
)
if IS_PYDANTIC_V2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(model, BaseModelV1):
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
raise NotImplementedError(msg)
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
if IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
@overload
def get_fields(model: type[BaseModel]) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: BaseModel) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
@overload
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
@overload
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
def get_fields(
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields
if hasattr(model, "__fields__"):
return model.__fields__
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)
@overload
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
elif IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1_
def get_fields( # type: ignore[no-redef]
model: Union[type[BaseModelV1_], BaseModelV1_],
) -> dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore[return-value]
def get_fields(
model: Union[type[Union[BaseModel, BaseModelV1]], BaseModel, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields
if hasattr(model, "__fields__"):
return model.__fields__
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
_SchemaConfig = ConfigDict(
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
@ -546,14 +451,11 @@ _RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
"""This remaps fields to avoid colliding with internal pydantic fields."""
from pydantic import Field
from pydantic.fields import FieldInfo
remapped = {}
for key, value in field_definitions.items():
if key.startswith("_") or key in _RESERVED_NAMES:
# Let's add a prefix to avoid colliding with internal pydantic fields
if isinstance(value, FieldInfo):
if isinstance(value, FieldInfoV2):
msg = (
f"Remapping for fields starting with '_' or fields with a name "
f"matching a reserved name {_RESERVED_NAMES} is not supported if "

View File

@ -69,7 +69,6 @@ langchain-text-splitters = { path = "../text-splitters" }
strict = "True"
strict_bytes = "True"
enable_error_code = "deprecated"
report_deprecated_as_note = "True"
# TODO: activate for 'strict' checking
disallow_any_generics = "False"

View File

@ -16,10 +16,6 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
)
STREAMED_MESSAGES: list = [
AIMessageChunk(content=""),
@ -532,7 +528,6 @@ async def test_partial_pydantic_output_parser_async() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic
@ -567,7 +562,6 @@ def test_parse_with_different_pydantic_2_v1() -> None:
]
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2."""
import pydantic
@ -602,41 +596,6 @@ def test_parse_with_different_pydantic_2_proper() -> None:
]
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1")
def test_parse_with_different_pydantic_1_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 1."""
import pydantic
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast])
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
def test_max_tokens_error(caplog: Any) -> None:
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
message = AIMessage(

View File

@ -65,8 +65,6 @@ from langchain_core.utils.function_calling import (
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
_create_subset_model,
create_model_v2,
)
@ -79,9 +77,11 @@ def _get_tool_call_json_schema(tool: BaseTool) -> dict:
if isinstance(tool_schema, dict):
return tool_schema
if hasattr(tool_schema, "model_json_schema"):
if issubclass(tool_schema, BaseModel):
return tool_schema.model_json_schema()
return tool_schema.schema()
if issubclass(tool_schema, BaseModelV1):
return tool_schema.schema()
return {}
def test_unnamed_decorator() -> None:
@ -1853,11 +1853,14 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
)
input_schema = tool.get_input_schema()
input_json_schema = (
input_schema.model_json_schema()
if hasattr(input_schema, "model_json_schema")
else input_schema.schema()
)
if issubclass(input_schema, BaseModel):
input_json_schema = input_schema.model_json_schema()
elif issubclass(input_schema, BaseModelV1):
input_json_schema = input_schema.schema()
else:
msg = "Unknown input schema type"
raise TypeError(msg)
assert input_json_schema == {
"properties": {
"a": {"title": "A", "type": "integer"},
@ -1943,12 +1946,14 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
args_schema = cast("BaseModel", foo_tool.args_schema)
args_json_schema = (
args_schema.model_json_schema()
if hasattr(args_schema, "model_json_schema")
else args_schema.schema()
)
args_schema = cast("type[BaseModel]", foo_tool.args_schema)
if issubclass(args_schema, BaseModel):
args_json_schema = args_schema.model_json_schema()
elif issubclass(args_schema, BaseModelV1):
args_json_schema = args_schema.schema()
else:
msg = "Unknown input schema type"
raise TypeError(msg)
assert args_json_schema == {
"properties": {
"a": {"title": "A", "type": "integer"},
@ -1960,11 +1965,13 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
}
input_schema = foo_tool.get_input_schema()
input_json_schema = (
input_schema.model_json_schema()
if hasattr(input_schema, "model_json_schema")
else input_schema.schema()
)
if issubclass(input_schema, BaseModel):
input_json_schema = input_schema.model_json_schema()
elif issubclass(input_schema, BaseModelV1):
input_json_schema = input_schema.schema()
else:
msg = "Unknown input schema type"
raise TypeError(msg)
assert input_json_schema == {
"properties": {
"a": {"title": "A", "type": "integer"},
@ -2020,7 +2027,6 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
assert _is_message_content_type(obj) is expected
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
@pytest.mark.parametrize("use_v1_namespace", [True, False])
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
A = TypeVar("A")
@ -2089,63 +2095,6 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
assert actual == expected
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.")
def test__get_all_basemodel_annotations_v1() -> None:
A = TypeVar("A")
class ModelA(BaseModel, Generic[A], extra="allow"):
a: A
class ModelB(ModelA[str]):
b: Annotated[ModelA[dict[str, Any]], "foo"]
class Mixin:
def foo(self) -> str:
return "foo"
class ModelC(Mixin, ModelB):
c: dict
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict}
actual = get_all_basemodel_annotations(ModelC)
assert actual == expected
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"]}
actual = get_all_basemodel_annotations(ModelB)
assert actual == expected
expected = {"a": Any}
actual = get_all_basemodel_annotations(ModelA)
assert actual == expected
expected = {"a": int}
actual = get_all_basemodel_annotations(ModelA[int])
assert actual == expected
D = TypeVar("D", bound=Union[str, int])
class ModelD(ModelC, Generic[D]):
d: Optional[D]
expected = {
"a": str,
"b": Annotated[ModelA[dict[str, Any]], "foo"],
"c": dict,
"d": Union[str, int, None],
}
actual = get_all_basemodel_annotations(ModelD)
assert actual == expected
expected = {
"a": str,
"b": Annotated[ModelA[dict[str, Any]], "foo"],
"c": dict,
"d": Union[int, None],
}
actual = get_all_basemodel_annotations(ModelD[int])
assert actual == expected
def test_get_all_basemodel_annotations_aliases() -> None:
class CalculatorInput(BaseModel):
a: int = Field(description="first number", alias="A")
@ -2226,7 +2175,6 @@ def test_create_retriever_tool() -> None:
)
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2
from pydantic import Field as FieldV2

View File

@ -3,13 +3,9 @@
import warnings
from typing import Any, Optional
import pytest
from pydantic import ConfigDict
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
PYDANTIC_VERSION,
_create_subset_model_v2,
create_model_v2,
get_fields,
@ -96,50 +92,29 @@ def test_with_aliases() -> None:
def test_is_basemodel_subclass() -> None:
"""Test pydantic."""
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
assert is_basemodel_subclass(BaseModelV1Proper)
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
assert is_basemodel_subclass(BaseModelV2)
assert is_basemodel_subclass(BaseModelV1)
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
assert is_basemodel_subclass(BaseModelV2)
assert is_basemodel_subclass(BaseModelV1)
def test_is_basemodel_instance() -> None:
"""Test pydantic."""
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
class FooV1(BaseModelV1Proper):
x: int
class Foo(BaseModelV2):
x: int
assert is_basemodel_instance(FooV1(x=5))
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
assert is_basemodel_instance(Foo(x=5))
class Foo(BaseModelV2):
x: int
class Bar(BaseModelV1):
x: int
assert is_basemodel_instance(Foo(x=5))
class Bar(BaseModelV1):
x: int
assert is_basemodel_instance(Bar(x=5))
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
assert is_basemodel_instance(Bar(x=5))
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_with_field_metadata() -> None:
"""Test pydantic with field metadata."""
from pydantic import BaseModel as BaseModelV2
@ -168,18 +143,6 @@ def test_with_field_metadata() -> None:
}
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1")
def test_fields_pydantic_v1() -> None:
from pydantic import BaseModel
class Foo(BaseModel):
x: int
fields = get_fields(Foo)
assert fields == {"x": Foo.model_fields["x"]}
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_fields_pydantic_v2_proper() -> None:
from pydantic import BaseModel
@ -190,7 +153,6 @@ def test_fields_pydantic_v2_proper() -> None:
assert fields == {"x": Foo.model_fields["x"]}
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_fields_pydantic_v1_from_2() -> None:
from pydantic.v1 import BaseModel

View File

@ -16,10 +16,6 @@ from langchain_core.utils import (
guard_import,
)
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
)
from langchain_core.utils.utils import secret_from_env
@ -214,7 +210,6 @@ def test_guard_import_failure(
guard_import(module_name, pip_name=pip_name, package=package)
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
def test_get_pydantic_field_names_v1_in_2() -> None:
from pydantic.v1 import BaseModel as PydanticV1BaseModel
from pydantic.v1 import Field
@ -229,7 +224,6 @@ def test_get_pydantic_field_names_v1_in_2() -> None:
assert result == expected
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
def test_get_pydantic_field_names_v2_in_2() -> None:
from pydantic import BaseModel, Field
@ -243,20 +237,6 @@ def test_get_pydantic_field_names_v2_in_2() -> None:
assert result == expected
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1")
def test_get_pydantic_field_names_v1() -> None:
from pydantic import BaseModel, Field
class PydanticModel(BaseModel):
field1: str
field2: int
alias_field: int = Field(alias="aliased_field")
result = get_pydantic_field_names(PydanticModel)
expected = {"field1", "field2", "aliased_field", "alias_field"}
assert result == expected
def test_from_env_with_env_variable() -> None:
key = "TEST_KEY"
value = "test_value"