mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 10:54:15 +00:00
core: Cleanup Pydantic models and handle deprecation warnings (#30799)
* Simplified Pydantic handling since Pydantic v1 is not supported anymore. * Replace use of deprecated v1 methods by corresponding v2 methods. * Remove use of other deprecated methods. * Activate mypy errors on deprecated methods use. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
29e17fbd6b
commit
7e046ea848
@ -23,6 +23,8 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
from pydantic.v1.fields import FieldInfo as FieldInfoV1
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from langchain_core._api.internal import is_caller_internal
|
from langchain_core._api.internal import is_caller_internal
|
||||||
@ -152,10 +154,6 @@ def deprecated(
|
|||||||
_package: str = package,
|
_package: str = package,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Implementation of the decorator returned by `deprecated`."""
|
"""Implementation of the decorator returned by `deprecated`."""
|
||||||
from langchain_core.utils.pydantic import ( # type: ignore[attr-defined]
|
|
||||||
FieldInfoV1,
|
|
||||||
FieldInfoV2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def emit_warning() -> None:
|
def emit_warning() -> None:
|
||||||
"""Emit the warning."""
|
"""Emit the warning."""
|
||||||
@ -249,7 +247,7 @@ def deprecated(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(obj, FieldInfoV2):
|
elif isinstance(obj, FieldInfo):
|
||||||
wrapped = None
|
wrapped = None
|
||||||
if not _obj_type:
|
if not _obj_type:
|
||||||
_obj_type = "attribute"
|
_obj_type = "attribute"
|
||||||
@ -261,7 +259,7 @@ def deprecated(
|
|||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||||
return cast(
|
return cast(
|
||||||
"T",
|
"T",
|
||||||
FieldInfoV2(
|
FieldInfo(
|
||||||
default=obj.default,
|
default=obj.default,
|
||||||
default_factory=obj.default_factory,
|
default_factory=obj.default_factory,
|
||||||
description=new_doc,
|
description=new_doc,
|
||||||
|
@ -326,7 +326,7 @@ class BaseOutputParser(
|
|||||||
|
|
||||||
def dict(self, **kwargs: Any) -> dict:
|
def dict(self, **kwargs: Any) -> dict:
|
||||||
"""Return dictionary representation of output parser."""
|
"""Return dictionary representation of output parser."""
|
||||||
output_parser_dict = super().dict(**kwargs)
|
output_parser_dict = super().model_dump(**kwargs)
|
||||||
with contextlib.suppress(NotImplementedError):
|
with contextlib.suppress(NotImplementedError):
|
||||||
output_parser_dict["_type"] = self._type
|
output_parser_dict["_type"] = self._type
|
||||||
return output_parser_dict
|
return output_parser_dict
|
||||||
|
@ -9,6 +9,7 @@ from typing import Annotated, Any, Optional, TypeVar, Union
|
|||||||
import jsonpatch # type: ignore[import-untyped]
|
import jsonpatch # type: ignore[import-untyped]
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import SkipValidation
|
from pydantic import SkipValidation
|
||||||
|
from pydantic.v1 import BaseModel
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
@ -20,16 +21,9 @@ from langchain_core.utils.json import (
|
|||||||
parse_json_markdown,
|
parse_json_markdown,
|
||||||
parse_partial_json,
|
parse_partial_json,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
|
|
||||||
|
|
||||||
if IS_PYDANTIC_V1:
|
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||||
PydanticBaseModel = pydantic.BaseModel
|
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel]
|
||||||
|
|
||||||
else:
|
|
||||||
from pydantic.v1 import BaseModel
|
|
||||||
|
|
||||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
|
||||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
|
||||||
|
|
||||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
import jsonpatch # type: ignore[import-untyped]
|
import jsonpatch # type: ignore[import-untyped]
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
@ -275,10 +276,13 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|||||||
pydantic_schema = self.pydantic_schema[fn_name]
|
pydantic_schema = self.pydantic_schema[fn_name]
|
||||||
else:
|
else:
|
||||||
pydantic_schema = self.pydantic_schema
|
pydantic_schema = self.pydantic_schema
|
||||||
if hasattr(pydantic_schema, "model_validate_json"):
|
if issubclass(pydantic_schema, BaseModel):
|
||||||
pydantic_args = pydantic_schema.model_validate_json(_args)
|
pydantic_args = pydantic_schema.model_validate_json(_args)
|
||||||
else:
|
elif issubclass(pydantic_schema, BaseModelV1):
|
||||||
pydantic_args = pydantic_schema.parse_raw(_args)
|
pydantic_args = pydantic_schema.parse_raw(_args)
|
||||||
|
else:
|
||||||
|
msg = f"Unsupported pydantic schema: {pydantic_schema}"
|
||||||
|
raise ValueError(msg)
|
||||||
return pydantic_args
|
return pydantic_args
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@ from langchain_core.exceptions import OutputParserException
|
|||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
PydanticBaseModel,
|
PydanticBaseModel,
|
||||||
TBaseModel,
|
TBaseModel,
|
||||||
)
|
)
|
||||||
@ -24,22 +23,16 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
"""The pydantic model to parse."""
|
"""The pydantic model to parse."""
|
||||||
|
|
||||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||||
if IS_PYDANTIC_V2:
|
try:
|
||||||
try:
|
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
||||||
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
return self.pydantic_object.model_validate(obj)
|
||||||
return self.pydantic_object.model_validate(obj)
|
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
|
||||||
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
|
|
||||||
return self.pydantic_object.parse_obj(obj)
|
|
||||||
msg = f"Unsupported model version for PydanticOutputParser: \
|
|
||||||
{self.pydantic_object.__class__}"
|
|
||||||
raise OutputParserException(msg)
|
|
||||||
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
|
||||||
raise self._parser_exception(e, obj) from e
|
|
||||||
else: # pydantic v1
|
|
||||||
try:
|
|
||||||
return self.pydantic_object.parse_obj(obj)
|
return self.pydantic_object.parse_obj(obj)
|
||||||
except pydantic.ValidationError as e:
|
msg = f"Unsupported model version for PydanticOutputParser: \
|
||||||
raise self._parser_exception(e, obj) from e
|
{self.pydantic_object.__class__}"
|
||||||
|
raise OutputParserException(msg)
|
||||||
|
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
||||||
|
raise self._parser_exception(e, obj) from e
|
||||||
|
|
||||||
def _parser_exception(
|
def _parser_exception(
|
||||||
self, e: Exception, json_object: dict
|
self, e: Exception, json_object: dict
|
||||||
|
@ -134,7 +134,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
message=BaseMessageChunk(**chunk.model_dump())
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
@ -161,7 +161,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
message=BaseMessageChunk(**chunk.model_dump())
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
@ -2,25 +2,10 @@
|
|||||||
|
|
||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
|
|
||||||
|
from pydantic.v1 import * # noqa: F403
|
||||||
|
|
||||||
from langchain_core._api.deprecation import warn_deprecated
|
from langchain_core._api.deprecation import warn_deprecated
|
||||||
|
|
||||||
# Create namespaces for pydantic v1 and v2.
|
|
||||||
# This code must stay at the top of the file before other modules may
|
|
||||||
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
|
|
||||||
#
|
|
||||||
# This hack is done for the following reasons:
|
|
||||||
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
|
|
||||||
# both dependencies and dependents may be stuck on either version of v1 or v2.
|
|
||||||
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
|
|
||||||
# unambiguously uses either v1 or v2 API.
|
|
||||||
# * This change is easier to roll out and roll back.
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pydantic.v1 import * # noqa: F403
|
|
||||||
except ImportError:
|
|
||||||
from pydantic import * # type: ignore[assignment,no-redef] # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||||
except metadata.PackageNotFoundError:
|
except metadata.PackageNotFoundError:
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
"""Pydantic v1 compatibility shim."""
|
"""Pydantic v1 compatibility shim."""
|
||||||
|
|
||||||
from langchain_core._api import warn_deprecated
|
from pydantic.v1.dataclasses import * # noqa: F403
|
||||||
|
|
||||||
try:
|
from langchain_core._api import warn_deprecated
|
||||||
from pydantic.v1.dataclasses import * # noqa: F403
|
|
||||||
except ImportError:
|
|
||||||
from pydantic.dataclasses import * # type: ignore[no-redef] # noqa: F403
|
|
||||||
|
|
||||||
warn_deprecated(
|
warn_deprecated(
|
||||||
"0.3.0",
|
"0.3.0",
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
"""Pydantic v1 compatibility shim."""
|
"""Pydantic v1 compatibility shim."""
|
||||||
|
|
||||||
from langchain_core._api import warn_deprecated
|
from pydantic.v1.main import * # noqa: F403
|
||||||
|
|
||||||
try:
|
from langchain_core._api import warn_deprecated
|
||||||
from pydantic.v1.main import * # noqa: F403
|
|
||||||
except ImportError:
|
|
||||||
from pydantic.main import * # type: ignore[assignment,no-redef] # noqa: F403
|
|
||||||
|
|
||||||
warn_deprecated(
|
warn_deprecated(
|
||||||
"0.3.0",
|
"0.3.0",
|
||||||
|
@ -540,10 +540,13 @@ class ChildTool(BaseTool):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
key_ = next(iter(get_fields(input_args).keys()))
|
key_ = next(iter(get_fields(input_args).keys()))
|
||||||
if hasattr(input_args, "model_validate"):
|
if issubclass(input_args, BaseModel):
|
||||||
input_args.model_validate({key_: tool_input})
|
input_args.model_validate({key_: tool_input})
|
||||||
else:
|
elif issubclass(input_args, BaseModelV1):
|
||||||
input_args.parse_obj({key_: tool_input})
|
input_args.parse_obj({key_: tool_input})
|
||||||
|
else:
|
||||||
|
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
|
||||||
|
raise TypeError(msg)
|
||||||
return tool_input
|
return tool_input
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
if isinstance(input_args, dict):
|
if isinstance(input_args, dict):
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
|
|||||||
class TracerSessionV1Base(BaseModelV1):
|
class TracerSessionV1Base(BaseModelV1):
|
||||||
"""Base class for TracerSessionV1."""
|
"""Base class for TracerSessionV1."""
|
||||||
|
|
||||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
extra: Optional[dict[str, Any]] = None
|
extra: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
@ -69,8 +69,8 @@ class BaseRun(BaseModelV1):
|
|||||||
|
|
||||||
uuid: str
|
uuid: str
|
||||||
parent_uuid: Optional[str] = None
|
parent_uuid: Optional[str] = None
|
||||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
extra: Optional[dict[str, Any]] = None
|
extra: Optional[dict[str, Any]] = None
|
||||||
execution_order: int
|
execution_order: int
|
||||||
child_execution_order: int
|
child_execution_order: int
|
||||||
|
@ -21,9 +21,12 @@ from typing import (
|
|||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import (
|
|
||||||
|
# root_validator is deprecated but we need it for backward compatibility of @pre_init
|
||||||
|
from pydantic import ( # type: ignore[deprecated]
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
|
Field,
|
||||||
PydanticDeprecationWarning,
|
PydanticDeprecationWarning,
|
||||||
RootModel,
|
RootModel,
|
||||||
root_validator,
|
root_validator,
|
||||||
@ -38,29 +41,23 @@ from pydantic.json_schema import (
|
|||||||
JsonSchemaMode,
|
JsonSchemaMode,
|
||||||
JsonSchemaValue,
|
JsonSchemaValue,
|
||||||
)
|
)
|
||||||
from typing_extensions import override
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
from pydantic.v1 import create_model as create_model_v1
|
||||||
|
from pydantic.v1.fields import ModelField
|
||||||
|
from typing_extensions import deprecated, override
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
try:
|
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
||||||
import pydantic
|
|
||||||
|
|
||||||
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
|
||||||
except ImportError:
|
|
||||||
PYDANTIC_VERSION = version.parse("0.0.0")
|
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("Use PYDANTIC_VERSION.major instead.")
|
||||||
def get_pydantic_major_version() -> int:
|
def get_pydantic_major_version() -> int:
|
||||||
"""DEPRECATED - Get the major version of Pydantic.
|
"""DEPRECATED - Get the major version of Pydantic.
|
||||||
|
|
||||||
Use PYDANTIC_VERSION.major instead.
|
Use PYDANTIC_VERSION.major instead.
|
||||||
"""
|
"""
|
||||||
warnings.warn(
|
|
||||||
"get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return PYDANTIC_VERSION.major
|
return PYDANTIC_VERSION.major
|
||||||
|
|
||||||
|
|
||||||
@ -70,43 +67,20 @@ PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
|
|||||||
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
|
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
|
||||||
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
|
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
|
||||||
|
|
||||||
if IS_PYDANTIC_V1:
|
PydanticBaseModel = BaseModel
|
||||||
from pydantic.fields import FieldInfo as FieldInfoV1
|
TypeBaseModel = type[BaseModel]
|
||||||
|
|
||||||
PydanticBaseModel = pydantic.BaseModel
|
|
||||||
TypeBaseModel = type[BaseModel]
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
|
||||||
from pydantic.v1.fields import ModelField
|
|
||||||
|
|
||||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
|
||||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
|
||||||
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||||
|
|
||||||
|
|
||||||
def is_pydantic_v1_subclass(cls: type) -> bool:
|
def is_pydantic_v1_subclass(cls: type) -> bool:
|
||||||
"""Check if the installed Pydantic version is 1.x-like."""
|
"""Check if the installed Pydantic version is 1.x-like."""
|
||||||
if IS_PYDANTIC_V1:
|
return issubclass(cls, BaseModelV1)
|
||||||
return True
|
|
||||||
if IS_PYDANTIC_V2:
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_pydantic_v2_subclass(cls: type) -> bool:
|
def is_pydantic_v2_subclass(cls: type) -> bool:
|
||||||
"""Check if the installed Pydantic version is 1.x-like."""
|
"""Check if the installed Pydantic version is 1.x-like."""
|
||||||
from pydantic import BaseModel
|
return issubclass(cls, BaseModel)
|
||||||
|
|
||||||
return IS_PYDANTIC_V2 and issubclass(cls, BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_subclass(cls: type) -> bool:
|
def is_basemodel_subclass(cls: type) -> bool:
|
||||||
@ -114,7 +88,6 @@ def is_basemodel_subclass(cls: type) -> bool:
|
|||||||
|
|
||||||
Check if the given class is a subclass of any of the following:
|
Check if the given class is a subclass of any of the following:
|
||||||
|
|
||||||
* pydantic.BaseModel in Pydantic 1.x
|
|
||||||
* pydantic.BaseModel in Pydantic 2.x
|
* pydantic.BaseModel in Pydantic 2.x
|
||||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||||
"""
|
"""
|
||||||
@ -122,24 +95,7 @@ def is_basemodel_subclass(cls: type) -> bool:
|
|||||||
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
|
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if IS_PYDANTIC_V1:
|
return issubclass(cls, (BaseModel, BaseModelV1))
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1Proper):
|
|
||||||
return True
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV2):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_instance(obj: Any) -> bool:
|
def is_basemodel_instance(obj: Any) -> bool:
|
||||||
@ -147,28 +103,10 @@ def is_basemodel_instance(obj: Any) -> bool:
|
|||||||
|
|
||||||
Check if the given class is an instance of any of the following:
|
Check if the given class is an instance of any of the following:
|
||||||
|
|
||||||
* pydantic.BaseModel in Pydantic 1.x
|
|
||||||
* pydantic.BaseModel in Pydantic 2.x
|
* pydantic.BaseModel in Pydantic 2.x
|
||||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||||
"""
|
"""
|
||||||
if IS_PYDANTIC_V1:
|
return isinstance(obj, (BaseModel, BaseModelV1))
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
|
||||||
|
|
||||||
if isinstance(obj, BaseModelV1Proper):
|
|
||||||
return True
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
if isinstance(obj, BaseModelV2):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if isinstance(obj, BaseModelV1):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# How to type hint this?
|
# How to type hint this?
|
||||||
@ -184,6 +122,9 @@ def pre_init(func: Callable) -> Any:
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
|
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
|
||||||
|
|
||||||
|
# Ideally we would use @model_validator(mode="before") but this would change the
|
||||||
|
# order of the validators. See https://github.com/pydantic/pydantic/discussions/7434.
|
||||||
|
# So we keep root_validator for backward compatibility.
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
|
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
|
||||||
@ -244,26 +185,18 @@ class _IgnoreUnserializable(GenerateJsonSchema):
|
|||||||
|
|
||||||
def _create_subset_model_v1(
|
def _create_subset_model_v1(
|
||||||
name: str,
|
name: str,
|
||||||
model: type[BaseModel],
|
model: type[BaseModelV1],
|
||||||
field_names: list,
|
field_names: list,
|
||||||
*,
|
*,
|
||||||
descriptions: Optional[dict] = None,
|
descriptions: Optional[dict] = None,
|
||||||
fn_description: Optional[str] = None,
|
fn_description: Optional[str] = None,
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a pydantic model with only a subset of model's fields."""
|
"""Create a pydantic model with only a subset of model's fields."""
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
from pydantic import create_model
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic.v1 import create_model # type: ignore[no-redef]
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
|
|
||||||
for field_name in field_names:
|
for field_name in field_names:
|
||||||
# Using pydantic v1 so can access __fields__ as a dict.
|
# Using pydantic v1 so can access __fields__ as a dict.
|
||||||
field = model.__fields__[field_name] # type: ignore[index]
|
field = model.__fields__[field_name]
|
||||||
t = (
|
t = (
|
||||||
# this isn't perfect but should work for most functions
|
# this isn't perfect but should work for most functions
|
||||||
field.outer_type_
|
field.outer_type_
|
||||||
@ -274,34 +207,31 @@ def _create_subset_model_v1(
|
|||||||
field.field_info.description = descriptions[field_name]
|
field.field_info.description = descriptions[field_name]
|
||||||
fields[field_name] = (t, field.field_info)
|
fields[field_name] = (t, field.field_info)
|
||||||
|
|
||||||
rtn = create_model(name, **fields) # type: ignore[call-overload]
|
rtn = create_model_v1(name, **fields) # type: ignore[call-overload]
|
||||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
def _create_subset_model_v2(
|
def _create_subset_model_v2(
|
||||||
name: str,
|
name: str,
|
||||||
model: type[pydantic.BaseModel],
|
model: type[BaseModel],
|
||||||
field_names: list[str],
|
field_names: list[str],
|
||||||
*,
|
*,
|
||||||
descriptions: Optional[dict] = None,
|
descriptions: Optional[dict] = None,
|
||||||
fn_description: Optional[str] = None,
|
fn_description: Optional[str] = None,
|
||||||
) -> type[pydantic.BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a pydantic model with a subset of the model fields."""
|
"""Create a pydantic model with a subset of the model fields."""
|
||||||
from pydantic import create_model
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
|
|
||||||
descriptions_ = descriptions or {}
|
descriptions_ = descriptions or {}
|
||||||
fields = {}
|
fields = {}
|
||||||
for field_name in field_names:
|
for field_name in field_names:
|
||||||
field = model.model_fields[field_name]
|
field = model.model_fields[field_name]
|
||||||
description = descriptions_.get(field_name, field.description)
|
description = descriptions_.get(field_name, field.description)
|
||||||
field_info = FieldInfo(description=description, default=field.default)
|
field_info = FieldInfoV2(description=description, default=field.default)
|
||||||
if field.metadata:
|
if field.metadata:
|
||||||
field_info.metadata = field.metadata
|
field_info.metadata = field.metadata
|
||||||
fields[field_name] = (field.annotation, field_info)
|
fields[field_name] = (field.annotation, field_info)
|
||||||
|
|
||||||
rtn = create_model( # type: ignore[call-overload]
|
rtn = _create_model_base( # type: ignore[call-overload]
|
||||||
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
|
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -322,7 +252,7 @@ def _create_subset_model_v2(
|
|||||||
|
|
||||||
# Private functionality to create a subset model that's compatible across
|
# Private functionality to create a subset model that's compatible across
|
||||||
# different versions of pydantic.
|
# different versions of pydantic.
|
||||||
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
|
# Handles pydantic versions 2.x. including v1 of pydantic in 2.x.
|
||||||
# However, can't find a way to type hint this.
|
# However, can't find a way to type hint this.
|
||||||
def _create_subset_model(
|
def _create_subset_model(
|
||||||
name: str,
|
name: str,
|
||||||
@ -333,7 +263,7 @@ def _create_subset_model(
|
|||||||
fn_description: Optional[str] = None,
|
fn_description: Optional[str] = None,
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create subset model using the same pydantic version as the input model."""
|
"""Create subset model using the same pydantic version as the input model."""
|
||||||
if IS_PYDANTIC_V1:
|
if issubclass(model, BaseModelV1):
|
||||||
return _create_subset_model_v1(
|
return _create_subset_model_v1(
|
||||||
name,
|
name,
|
||||||
model,
|
model,
|
||||||
@ -341,68 +271,43 @@ def _create_subset_model(
|
|||||||
descriptions=descriptions,
|
descriptions=descriptions,
|
||||||
fn_description=fn_description,
|
fn_description=fn_description,
|
||||||
)
|
)
|
||||||
if IS_PYDANTIC_V2:
|
return _create_subset_model_v2(
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
name,
|
||||||
|
model,
|
||||||
if issubclass(model, BaseModelV1):
|
field_names,
|
||||||
return _create_subset_model_v1(
|
descriptions=descriptions,
|
||||||
name,
|
fn_description=fn_description,
|
||||||
model,
|
)
|
||||||
field_names,
|
|
||||||
descriptions=descriptions,
|
|
||||||
fn_description=fn_description,
|
|
||||||
)
|
|
||||||
return _create_subset_model_v2(
|
|
||||||
name,
|
|
||||||
model,
|
|
||||||
field_names,
|
|
||||||
descriptions=descriptions,
|
|
||||||
fn_description=fn_description,
|
|
||||||
)
|
|
||||||
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
if IS_PYDANTIC_V2:
|
@overload
|
||||||
from pydantic import BaseModel as BaseModelV2
|
def get_fields(model: type[BaseModel]) -> dict[str, FieldInfoV2]: ...
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
|
def get_fields(model: BaseModel) -> dict[str, FieldInfoV2]: ...
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
|
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
|
||||||
|
|
||||||
def get_fields(
|
|
||||||
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
|
||||||
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
|
|
||||||
"""Get the field names of a Pydantic model."""
|
|
||||||
if hasattr(model, "model_fields"):
|
|
||||||
return model.model_fields
|
|
||||||
|
|
||||||
if hasattr(model, "__fields__"):
|
@overload
|
||||||
return model.__fields__
|
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
|
||||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
|
||||||
raise TypeError(msg)
|
|
||||||
|
|
||||||
elif IS_PYDANTIC_V1:
|
|
||||||
from pydantic import BaseModel as BaseModelV1_
|
|
||||||
|
|
||||||
def get_fields( # type: ignore[no-redef]
|
def get_fields(
|
||||||
model: Union[type[BaseModelV1_], BaseModelV1_],
|
model: Union[type[Union[BaseModel, BaseModelV1]], BaseModel, BaseModelV1],
|
||||||
) -> dict[str, FieldInfoV1]:
|
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
|
||||||
"""Get the field names of a Pydantic model."""
|
"""Get the field names of a Pydantic model."""
|
||||||
return model.__fields__ # type: ignore[return-value]
|
if hasattr(model, "model_fields"):
|
||||||
|
return model.model_fields
|
||||||
|
|
||||||
|
if hasattr(model, "__fields__"):
|
||||||
|
return model.__fields__
|
||||||
|
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
_SchemaConfig = ConfigDict(
|
_SchemaConfig = ConfigDict(
|
||||||
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
||||||
@ -546,14 +451,11 @@ _RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
|
|||||||
|
|
||||||
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
||||||
from pydantic import Field
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
|
|
||||||
remapped = {}
|
remapped = {}
|
||||||
for key, value in field_definitions.items():
|
for key, value in field_definitions.items():
|
||||||
if key.startswith("_") or key in _RESERVED_NAMES:
|
if key.startswith("_") or key in _RESERVED_NAMES:
|
||||||
# Let's add a prefix to avoid colliding with internal pydantic fields
|
# Let's add a prefix to avoid colliding with internal pydantic fields
|
||||||
if isinstance(value, FieldInfo):
|
if isinstance(value, FieldInfoV2):
|
||||||
msg = (
|
msg = (
|
||||||
f"Remapping for fields starting with '_' or fields with a name "
|
f"Remapping for fields starting with '_' or fields with a name "
|
||||||
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
|
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
|
||||||
|
@ -69,7 +69,6 @@ langchain-text-splitters = { path = "../text-splitters" }
|
|||||||
strict = "True"
|
strict = "True"
|
||||||
strict_bytes = "True"
|
strict_bytes = "True"
|
||||||
enable_error_code = "deprecated"
|
enable_error_code = "deprecated"
|
||||||
report_deprecated_as_note = "True"
|
|
||||||
|
|
||||||
# TODO: activate for 'strict' checking
|
# TODO: activate for 'strict' checking
|
||||||
disallow_any_generics = "False"
|
disallow_any_generics = "False"
|
||||||
|
@ -16,10 +16,6 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration
|
from langchain_core.outputs import ChatGeneration
|
||||||
from langchain_core.utils.pydantic import (
|
|
||||||
IS_PYDANTIC_V1,
|
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
)
|
|
||||||
|
|
||||||
STREAMED_MESSAGES: list = [
|
STREAMED_MESSAGES: list = [
|
||||||
AIMessageChunk(content=""),
|
AIMessageChunk(content=""),
|
||||||
@ -532,7 +528,6 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
|||||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
|
||||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -567,7 +562,6 @@ def test_parse_with_different_pydantic_2_v1() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
|
||||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
def test_parse_with_different_pydantic_2_proper() -> None:
|
||||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
"""Test with pydantic.BaseModel from pydantic 2."""
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -602,41 +596,6 @@ def test_parse_with_different_pydantic_2_proper() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1")
|
|
||||||
def test_parse_with_different_pydantic_1_proper() -> None:
|
|
||||||
"""Test with pydantic.BaseModel from pydantic 1."""
|
|
||||||
import pydantic
|
|
||||||
|
|
||||||
class Forecast(pydantic.BaseModel):
|
|
||||||
temperature: int
|
|
||||||
forecast: str
|
|
||||||
|
|
||||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
|
||||||
# both v1 and v2 in the same codebase.
|
|
||||||
parser = PydanticToolsParser(tools=[Forecast])
|
|
||||||
message = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "call_OwL7f5PE",
|
|
||||||
"name": "Forecast",
|
|
||||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
generation = ChatGeneration(
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert parser.parse_result([generation]) == [
|
|
||||||
Forecast(
|
|
||||||
temperature=20,
|
|
||||||
forecast="Sunny",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_max_tokens_error(caplog: Any) -> None:
|
def test_max_tokens_error(caplog: Any) -> None:
|
||||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||||
message = AIMessage(
|
message = AIMessage(
|
||||||
|
@ -65,8 +65,6 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
IS_PYDANTIC_V1,
|
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
_create_subset_model,
|
_create_subset_model,
|
||||||
create_model_v2,
|
create_model_v2,
|
||||||
)
|
)
|
||||||
@ -79,9 +77,11 @@ def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
|||||||
if isinstance(tool_schema, dict):
|
if isinstance(tool_schema, dict):
|
||||||
return tool_schema
|
return tool_schema
|
||||||
|
|
||||||
if hasattr(tool_schema, "model_json_schema"):
|
if issubclass(tool_schema, BaseModel):
|
||||||
return tool_schema.model_json_schema()
|
return tool_schema.model_json_schema()
|
||||||
return tool_schema.schema()
|
if issubclass(tool_schema, BaseModelV1):
|
||||||
|
return tool_schema.schema()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def test_unnamed_decorator() -> None:
|
def test_unnamed_decorator() -> None:
|
||||||
@ -1853,11 +1853,14 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
input_schema = tool.get_input_schema()
|
input_schema = tool.get_input_schema()
|
||||||
input_json_schema = (
|
if issubclass(input_schema, BaseModel):
|
||||||
input_schema.model_json_schema()
|
input_json_schema = input_schema.model_json_schema()
|
||||||
if hasattr(input_schema, "model_json_schema")
|
elif issubclass(input_schema, BaseModelV1):
|
||||||
else input_schema.schema()
|
input_json_schema = input_schema.schema()
|
||||||
)
|
else:
|
||||||
|
msg = "Unknown input schema type"
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
assert input_json_schema == {
|
assert input_json_schema == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {"title": "A", "type": "integer"},
|
"a": {"title": "A", "type": "integer"},
|
||||||
@ -1943,12 +1946,14 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
|||||||
|
|
||||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||||
|
|
||||||
args_schema = cast("BaseModel", foo_tool.args_schema)
|
args_schema = cast("type[BaseModel]", foo_tool.args_schema)
|
||||||
args_json_schema = (
|
if issubclass(args_schema, BaseModel):
|
||||||
args_schema.model_json_schema()
|
args_json_schema = args_schema.model_json_schema()
|
||||||
if hasattr(args_schema, "model_json_schema")
|
elif issubclass(args_schema, BaseModelV1):
|
||||||
else args_schema.schema()
|
args_json_schema = args_schema.schema()
|
||||||
)
|
else:
|
||||||
|
msg = "Unknown input schema type"
|
||||||
|
raise TypeError(msg)
|
||||||
assert args_json_schema == {
|
assert args_json_schema == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {"title": "A", "type": "integer"},
|
"a": {"title": "A", "type": "integer"},
|
||||||
@ -1960,11 +1965,13 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
|||||||
}
|
}
|
||||||
|
|
||||||
input_schema = foo_tool.get_input_schema()
|
input_schema = foo_tool.get_input_schema()
|
||||||
input_json_schema = (
|
if issubclass(input_schema, BaseModel):
|
||||||
input_schema.model_json_schema()
|
input_json_schema = input_schema.model_json_schema()
|
||||||
if hasattr(input_schema, "model_json_schema")
|
elif issubclass(input_schema, BaseModelV1):
|
||||||
else input_schema.schema()
|
input_json_schema = input_schema.schema()
|
||||||
)
|
else:
|
||||||
|
msg = "Unknown input schema type"
|
||||||
|
raise TypeError(msg)
|
||||||
assert input_json_schema == {
|
assert input_json_schema == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {"title": "A", "type": "integer"},
|
"a": {"title": "A", "type": "integer"},
|
||||||
@ -2020,7 +2027,6 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
|
|||||||
assert _is_message_content_type(obj) is expected
|
assert _is_message_content_type(obj) is expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
|
||||||
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
||||||
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
||||||
A = TypeVar("A")
|
A = TypeVar("A")
|
||||||
@ -2089,63 +2095,6 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.")
|
|
||||||
def test__get_all_basemodel_annotations_v1() -> None:
|
|
||||||
A = TypeVar("A")
|
|
||||||
|
|
||||||
class ModelA(BaseModel, Generic[A], extra="allow"):
|
|
||||||
a: A
|
|
||||||
|
|
||||||
class ModelB(ModelA[str]):
|
|
||||||
b: Annotated[ModelA[dict[str, Any]], "foo"]
|
|
||||||
|
|
||||||
class Mixin:
|
|
||||||
def foo(self) -> str:
|
|
||||||
return "foo"
|
|
||||||
|
|
||||||
class ModelC(Mixin, ModelB):
|
|
||||||
c: dict
|
|
||||||
|
|
||||||
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict}
|
|
||||||
actual = get_all_basemodel_annotations(ModelC)
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"]}
|
|
||||||
actual = get_all_basemodel_annotations(ModelB)
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
expected = {"a": Any}
|
|
||||||
actual = get_all_basemodel_annotations(ModelA)
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
expected = {"a": int}
|
|
||||||
actual = get_all_basemodel_annotations(ModelA[int])
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
D = TypeVar("D", bound=Union[str, int])
|
|
||||||
|
|
||||||
class ModelD(ModelC, Generic[D]):
|
|
||||||
d: Optional[D]
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"a": str,
|
|
||||||
"b": Annotated[ModelA[dict[str, Any]], "foo"],
|
|
||||||
"c": dict,
|
|
||||||
"d": Union[str, int, None],
|
|
||||||
}
|
|
||||||
actual = get_all_basemodel_annotations(ModelD)
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"a": str,
|
|
||||||
"b": Annotated[ModelA[dict[str, Any]], "foo"],
|
|
||||||
"c": dict,
|
|
||||||
"d": Union[int, None],
|
|
||||||
}
|
|
||||||
actual = get_all_basemodel_annotations(ModelD[int])
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_basemodel_annotations_aliases() -> None:
|
def test_get_all_basemodel_annotations_aliases() -> None:
|
||||||
class CalculatorInput(BaseModel):
|
class CalculatorInput(BaseModel):
|
||||||
a: int = Field(description="first number", alias="A")
|
a: int = Field(description="first number", alias="A")
|
||||||
@ -2226,7 +2175,6 @@ def test_create_retriever_tool() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
|
||||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic import Field as FieldV2
|
from pydantic import Field as FieldV2
|
||||||
|
@ -3,13 +3,9 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
IS_PYDANTIC_V1,
|
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
PYDANTIC_VERSION,
|
|
||||||
_create_subset_model_v2,
|
_create_subset_model_v2,
|
||||||
create_model_v2,
|
create_model_v2,
|
||||||
get_fields,
|
get_fields,
|
||||||
@ -96,50 +92,29 @@ def test_with_aliases() -> None:
|
|||||||
|
|
||||||
def test_is_basemodel_subclass() -> None:
|
def test_is_basemodel_subclass() -> None:
|
||||||
"""Test pydantic."""
|
"""Test pydantic."""
|
||||||
if IS_PYDANTIC_V1:
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV1Proper)
|
assert is_basemodel_subclass(BaseModelV2)
|
||||||
elif IS_PYDANTIC_V2:
|
assert is_basemodel_subclass(BaseModelV1)
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV2)
|
|
||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV1)
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_basemodel_instance() -> None:
|
def test_is_basemodel_instance() -> None:
|
||||||
"""Test pydantic."""
|
"""Test pydantic."""
|
||||||
if IS_PYDANTIC_V1:
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
class FooV1(BaseModelV1Proper):
|
class Foo(BaseModelV2):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
assert is_basemodel_instance(FooV1(x=5))
|
assert is_basemodel_instance(Foo(x=5))
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
class Foo(BaseModelV2):
|
class Bar(BaseModelV1):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
assert is_basemodel_instance(Foo(x=5))
|
assert is_basemodel_instance(Bar(x=5))
|
||||||
|
|
||||||
class Bar(BaseModelV1):
|
|
||||||
x: int
|
|
||||||
|
|
||||||
assert is_basemodel_instance(Bar(x=5))
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
|
||||||
def test_with_field_metadata() -> None:
|
def test_with_field_metadata() -> None:
|
||||||
"""Test pydantic with field metadata."""
|
"""Test pydantic with field metadata."""
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
@ -168,18 +143,6 @@ def test_with_field_metadata() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1")
|
|
||||||
def test_fields_pydantic_v1() -> None:
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
|
||||||
x: int
|
|
||||||
|
|
||||||
fields = get_fields(Foo)
|
|
||||||
assert fields == {"x": Foo.model_fields["x"]}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
|
||||||
def test_fields_pydantic_v2_proper() -> None:
|
def test_fields_pydantic_v2_proper() -> None:
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -190,7 +153,6 @@ def test_fields_pydantic_v2_proper() -> None:
|
|||||||
assert fields == {"x": Foo.model_fields["x"]}
|
assert fields == {"x": Foo.model_fields["x"]}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
|
||||||
def test_fields_pydantic_v1_from_2() -> None:
|
def test_fields_pydantic_v1_from_2() -> None:
|
||||||
from pydantic.v1 import BaseModel
|
from pydantic.v1 import BaseModel
|
||||||
|
|
||||||
|
@ -16,10 +16,6 @@ from langchain_core.utils import (
|
|||||||
guard_import,
|
guard_import,
|
||||||
)
|
)
|
||||||
from langchain_core.utils._merge import merge_dicts
|
from langchain_core.utils._merge import merge_dicts
|
||||||
from langchain_core.utils.pydantic import (
|
|
||||||
IS_PYDANTIC_V1,
|
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
)
|
|
||||||
from langchain_core.utils.utils import secret_from_env
|
from langchain_core.utils.utils import secret_from_env
|
||||||
|
|
||||||
|
|
||||||
@ -214,7 +210,6 @@ def test_guard_import_failure(
|
|||||||
guard_import(module_name, pip_name=pip_name, package=package)
|
guard_import(module_name, pip_name=pip_name, package=package)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
|
||||||
def test_get_pydantic_field_names_v1_in_2() -> None:
|
def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
||||||
from pydantic.v1 import Field
|
from pydantic.v1 import Field
|
||||||
@ -229,7 +224,6 @@ def test_get_pydantic_field_names_v1_in_2() -> None:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
|
||||||
def test_get_pydantic_field_names_v2_in_2() -> None:
|
def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -243,20 +237,6 @@ def test_get_pydantic_field_names_v2_in_2() -> None:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1")
|
|
||||||
def test_get_pydantic_field_names_v1() -> None:
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class PydanticModel(BaseModel):
|
|
||||||
field1: str
|
|
||||||
field2: int
|
|
||||||
alias_field: int = Field(alias="aliased_field")
|
|
||||||
|
|
||||||
result = get_pydantic_field_names(PydanticModel)
|
|
||||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
|
||||||
assert result == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_env_with_env_variable() -> None:
|
def test_from_env_with_env_variable() -> None:
|
||||||
key = "TEST_KEY"
|
key = "TEST_KEY"
|
||||||
value = "test_value"
|
value = "test_value"
|
||||||
|
Loading…
Reference in New Issue
Block a user