mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 02:43:37 +00:00
core: Cleanup Pydantic models and handle deprecation warnings (#30799)
* Simplified Pydantic handling since Pydantic v1 is not supported anymore. * Replace use of deprecated v1 methods by corresponding v2 methods. * Remove use of other deprecated methods. * Activate mypy errors on deprecated methods use. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
29e17fbd6b
commit
7e046ea848
@ -23,6 +23,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from langchain_core._api.internal import is_caller_internal
|
||||
@ -152,10 +154,6 @@ def deprecated(
|
||||
_package: str = package,
|
||||
) -> T:
|
||||
"""Implementation of the decorator returned by `deprecated`."""
|
||||
from langchain_core.utils.pydantic import ( # type: ignore[attr-defined]
|
||||
FieldInfoV1,
|
||||
FieldInfoV2,
|
||||
)
|
||||
|
||||
def emit_warning() -> None:
|
||||
"""Emit the warning."""
|
||||
@ -249,7 +247,7 @@ def deprecated(
|
||||
),
|
||||
)
|
||||
|
||||
elif isinstance(obj, FieldInfoV2):
|
||||
elif isinstance(obj, FieldInfo):
|
||||
wrapped = None
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
@ -261,7 +259,7 @@ def deprecated(
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||
return cast(
|
||||
"T",
|
||||
FieldInfoV2(
|
||||
FieldInfo(
|
||||
default=obj.default,
|
||||
default_factory=obj.default_factory,
|
||||
description=new_doc,
|
||||
|
@ -326,7 +326,7 @@ class BaseOutputParser(
|
||||
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
output_parser_dict = super().model_dump(**kwargs)
|
||||
with contextlib.suppress(NotImplementedError):
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
|
@ -9,6 +9,7 @@ from typing import Annotated, Any, Optional, TypeVar, Union
|
||||
import jsonpatch # type: ignore[import-untyped]
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from pydantic.v1 import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
@ -20,16 +21,9 @@ from langchain_core.utils.json import (
|
||||
parse_json_markdown,
|
||||
parse_partial_json,
|
||||
)
|
||||
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
|
||||
|
||||
if IS_PYDANTIC_V1:
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
|
||||
else:
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel]
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
|
@ -7,6 +7,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
import jsonpatch # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
@ -275,10 +276,13 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
pydantic_schema = self.pydantic_schema[fn_name]
|
||||
else:
|
||||
pydantic_schema = self.pydantic_schema
|
||||
if hasattr(pydantic_schema, "model_validate_json"):
|
||||
if issubclass(pydantic_schema, BaseModel):
|
||||
pydantic_args = pydantic_schema.model_validate_json(_args)
|
||||
else:
|
||||
elif issubclass(pydantic_schema, BaseModelV1):
|
||||
pydantic_args = pydantic_schema.parse_raw(_args)
|
||||
else:
|
||||
msg = f"Unsupported pydantic schema: {pydantic_schema}"
|
||||
raise ValueError(msg)
|
||||
return pydantic_args
|
||||
|
||||
|
||||
|
@ -11,7 +11,6 @@ from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V2,
|
||||
PydanticBaseModel,
|
||||
TBaseModel,
|
||||
)
|
||||
@ -24,22 +23,16 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
"""The pydantic model to parse."""
|
||||
|
||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||
if IS_PYDANTIC_V2:
|
||||
try:
|
||||
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
||||
return self.pydantic_object.model_validate(obj)
|
||||
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
|
||||
return self.pydantic_object.parse_obj(obj)
|
||||
msg = f"Unsupported model version for PydanticOutputParser: \
|
||||
{self.pydantic_object.__class__}"
|
||||
raise OutputParserException(msg)
|
||||
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
||||
raise self._parser_exception(e, obj) from e
|
||||
else: # pydantic v1
|
||||
try:
|
||||
try:
|
||||
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
||||
return self.pydantic_object.model_validate(obj)
|
||||
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
|
||||
return self.pydantic_object.parse_obj(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
raise self._parser_exception(e, obj) from e
|
||||
msg = f"Unsupported model version for PydanticOutputParser: \
|
||||
{self.pydantic_object.__class__}"
|
||||
raise OutputParserException(msg)
|
||||
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
||||
raise self._parser_exception(e, obj) from e
|
||||
|
||||
def _parser_exception(
|
||||
self, e: Exception, json_object: dict
|
||||
|
@ -134,7 +134,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
message=BaseMessageChunk(**chunk.model_dump())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
@ -161,7 +161,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
message=BaseMessageChunk(**chunk.model_dump())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
@ -2,25 +2,10 @@
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
from pydantic.v1 import * # noqa: F403
|
||||
|
||||
from langchain_core._api.deprecation import warn_deprecated
|
||||
|
||||
# Create namespaces for pydantic v1 and v2.
|
||||
# This code must stay at the top of the file before other modules may
|
||||
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
|
||||
#
|
||||
# This hack is done for the following reasons:
|
||||
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
|
||||
# both dependencies and dependents may be stuck on either version of v1 or v2.
|
||||
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
|
||||
# unambiguously uses either v1 or v2 API.
|
||||
# * This change is easier to roll out and roll back.
|
||||
|
||||
try:
|
||||
from pydantic.v1 import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic import * # type: ignore[assignment,no-redef] # noqa: F403
|
||||
|
||||
|
||||
try:
|
||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||
except metadata.PackageNotFoundError:
|
||||
|
@ -1,11 +1,8 @@
|
||||
"""Pydantic v1 compatibility shim."""
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
from pydantic.v1.dataclasses import * # noqa: F403
|
||||
|
||||
try:
|
||||
from pydantic.v1.dataclasses import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic.dataclasses import * # type: ignore[no-redef] # noqa: F403
|
||||
from langchain_core._api import warn_deprecated
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
|
@ -1,11 +1,8 @@
|
||||
"""Pydantic v1 compatibility shim."""
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
from pydantic.v1.main import * # noqa: F403
|
||||
|
||||
try:
|
||||
from pydantic.v1.main import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic.main import * # type: ignore[assignment,no-redef] # noqa: F403
|
||||
from langchain_core._api import warn_deprecated
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
|
@ -540,10 +540,13 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
key_ = next(iter(get_fields(input_args).keys()))
|
||||
if hasattr(input_args, "model_validate"):
|
||||
if issubclass(input_args, BaseModel):
|
||||
input_args.model_validate({key_: tool_input})
|
||||
else:
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
input_args.parse_obj({key_: tool_input})
|
||||
else:
|
||||
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
|
||||
raise TypeError(msg)
|
||||
return tool_input
|
||||
if input_args is not None:
|
||||
if isinstance(input_args, dict):
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
@ -32,7 +32,7 @@ def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
|
||||
class TracerSessionV1Base(BaseModelV1):
|
||||
"""Base class for TracerSessionV1."""
|
||||
|
||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||
name: Optional[str] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
|
||||
@ -69,8 +69,8 @@ class BaseRun(BaseModelV1):
|
||||
|
||||
uuid: str
|
||||
parent_uuid: Optional[str] = None
|
||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||
end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
execution_order: int
|
||||
child_execution_order: int
|
||||
|
@ -21,9 +21,12 @@ from typing import (
|
||||
|
||||
import pydantic
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
|
||||
# root_validator is deprecated but we need it for backward compatibility of @pre_init
|
||||
from pydantic import ( # type: ignore[deprecated]
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PydanticDeprecationWarning,
|
||||
RootModel,
|
||||
root_validator,
|
||||
@ -38,29 +41,23 @@ from pydantic.json_schema import (
|
||||
JsonSchemaMode,
|
||||
JsonSchemaValue,
|
||||
)
|
||||
from typing_extensions import override
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import create_model as create_model_v1
|
||||
from pydantic.v1.fields import ModelField
|
||||
from typing_extensions import deprecated, override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core import core_schema
|
||||
|
||||
try:
|
||||
import pydantic
|
||||
|
||||
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
||||
except ImportError:
|
||||
PYDANTIC_VERSION = version.parse("0.0.0")
|
||||
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
||||
|
||||
|
||||
@deprecated("Use PYDANTIC_VERSION.major instead.")
|
||||
def get_pydantic_major_version() -> int:
|
||||
"""DEPRECATED - Get the major version of Pydantic.
|
||||
|
||||
Use PYDANTIC_VERSION.major instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return PYDANTIC_VERSION.major
|
||||
|
||||
|
||||
@ -70,43 +67,20 @@ PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
|
||||
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
|
||||
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
|
||||
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic.fields import FieldInfo as FieldInfoV1
|
||||
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
TypeBaseModel = type[BaseModel]
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
||||
from pydantic.v1.fields import ModelField
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
||||
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
|
||||
PydanticBaseModel = BaseModel
|
||||
TypeBaseModel = type[BaseModel]
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
|
||||
def is_pydantic_v1_subclass(cls: type) -> bool:
|
||||
"""Check if the installed Pydantic version is 1.x-like."""
|
||||
if IS_PYDANTIC_V1:
|
||||
return True
|
||||
if IS_PYDANTIC_V2:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(cls, BaseModelV1):
|
||||
return True
|
||||
return False
|
||||
return issubclass(cls, BaseModelV1)
|
||||
|
||||
|
||||
def is_pydantic_v2_subclass(cls: type) -> bool:
|
||||
"""Check if the installed Pydantic version is 1.x-like."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
return IS_PYDANTIC_V2 and issubclass(cls, BaseModel)
|
||||
return issubclass(cls, BaseModel)
|
||||
|
||||
|
||||
def is_basemodel_subclass(cls: type) -> bool:
|
||||
@ -114,7 +88,6 @@ def is_basemodel_subclass(cls: type) -> bool:
|
||||
|
||||
Check if the given class is a subclass of any of the following:
|
||||
|
||||
* pydantic.BaseModel in Pydantic 1.x
|
||||
* pydantic.BaseModel in Pydantic 2.x
|
||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||
"""
|
||||
@ -122,24 +95,7 @@ def is_basemodel_subclass(cls: type) -> bool:
|
||||
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
|
||||
return False
|
||||
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
|
||||
if issubclass(cls, BaseModelV1Proper):
|
||||
return True
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(cls, BaseModelV2):
|
||||
return True
|
||||
|
||||
if issubclass(cls, BaseModelV1):
|
||||
return True
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
return False
|
||||
return issubclass(cls, (BaseModel, BaseModelV1))
|
||||
|
||||
|
||||
def is_basemodel_instance(obj: Any) -> bool:
|
||||
@ -147,28 +103,10 @@ def is_basemodel_instance(obj: Any) -> bool:
|
||||
|
||||
Check if the given class is an instance of any of the following:
|
||||
|
||||
* pydantic.BaseModel in Pydantic 1.x
|
||||
* pydantic.BaseModel in Pydantic 2.x
|
||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||
"""
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
|
||||
if isinstance(obj, BaseModelV1Proper):
|
||||
return True
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if isinstance(obj, BaseModelV2):
|
||||
return True
|
||||
|
||||
if isinstance(obj, BaseModelV1):
|
||||
return True
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
return False
|
||||
return isinstance(obj, (BaseModel, BaseModelV1))
|
||||
|
||||
|
||||
# How to type hint this?
|
||||
@ -184,6 +122,9 @@ def pre_init(func: Callable) -> Any:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
|
||||
|
||||
# Ideally we would use @model_validator(mode="before") but this would change the
|
||||
# order of the validators. See https://github.com/pydantic/pydantic/discussions/7434.
|
||||
# So we keep root_validator for backward compatibility.
|
||||
@root_validator(pre=True)
|
||||
@wraps(func)
|
||||
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
|
||||
@ -244,26 +185,18 @@ class _IgnoreUnserializable(GenerateJsonSchema):
|
||||
|
||||
def _create_subset_model_v1(
|
||||
name: str,
|
||||
model: type[BaseModel],
|
||||
model: type[BaseModelV1],
|
||||
field_names: list,
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import create_model
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic.v1 import create_model # type: ignore[no-redef]
|
||||
else:
|
||||
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
fields = {}
|
||||
|
||||
for field_name in field_names:
|
||||
# Using pydantic v1 so can access __fields__ as a dict.
|
||||
field = model.__fields__[field_name] # type: ignore[index]
|
||||
field = model.__fields__[field_name]
|
||||
t = (
|
||||
# this isn't perfect but should work for most functions
|
||||
field.outer_type_
|
||||
@ -274,34 +207,31 @@ def _create_subset_model_v1(
|
||||
field.field_info.description = descriptions[field_name]
|
||||
fields[field_name] = (t, field.field_info)
|
||||
|
||||
rtn = create_model(name, **fields) # type: ignore[call-overload]
|
||||
rtn = create_model_v1(name, **fields) # type: ignore[call-overload]
|
||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||
return rtn
|
||||
|
||||
|
||||
def _create_subset_model_v2(
|
||||
name: str,
|
||||
model: type[pydantic.BaseModel],
|
||||
model: type[BaseModel],
|
||||
field_names: list[str],
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> type[pydantic.BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with a subset of the model fields."""
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
descriptions_ = descriptions or {}
|
||||
fields = {}
|
||||
for field_name in field_names:
|
||||
field = model.model_fields[field_name]
|
||||
description = descriptions_.get(field_name, field.description)
|
||||
field_info = FieldInfo(description=description, default=field.default)
|
||||
field_info = FieldInfoV2(description=description, default=field.default)
|
||||
if field.metadata:
|
||||
field_info.metadata = field.metadata
|
||||
fields[field_name] = (field.annotation, field_info)
|
||||
|
||||
rtn = create_model( # type: ignore[call-overload]
|
||||
rtn = _create_model_base( # type: ignore[call-overload]
|
||||
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
|
||||
)
|
||||
|
||||
@ -322,7 +252,7 @@ def _create_subset_model_v2(
|
||||
|
||||
# Private functionality to create a subset model that's compatible across
|
||||
# different versions of pydantic.
|
||||
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
|
||||
# Handles pydantic versions 2.x. including v1 of pydantic in 2.x.
|
||||
# However, can't find a way to type hint this.
|
||||
def _create_subset_model(
|
||||
name: str,
|
||||
@ -333,7 +263,7 @@ def _create_subset_model(
|
||||
fn_description: Optional[str] = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create subset model using the same pydantic version as the input model."""
|
||||
if IS_PYDANTIC_V1:
|
||||
if issubclass(model, BaseModelV1):
|
||||
return _create_subset_model_v1(
|
||||
name,
|
||||
model,
|
||||
@ -341,68 +271,43 @@ def _create_subset_model(
|
||||
descriptions=descriptions,
|
||||
fn_description=fn_description,
|
||||
)
|
||||
if IS_PYDANTIC_V2:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(model, BaseModelV1):
|
||||
return _create_subset_model_v1(
|
||||
name,
|
||||
model,
|
||||
field_names,
|
||||
descriptions=descriptions,
|
||||
fn_description=fn_description,
|
||||
)
|
||||
return _create_subset_model_v2(
|
||||
name,
|
||||
model,
|
||||
field_names,
|
||||
descriptions=descriptions,
|
||||
fn_description=fn_description,
|
||||
)
|
||||
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise NotImplementedError(msg)
|
||||
return _create_subset_model_v2(
|
||||
name,
|
||||
model,
|
||||
field_names,
|
||||
descriptions=descriptions,
|
||||
fn_description=fn_description,
|
||||
)
|
||||
|
||||
|
||||
if IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
@overload
|
||||
def get_fields(model: type[BaseModel]) -> dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
|
||||
@overload
|
||||
def get_fields(model: BaseModel) -> dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
|
||||
@overload
|
||||
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
|
||||
|
||||
def get_fields(
|
||||
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
||||
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
return model.model_fields
|
||||
|
||||
if hasattr(model, "__fields__"):
|
||||
return model.__fields__
|
||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||
raise TypeError(msg)
|
||||
@overload
|
||||
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
|
||||
|
||||
elif IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1_
|
||||
|
||||
def get_fields( # type: ignore[no-redef]
|
||||
model: Union[type[BaseModelV1_], BaseModelV1_],
|
||||
) -> dict[str, FieldInfoV1]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
return model.__fields__ # type: ignore[return-value]
|
||||
def get_fields(
|
||||
model: Union[type[Union[BaseModel, BaseModelV1]], BaseModel, BaseModelV1],
|
||||
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
return model.model_fields
|
||||
|
||||
if hasattr(model, "__fields__"):
|
||||
return model.__fields__
|
||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
|
||||
_SchemaConfig = ConfigDict(
|
||||
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
||||
@ -546,14 +451,11 @@ _RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
|
||||
|
||||
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
||||
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
remapped = {}
|
||||
for key, value in field_definitions.items():
|
||||
if key.startswith("_") or key in _RESERVED_NAMES:
|
||||
# Let's add a prefix to avoid colliding with internal pydantic fields
|
||||
if isinstance(value, FieldInfo):
|
||||
if isinstance(value, FieldInfoV2):
|
||||
msg = (
|
||||
f"Remapping for fields starting with '_' or fields with a name "
|
||||
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
|
||||
|
@ -69,7 +69,6 @@ langchain-text-splitters = { path = "../text-splitters" }
|
||||
strict = "True"
|
||||
strict_bytes = "True"
|
||||
enable_error_code = "deprecated"
|
||||
report_deprecated_as_note = "True"
|
||||
|
||||
# TODO: activate for 'strict' checking
|
||||
disallow_any_generics = "False"
|
||||
|
@ -16,10 +16,6 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V1,
|
||||
IS_PYDANTIC_V2,
|
||||
)
|
||||
|
||||
STREAMED_MESSAGES: list = [
|
||||
AIMessageChunk(content=""),
|
||||
@ -532,7 +528,6 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||
import pydantic
|
||||
@ -567,7 +562,6 @@ def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
||||
import pydantic
|
||||
@ -602,41 +596,6 @@ def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1")
|
||||
def test_parse_with_different_pydantic_1_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 1."""
|
||||
import pydantic
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast])
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_max_tokens_error(caplog: Any) -> None:
|
||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||
message = AIMessage(
|
||||
|
@ -65,8 +65,6 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V1,
|
||||
IS_PYDANTIC_V2,
|
||||
_create_subset_model,
|
||||
create_model_v2,
|
||||
)
|
||||
@ -79,9 +77,11 @@ def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
||||
if isinstance(tool_schema, dict):
|
||||
return tool_schema
|
||||
|
||||
if hasattr(tool_schema, "model_json_schema"):
|
||||
if issubclass(tool_schema, BaseModel):
|
||||
return tool_schema.model_json_schema()
|
||||
return tool_schema.schema()
|
||||
if issubclass(tool_schema, BaseModelV1):
|
||||
return tool_schema.schema()
|
||||
return {}
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@ -1853,11 +1853,14 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||
)
|
||||
|
||||
input_schema = tool.get_input_schema()
|
||||
input_json_schema = (
|
||||
input_schema.model_json_schema()
|
||||
if hasattr(input_schema, "model_json_schema")
|
||||
else input_schema.schema()
|
||||
)
|
||||
if issubclass(input_schema, BaseModel):
|
||||
input_json_schema = input_schema.model_json_schema()
|
||||
elif issubclass(input_schema, BaseModelV1):
|
||||
input_json_schema = input_schema.schema()
|
||||
else:
|
||||
msg = "Unknown input schema type"
|
||||
raise TypeError(msg)
|
||||
|
||||
assert input_json_schema == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
@ -1943,12 +1946,14 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
||||
|
||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||
|
||||
args_schema = cast("BaseModel", foo_tool.args_schema)
|
||||
args_json_schema = (
|
||||
args_schema.model_json_schema()
|
||||
if hasattr(args_schema, "model_json_schema")
|
||||
else args_schema.schema()
|
||||
)
|
||||
args_schema = cast("type[BaseModel]", foo_tool.args_schema)
|
||||
if issubclass(args_schema, BaseModel):
|
||||
args_json_schema = args_schema.model_json_schema()
|
||||
elif issubclass(args_schema, BaseModelV1):
|
||||
args_json_schema = args_schema.schema()
|
||||
else:
|
||||
msg = "Unknown input schema type"
|
||||
raise TypeError(msg)
|
||||
assert args_json_schema == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
@ -1960,11 +1965,13 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
||||
}
|
||||
|
||||
input_schema = foo_tool.get_input_schema()
|
||||
input_json_schema = (
|
||||
input_schema.model_json_schema()
|
||||
if hasattr(input_schema, "model_json_schema")
|
||||
else input_schema.schema()
|
||||
)
|
||||
if issubclass(input_schema, BaseModel):
|
||||
input_json_schema = input_schema.model_json_schema()
|
||||
elif issubclass(input_schema, BaseModelV1):
|
||||
input_json_schema = input_schema.schema()
|
||||
else:
|
||||
msg = "Unknown input schema type"
|
||||
raise TypeError(msg)
|
||||
assert input_json_schema == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
@ -2020,7 +2027,6 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
|
||||
assert _is_message_content_type(obj) is expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
||||
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
||||
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
||||
A = TypeVar("A")
|
||||
@ -2089,63 +2095,6 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.")
|
||||
def test__get_all_basemodel_annotations_v1() -> None:
|
||||
A = TypeVar("A")
|
||||
|
||||
class ModelA(BaseModel, Generic[A], extra="allow"):
|
||||
a: A
|
||||
|
||||
class ModelB(ModelA[str]):
|
||||
b: Annotated[ModelA[dict[str, Any]], "foo"]
|
||||
|
||||
class Mixin:
|
||||
def foo(self) -> str:
|
||||
return "foo"
|
||||
|
||||
class ModelC(Mixin, ModelB):
|
||||
c: dict
|
||||
|
||||
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict}
|
||||
actual = get_all_basemodel_annotations(ModelC)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"]}
|
||||
actual = get_all_basemodel_annotations(ModelB)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": Any}
|
||||
actual = get_all_basemodel_annotations(ModelA)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": int}
|
||||
actual = get_all_basemodel_annotations(ModelA[int])
|
||||
assert actual == expected
|
||||
|
||||
D = TypeVar("D", bound=Union[str, int])
|
||||
|
||||
class ModelD(ModelC, Generic[D]):
|
||||
d: Optional[D]
|
||||
|
||||
expected = {
|
||||
"a": str,
|
||||
"b": Annotated[ModelA[dict[str, Any]], "foo"],
|
||||
"c": dict,
|
||||
"d": Union[str, int, None],
|
||||
}
|
||||
actual = get_all_basemodel_annotations(ModelD)
|
||||
assert actual == expected
|
||||
|
||||
expected = {
|
||||
"a": str,
|
||||
"b": Annotated[ModelA[dict[str, Any]], "foo"],
|
||||
"c": dict,
|
||||
"d": Union[int, None],
|
||||
}
|
||||
actual = get_all_basemodel_annotations(ModelD[int])
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_get_all_basemodel_annotations_aliases() -> None:
|
||||
class CalculatorInput(BaseModel):
|
||||
a: int = Field(description="first number", alias="A")
|
||||
@ -2226,7 +2175,6 @@ def test_create_retriever_tool() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic import Field as FieldV2
|
||||
|
@ -3,13 +3,9 @@
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V1,
|
||||
IS_PYDANTIC_V2,
|
||||
PYDANTIC_VERSION,
|
||||
_create_subset_model_v2,
|
||||
create_model_v2,
|
||||
get_fields,
|
||||
@ -96,50 +92,29 @@ def test_with_aliases() -> None:
|
||||
|
||||
def test_is_basemodel_subclass() -> None:
|
||||
"""Test pydantic."""
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV1Proper)
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV2)
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV1)
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
assert is_basemodel_subclass(BaseModelV2)
|
||||
assert is_basemodel_subclass(BaseModelV1)
|
||||
|
||||
|
||||
def test_is_basemodel_instance() -> None:
|
||||
"""Test pydantic."""
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
class FooV1(BaseModelV1Proper):
|
||||
x: int
|
||||
class Foo(BaseModelV2):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(FooV1(x=5))
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
assert is_basemodel_instance(Foo(x=5))
|
||||
|
||||
class Foo(BaseModelV2):
|
||||
x: int
|
||||
class Bar(BaseModelV1):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(Foo(x=5))
|
||||
|
||||
class Bar(BaseModelV1):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(Bar(x=5))
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
assert is_basemodel_instance(Bar(x=5))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||
def test_with_field_metadata() -> None:
|
||||
"""Test pydantic with field metadata."""
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
@ -168,18 +143,6 @@ def test_with_field_metadata() -> None:
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1")
|
||||
def test_fields_pydantic_v1() -> None:
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
fields = get_fields(Foo)
|
||||
assert fields == {"x": Foo.model_fields["x"]}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||
def test_fields_pydantic_v2_proper() -> None:
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -190,7 +153,6 @@ def test_fields_pydantic_v2_proper() -> None:
|
||||
assert fields == {"x": Foo.model_fields["x"]}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||
def test_fields_pydantic_v1_from_2() -> None:
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
|
@ -16,10 +16,6 @@ from langchain_core.utils import (
|
||||
guard_import,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V1,
|
||||
IS_PYDANTIC_V2,
|
||||
)
|
||||
from langchain_core.utils.utils import secret_from_env
|
||||
|
||||
|
||||
@ -214,7 +210,6 @@ def test_guard_import_failure(
|
||||
guard_import(module_name, pip_name=pip_name, package=package)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
||||
def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
||||
from pydantic.v1 import Field
|
||||
@ -229,7 +224,6 @@ def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
||||
def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -243,20 +237,6 @@ def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1")
|
||||
def test_get_pydantic_field_names_v1() -> None:
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PydanticModel(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticModel)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_from_env_with_env_variable() -> None:
|
||||
key = "TEST_KEY"
|
||||
value = "test_value"
|
||||
|
Loading…
Reference in New Issue
Block a user