mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 07:07:34 +00:00
core: Rework pydantic version checks (#30653)
This pull request includes various changes to the `langchain_core` library, focusing on improving compatibility with different versions of Pydantic. The primary change involves replacing checks for Pydantic major versions with boolean flags, which simplifies the code and improves readability. This also solves ruff rule checks for [RUF048](https://docs.astral.sh/ruff/rules/map-int-version-parsing/) and [PLR2004](https://docs.astral.sh/ruff/rules/magic-value-comparison/). Key changes include: ### Compatibility Improvements: * [`libs/core/langchain_core/output_parsers/json.py`](diffhunk://#diff-5add0cf7134636ae4198a1e0df49ee332ae0c9123c3a2395101e02687c717646L22-R24): Replaced `PYDANTIC_MAJOR_VERSION` with `IS_PYDANTIC_V1` to check for Pydantic version 1. * [`libs/core/langchain_core/output_parsers/pydantic.py`](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L14-R14): Updated version checks from `PYDANTIC_MAJOR_VERSION` to `IS_PYDANTIC_V2` in the `PydanticOutputParser` class. [[1]](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L14-R14) [[2]](diffhunk://#diff-2364b5b4aee01c462aa5dbda5dc3a877dcd20f29df173ad540dc8adf8b192361L27-R27) ### Utility Enhancements: * [`libs/core/langchain_core/utils/pydantic.py`](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R23): Introduced `IS_PYDANTIC_V1` and `IS_PYDANTIC_V2` flags and deprecated the `get_pydantic_major_version` function. Updated various functions to use these flags instead of version numbers. [[1]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R23) [[2]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896R42-R78) [[3]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L90-R89) [[4]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L104-R101) [[5]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L120-R122) [[6]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L135-R132) [[7]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L149-R151) [[8]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L164-R161) [[9]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L248-R250) [[10]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L330-R335) [[11]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L356-R357) [[12]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L393-R390) [[13]](diffhunk://#diff-ff28020c5f1073a8b63bcd9d8b756a187fd682cb81935295120c63b207071896L403-R400) ### Test Updates: * [`libs/core/tests/unit_tests/output_parsers/test_openai_tools.py`](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L19-R22): Updated tests to use `IS_PYDANTIC_V1` and `IS_PYDANTIC_V2` for version checks. [[1]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L19-R22) [[2]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L532-R535) [[3]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L567-R570) [[4]](diffhunk://#diff-694cc0318edbd6bbca34f53304934062ad59ba9f5a788252ce6c5f5452489d67L602-R605) * [`libs/core/tests/unit_tests/prompts/test_chat.py`](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84R7): Replaced version tuple checks with `PYDANTIC_VERSION` comparisons. [[1]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84R7) [[2]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L35-R38) [[3]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L924-R927) [[4]](diffhunk://#diff-3e60e744842086a4f3c4b21bc83e819c3435720eab210078e77e2430fb8c7e84L935-R938) * [`libs/core/tests/unit_tests/runnables/test_graph.py`](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dR3): Simplified version checks using `PYDANTIC_VERSION`. [[1]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dR3) [[2]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dL15-R18) [[3]](diffhunk://#diff-99a290330ef40103d0ce02e52e21310d6fadea142bfdea13c94d23fc81c0bb5dL234-L239) * [`libs/core/tests/unit_tests/runnables/test_runnable.py`](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L18-R20): Introduced `PYDANTIC_VERSION_AT_LEAST_29` and `PYDANTIC_VERSION_AT_LEAST_210` for more readable version checks. [[1]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L18-R20) [[2]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L92-R99) [[3]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L230-R233) [[4]](diffhunk://#diff-06bed920c0dad0cfd41d57a8d9e47a7b56832409649c10151061a791860d5bb5L652-R655)
This commit is contained in:
parent
43b5dc7191
commit
5e418c2666
@ -19,9 +19,9 @@ from langchain_core.utils.json import (
|
|||||||
parse_json_markdown,
|
parse_json_markdown,
|
||||||
parse_partial_json,
|
parse_partial_json,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
|
||||||
|
|
||||||
if PYDANTIC_MAJOR_VERSION < 2:
|
if IS_PYDANTIC_V1:
|
||||||
PydanticBaseModel = pydantic.BaseModel
|
PydanticBaseModel = pydantic.BaseModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -11,7 +11,7 @@ from langchain_core.exceptions import OutputParserException
|
|||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
PYDANTIC_MAJOR_VERSION,
|
IS_PYDANTIC_V2,
|
||||||
PydanticBaseModel,
|
PydanticBaseModel,
|
||||||
TBaseModel,
|
TBaseModel,
|
||||||
)
|
)
|
||||||
@ -24,7 +24,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
"""The pydantic model to parse."""
|
"""The pydantic model to parse."""
|
||||||
|
|
||||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||||
if PYDANTIC_MAJOR_VERSION == 2:
|
if IS_PYDANTIC_V2:
|
||||||
try:
|
try:
|
||||||
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
||||||
return self.pydantic_object.model_validate(obj)
|
return self.pydantic_object.model_validate(obj)
|
||||||
|
@ -20,6 +20,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
from packaging import version
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
@ -41,44 +42,46 @@ from pydantic.json_schema import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
||||||
|
except ImportError:
|
||||||
|
PYDANTIC_VERSION = version.parse("0.0.0")
|
||||||
|
|
||||||
|
|
||||||
def get_pydantic_major_version() -> int:
|
def get_pydantic_major_version() -> int:
|
||||||
"""Get the major version of Pydantic."""
|
"""DEPRECATED - Get the major version of Pydantic.
|
||||||
try:
|
|
||||||
import pydantic
|
|
||||||
|
|
||||||
return int(pydantic.__version__.split(".")[0])
|
Use PYDANTIC_VERSION.major instead.
|
||||||
except ImportError:
|
"""
|
||||||
return 0
|
warnings.warn(
|
||||||
|
"get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return PYDANTIC_VERSION.major
|
||||||
|
|
||||||
|
|
||||||
def _get_pydantic_minor_version() -> int:
|
PYDANTIC_MAJOR_VERSION = PYDANTIC_VERSION.major
|
||||||
"""Get the minor version of Pydantic."""
|
PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
|
||||||
try:
|
|
||||||
import pydantic
|
|
||||||
|
|
||||||
return int(pydantic.__version__.split(".")[1])
|
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
|
||||||
except ImportError:
|
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
|
||||||
return 0
|
|
||||||
|
|
||||||
|
if IS_PYDANTIC_V1:
|
||||||
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
|
||||||
PYDANTIC_MINOR_VERSION = _get_pydantic_minor_version()
|
|
||||||
|
|
||||||
|
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
|
||||||
from pydantic.fields import FieldInfo as FieldInfoV1
|
from pydantic.fields import FieldInfo as FieldInfoV1
|
||||||
|
|
||||||
PydanticBaseModel = pydantic.BaseModel
|
PydanticBaseModel = pydantic.BaseModel
|
||||||
TypeBaseModel = type[BaseModel]
|
TypeBaseModel = type[BaseModel]
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
elif IS_PYDANTIC_V2:
|
||||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
||||||
|
|
||||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
||||||
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
|
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
@ -87,9 +90,9 @@ TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
|||||||
|
|
||||||
def is_pydantic_v1_subclass(cls: type) -> bool:
|
def is_pydantic_v1_subclass(cls: type) -> bool:
|
||||||
"""Check if the installed Pydantic version is 1.x-like."""
|
"""Check if the installed Pydantic version is 1.x-like."""
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
if IS_PYDANTIC_V1:
|
||||||
return True
|
return True
|
||||||
if PYDANTIC_MAJOR_VERSION == 2:
|
if IS_PYDANTIC_V2:
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1):
|
if issubclass(cls, BaseModelV1):
|
||||||
@ -101,7 +104,7 @@ def is_pydantic_v2_subclass(cls: type) -> bool:
|
|||||||
"""Check if the installed Pydantic version is 1.x-like."""
|
"""Check if the installed Pydantic version is 1.x-like."""
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
|
return IS_PYDANTIC_V2 and issubclass(cls, BaseModel)
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_subclass(cls: type) -> bool:
|
def is_basemodel_subclass(cls: type) -> bool:
|
||||||
@ -117,12 +120,12 @@ def is_basemodel_subclass(cls: type) -> bool:
|
|||||||
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
|
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
if IS_PYDANTIC_V1:
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
from pydantic import BaseModel as BaseModelV1Proper
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1Proper):
|
if issubclass(cls, BaseModelV1Proper):
|
||||||
return True
|
return True
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
elif IS_PYDANTIC_V2:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
@ -132,7 +135,7 @@ def is_basemodel_subclass(cls: type) -> bool:
|
|||||||
if issubclass(cls, BaseModelV1):
|
if issubclass(cls, BaseModelV1):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -146,12 +149,12 @@ def is_basemodel_instance(obj: Any) -> bool:
|
|||||||
* pydantic.BaseModel in Pydantic 2.x
|
* pydantic.BaseModel in Pydantic 2.x
|
||||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||||
"""
|
"""
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
if IS_PYDANTIC_V1:
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
from pydantic import BaseModel as BaseModelV1Proper
|
||||||
|
|
||||||
if isinstance(obj, BaseModelV1Proper):
|
if isinstance(obj, BaseModelV1Proper):
|
||||||
return True
|
return True
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
elif IS_PYDANTIC_V2:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
@ -161,7 +164,7 @@ def is_basemodel_instance(obj: Any) -> bool:
|
|||||||
if isinstance(obj, BaseModelV1):
|
if isinstance(obj, BaseModelV1):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -245,12 +248,12 @@ def _create_subset_model_v1(
|
|||||||
fn_description: Optional[str] = None,
|
fn_description: Optional[str] = None,
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a pydantic model with only a subset of model's fields."""
|
"""Create a pydantic model with only a subset of model's fields."""
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
if IS_PYDANTIC_V1:
|
||||||
from pydantic import create_model
|
from pydantic import create_model
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
elif IS_PYDANTIC_V2:
|
||||||
from pydantic.v1 import create_model # type: ignore
|
from pydantic.v1 import create_model # type: ignore
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
@ -327,7 +330,7 @@ def _create_subset_model(
|
|||||||
fn_description: Optional[str] = None,
|
fn_description: Optional[str] = None,
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create subset model using the same pydantic version as the input model."""
|
"""Create subset model using the same pydantic version as the input model."""
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
if IS_PYDANTIC_V1:
|
||||||
return _create_subset_model_v1(
|
return _create_subset_model_v1(
|
||||||
name,
|
name,
|
||||||
model,
|
model,
|
||||||
@ -335,7 +338,7 @@ def _create_subset_model(
|
|||||||
descriptions=descriptions,
|
descriptions=descriptions,
|
||||||
fn_description=fn_description,
|
fn_description=fn_description,
|
||||||
)
|
)
|
||||||
if PYDANTIC_MAJOR_VERSION == 2:
|
if IS_PYDANTIC_V2:
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
if issubclass(model, BaseModelV1):
|
if issubclass(model, BaseModelV1):
|
||||||
@ -353,11 +356,11 @@ def _create_subset_model(
|
|||||||
descriptions=descriptions,
|
descriptions=descriptions,
|
||||||
fn_description=fn_description,
|
fn_description=fn_description,
|
||||||
)
|
)
|
||||||
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
if PYDANTIC_MAJOR_VERSION == 2:
|
if IS_PYDANTIC_V2:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
@ -390,7 +393,7 @@ if PYDANTIC_MAJOR_VERSION == 2:
|
|||||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
elif PYDANTIC_MAJOR_VERSION == 1:
|
elif IS_PYDANTIC_V1:
|
||||||
from pydantic import BaseModel as BaseModelV1_
|
from pydantic import BaseModel as BaseModelV1_
|
||||||
|
|
||||||
def get_fields( # type: ignore[no-redef]
|
def get_fields( # type: ignore[no-redef]
|
||||||
@ -400,7 +403,7 @@ elif PYDANTIC_MAJOR_VERSION == 1:
|
|||||||
return model.__fields__ # type: ignore
|
return model.__fields__ # type: ignore
|
||||||
|
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
_SchemaConfig = ConfigDict(
|
_SchemaConfig = ConfigDict(
|
||||||
|
@ -16,7 +16,10 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration
|
from langchain_core.outputs import ChatGeneration
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_core.utils.pydantic import (
|
||||||
|
IS_PYDANTIC_V1,
|
||||||
|
IS_PYDANTIC_V2,
|
||||||
|
)
|
||||||
|
|
||||||
STREAMED_MESSAGES: list = [
|
STREAMED_MESSAGES: list = [
|
||||||
AIMessageChunk(content=""),
|
AIMessageChunk(content=""),
|
||||||
@ -529,7 +532,7 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
|||||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
||||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -564,7 +567,7 @@ def test_parse_with_different_pydantic_2_v1() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
||||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
def test_parse_with_different_pydantic_2_proper() -> None:
|
||||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
"""Test with pydantic.BaseModel from pydantic 2."""
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -599,7 +602,7 @@ def test_parse_with_different_pydantic_2_proper() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
|
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1")
|
||||||
def test_parse_with_different_pydantic_1_proper() -> None:
|
def test_parse_with_different_pydantic_1_proper() -> None:
|
||||||
"""Test with pydantic.BaseModel from pydantic 1."""
|
"""Test with pydantic.BaseModel from pydantic 1."""
|
||||||
import pydantic
|
import pydantic
|
||||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging import version
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
@ -32,7 +33,9 @@ from langchain_core.prompts.chat import (
|
|||||||
_convert_to_message,
|
_convert_to_message,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts.string import PromptTemplateFormat
|
from langchain_core.prompts.string import PromptTemplateFormat
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION
|
from langchain_core.utils.pydantic import (
|
||||||
|
PYDANTIC_VERSION,
|
||||||
|
)
|
||||||
from tests.unit_tests.pydantic_utils import _normalize_schema
|
from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||||
|
|
||||||
|
|
||||||
@ -921,7 +924,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
|
|||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
prompt_all_required.input_schema(input="")
|
prompt_all_required.input_schema(input="")
|
||||||
|
|
||||||
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
|
if version.parse("2.10") <= PYDANTIC_VERSION:
|
||||||
assert _normalize_schema(
|
assert _normalize_schema(
|
||||||
prompt_all_required.get_input_jsonschema()
|
prompt_all_required.get_input_jsonschema()
|
||||||
) == snapshot(name="required")
|
) == snapshot(name="required")
|
||||||
@ -932,7 +935,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
|
|||||||
assert set(prompt_optional.input_variables) == {"input"}
|
assert set(prompt_optional.input_variables) == {"input"}
|
||||||
prompt_optional.input_schema(input="") # won't raise error
|
prompt_optional.input_schema(input="") # won't raise error
|
||||||
|
|
||||||
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
|
if version.parse("2.10") <= PYDANTIC_VERSION:
|
||||||
assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot(
|
assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot(
|
||||||
name="partial"
|
name="partial"
|
||||||
)
|
)
|
||||||
|
@ -4,16 +4,17 @@ import re
|
|||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pydantic
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging import version
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import PromptTemplateFormat
|
from langchain_core.prompts.string import PromptTemplateFormat
|
||||||
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
||||||
|
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
||||||
from tests.unit_tests.pydantic_utils import _normalize_schema
|
from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||||
|
|
||||||
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
|
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_valid() -> None:
|
def test_prompt_valid() -> None:
|
||||||
@ -117,7 +118,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
|
|||||||
"This foo is a bar test baz."
|
"This foo is a bar test baz."
|
||||||
)
|
)
|
||||||
assert prompt.input_variables == ["foo", "obj"]
|
assert prompt.input_variables == ["foo", "obj"]
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
||||||
name="schema_0"
|
name="schema_0"
|
||||||
)
|
)
|
||||||
@ -144,7 +145,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
|
|||||||
is a test."""
|
is a test."""
|
||||||
)
|
)
|
||||||
assert prompt.input_variables == ["foo"]
|
assert prompt.input_variables == ["foo"]
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
||||||
name="schema_2"
|
name="schema_2"
|
||||||
)
|
)
|
||||||
@ -168,7 +169,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
|
|||||||
is a test."""
|
is a test."""
|
||||||
)
|
)
|
||||||
assert prompt.input_variables == ["foo"]
|
assert prompt.input_variables == ["foo"]
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
||||||
name="schema_3"
|
name="schema_3"
|
||||||
)
|
)
|
||||||
@ -206,7 +207,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
|
|||||||
is a test."""
|
is a test."""
|
||||||
)
|
)
|
||||||
assert prompt.input_variables == ["foo"]
|
assert prompt.input_variables == ["foo"]
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
||||||
name="schema_4"
|
name="schema_4"
|
||||||
)
|
)
|
||||||
@ -224,7 +225,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
|
|||||||
is a test.""" # noqa: W293
|
is a test.""" # noqa: W293
|
||||||
)
|
)
|
||||||
assert prompt.input_variables == ["foo"]
|
assert prompt.input_variables == ["foo"]
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
|
||||||
name="schema_5"
|
name="schema_5"
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -12,7 +13,9 @@ from langchain_core.prompts.prompt import PromptTemplate
|
|||||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||||
from langchain_core.runnables.graph import Edge, Graph, Node
|
from langchain_core.runnables.graph import Edge, Graph, Node
|
||||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION
|
from langchain_core.utils.pydantic import (
|
||||||
|
PYDANTIC_VERSION,
|
||||||
|
)
|
||||||
from tests.unit_tests.pydantic_utils import _normalize_schema
|
from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||||
|
|
||||||
|
|
||||||
@ -231,12 +234,10 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
|||||||
)
|
)
|
||||||
graph = sequence.get_graph()
|
graph = sequence.get_graph()
|
||||||
|
|
||||||
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
|
if version.parse("2.10") <= PYDANTIC_VERSION:
|
||||||
assert _normalize_schema(graph.to_json(with_schemas=True)) == snapshot(
|
assert _normalize_schema(graph.to_json(with_schemas=True)) == snapshot(
|
||||||
name="graph_with_schema"
|
name="graph_with_schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
|
|
||||||
assert _normalize_schema(graph.to_json()) == snapshot(name="graph_no_schemas")
|
assert _normalize_schema(graph.to_json()) == snapshot(name="graph_no_schemas")
|
||||||
|
|
||||||
assert graph.draw_ascii() == snapshot(name="ascii")
|
assert graph.draw_ascii() == snapshot(name="ascii")
|
||||||
|
@ -2,8 +2,8 @@ import re
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import pydantic
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging import version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -19,10 +19,9 @@ from langchain_core.runnables.config import RunnableConfig
|
|||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||||
from langchain_core.tracers import Run
|
from langchain_core.tracers import Run
|
||||||
|
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
||||||
from tests.unit_tests.pydantic_utils import _schema
|
from tests.unit_tests.pydantic_utils import _schema
|
||||||
|
|
||||||
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
|
|
||||||
|
|
||||||
|
|
||||||
def test_interfaces() -> None:
|
def test_interfaces() -> None:
|
||||||
history = InMemoryChatMessageHistory()
|
history = InMemoryChatMessageHistory()
|
||||||
@ -492,7 +491,7 @@ def test_get_output_schema() -> None:
|
|||||||
"title": "RunnableWithChatHistoryOutput",
|
"title": "RunnableWithChatHistoryOutput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
if PYDANTIC_VERSION >= (2, 11):
|
if version.parse("2.11") <= PYDANTIC_VERSION:
|
||||||
expected_schema["additionalProperties"] = True
|
expected_schema["additionalProperties"] = True
|
||||||
assert _schema(output_type) == expected_schema
|
assert _schema(output_type) == expected_schema
|
||||||
|
|
||||||
|
@ -15,9 +15,9 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import pydantic
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
from packaging import version
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
@ -89,11 +89,14 @@ from langchain_core.tracers import (
|
|||||||
RunLogPatch,
|
RunLogPatch,
|
||||||
)
|
)
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION
|
from langchain_core.utils.pydantic import (
|
||||||
|
PYDANTIC_VERSION,
|
||||||
|
)
|
||||||
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
|
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
|
||||||
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
|
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
|
||||||
|
|
||||||
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
|
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
|
||||||
|
PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION
|
||||||
|
|
||||||
|
|
||||||
class FakeTracer(BaseTracer):
|
class FakeTracer(BaseTracer):
|
||||||
@ -227,7 +230,7 @@ class FakeRetriever(BaseRetriever):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
(PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10),
|
PYDANTIC_VERSION_AT_LEAST_210,
|
||||||
reason=(
|
reason=(
|
||||||
"Only test with most recent version of pydantic. "
|
"Only test with most recent version of pydantic. "
|
||||||
"Pydantic introduced small fixes to generated JSONSchema on minor versions."
|
"Pydantic introduced small fixes to generated JSONSchema on minor versions."
|
||||||
@ -649,7 +652,7 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(
|
assert _normalize_schema(
|
||||||
RunnableLambda(aget_values_typed).get_output_jsonschema() # type: ignore
|
RunnableLambda(aget_values_typed).get_output_jsonschema() # type: ignore
|
||||||
) == snapshot(name="schema8")
|
) == snapshot(name="schema8")
|
||||||
@ -764,7 +767,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
|
|||||||
|
|
||||||
assert fake_llm_configurable.invoke("...") == "a"
|
assert fake_llm_configurable.invoke("...") == "a"
|
||||||
|
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(
|
assert _normalize_schema(
|
||||||
fake_llm_configurable.get_config_jsonschema()
|
fake_llm_configurable.get_config_jsonschema()
|
||||||
) == snapshot(name="schema2")
|
) == snapshot(name="schema2")
|
||||||
@ -791,7 +794,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
|
|||||||
text="Hello, John!"
|
text="Hello, John!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(
|
assert _normalize_schema(
|
||||||
prompt_configurable.get_config_jsonschema()
|
prompt_configurable.get_config_jsonschema()
|
||||||
) == snapshot(name="schema3")
|
) == snapshot(name="schema3")
|
||||||
@ -820,7 +823,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
|
|||||||
|
|
||||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||||
|
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(
|
assert _normalize_schema(
|
||||||
chain_configurable.get_config_jsonschema()
|
chain_configurable.get_config_jsonschema()
|
||||||
) == snapshot(name="schema4")
|
) == snapshot(name="schema4")
|
||||||
@ -865,7 +868,7 @@ def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
|
|||||||
"llm3": "a",
|
"llm3": "a",
|
||||||
}
|
}
|
||||||
|
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(
|
assert _normalize_schema(
|
||||||
chain_with_map_configurable.get_config_jsonschema()
|
chain_with_map_configurable.get_config_jsonschema()
|
||||||
) == snapshot(name="schema5")
|
) == snapshot(name="schema5")
|
||||||
@ -938,7 +941,7 @@ def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None:
|
|||||||
|
|
||||||
chain = prompt | fake_llm
|
chain = prompt | fake_llm
|
||||||
|
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(_schema(chain.config_schema())) == snapshot(
|
assert _normalize_schema(_schema(chain.config_schema())) == snapshot(
|
||||||
name="schema6"
|
name="schema6"
|
||||||
)
|
)
|
||||||
@ -990,7 +993,7 @@ def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None:
|
|||||||
|
|
||||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||||
|
|
||||||
if PYDANTIC_VERSION >= (2, 9):
|
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||||
assert _normalize_schema(
|
assert _normalize_schema(
|
||||||
chain_configurable.get_config_jsonschema()
|
chain_configurable.get_config_jsonschema()
|
||||||
) == snapshot(name="schema7")
|
) == snapshot(name="schema7")
|
||||||
@ -3089,7 +3092,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
|||||||
assert chain.middle == [RunnableLambda(passthrough)]
|
assert chain.middle == [RunnableLambda(passthrough)]
|
||||||
assert isinstance(chain.last, RunnableParallel)
|
assert isinstance(chain.last, RunnableParallel)
|
||||||
|
|
||||||
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
|
if PYDANTIC_VERSION_AT_LEAST_210:
|
||||||
assert dumps(chain, pretty=True) == snapshot
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
# Test invoke
|
# Test invoke
|
||||||
|
@ -65,7 +65,8 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
PYDANTIC_MAJOR_VERSION,
|
IS_PYDANTIC_V1,
|
||||||
|
IS_PYDANTIC_V2,
|
||||||
_create_subset_model,
|
_create_subset_model,
|
||||||
create_model_v2,
|
create_model_v2,
|
||||||
)
|
)
|
||||||
@ -2017,7 +2018,7 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
|
|||||||
assert _is_message_content_type(obj) is expected
|
assert _is_message_content_type(obj) is expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
||||||
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
||||||
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
||||||
A = TypeVar("A")
|
A = TypeVar("A")
|
||||||
@ -2086,7 +2087,7 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.")
|
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.")
|
||||||
def test__get_all_basemodel_annotations_v1() -> None:
|
def test__get_all_basemodel_annotations_v1() -> None:
|
||||||
A = TypeVar("A")
|
A = TypeVar("A")
|
||||||
|
|
||||||
@ -2214,7 +2215,7 @@ def test_create_retriever_tool() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
||||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic import Field as FieldV2
|
from pydantic import Field as FieldV2
|
||||||
|
@ -7,7 +7,9 @@ import pytest
|
|||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
PYDANTIC_MAJOR_VERSION,
|
IS_PYDANTIC_V1,
|
||||||
|
IS_PYDANTIC_V2,
|
||||||
|
PYDANTIC_VERSION,
|
||||||
_create_subset_model_v2,
|
_create_subset_model_v2,
|
||||||
create_model_v2,
|
create_model_v2,
|
||||||
get_fields,
|
get_fields,
|
||||||
@ -95,11 +97,11 @@ def test_with_aliases() -> None:
|
|||||||
|
|
||||||
def test_is_basemodel_subclass() -> None:
|
def test_is_basemodel_subclass() -> None:
|
||||||
"""Test pydantic."""
|
"""Test pydantic."""
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
if IS_PYDANTIC_V1:
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
from pydantic import BaseModel as BaseModelV1Proper
|
||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV1Proper)
|
assert is_basemodel_subclass(BaseModelV1Proper)
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
elif IS_PYDANTIC_V2:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
@ -107,20 +109,20 @@ def test_is_basemodel_subclass() -> None:
|
|||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV1)
|
assert is_basemodel_subclass(BaseModelV1)
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
def test_is_basemodel_instance() -> None:
|
def test_is_basemodel_instance() -> None:
|
||||||
"""Test pydantic."""
|
"""Test pydantic."""
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
if IS_PYDANTIC_V1:
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
from pydantic import BaseModel as BaseModelV1Proper
|
||||||
|
|
||||||
class FooV1(BaseModelV1Proper):
|
class FooV1(BaseModelV1Proper):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
assert is_basemodel_instance(FooV1(x=5))
|
assert is_basemodel_instance(FooV1(x=5))
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
elif IS_PYDANTIC_V2:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
@ -134,11 +136,11 @@ def test_is_basemodel_instance() -> None:
|
|||||||
|
|
||||||
assert is_basemodel_instance(Bar(x=5))
|
assert is_basemodel_instance(Bar(x=5))
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||||
def test_with_field_metadata() -> None:
|
def test_with_field_metadata() -> None:
|
||||||
"""Test pydantic with field metadata."""
|
"""Test pydantic with field metadata."""
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
@ -167,7 +169,7 @@ def test_with_field_metadata() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Only tests Pydantic v1")
|
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1")
|
||||||
def test_fields_pydantic_v1() -> None:
|
def test_fields_pydantic_v1() -> None:
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -178,7 +180,7 @@ def test_fields_pydantic_v1() -> None:
|
|||||||
assert fields == {"x": Foo.model_fields["x"]} # type: ignore[index]
|
assert fields == {"x": Foo.model_fields["x"]} # type: ignore[index]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||||
def test_fields_pydantic_v2_proper() -> None:
|
def test_fields_pydantic_v2_proper() -> None:
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -189,7 +191,7 @@ def test_fields_pydantic_v2_proper() -> None:
|
|||||||
assert fields == {"x": Foo.model_fields["x"]}
|
assert fields == {"x": Foo.model_fields["x"]}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||||
def test_fields_pydantic_v1_from_2() -> None:
|
def test_fields_pydantic_v1_from_2() -> None:
|
||||||
from pydantic.v1 import BaseModel
|
from pydantic.v1 import BaseModel
|
||||||
|
|
||||||
|
@ -16,7 +16,10 @@ from langchain_core.utils import (
|
|||||||
guard_import,
|
guard_import,
|
||||||
)
|
)
|
||||||
from langchain_core.utils._merge import merge_dicts
|
from langchain_core.utils._merge import merge_dicts
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_core.utils.pydantic import (
|
||||||
|
IS_PYDANTIC_V1,
|
||||||
|
IS_PYDANTIC_V2,
|
||||||
|
)
|
||||||
from langchain_core.utils.utils import secret_from_env
|
from langchain_core.utils.utils import secret_from_env
|
||||||
|
|
||||||
|
|
||||||
@ -211,7 +214,7 @@ def test_guard_import_failure(
|
|||||||
guard_import(module_name, pip_name=pip_name, package=package)
|
guard_import(module_name, pip_name=pip_name, package=package)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
||||||
def test_get_pydantic_field_names_v1_in_2() -> None:
|
def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
||||||
from pydantic.v1 import Field
|
from pydantic.v1 import Field
|
||||||
@ -226,7 +229,7 @@ def test_get_pydantic_field_names_v1_in_2() -> None:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
||||||
def test_get_pydantic_field_names_v2_in_2() -> None:
|
def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -240,7 +243,7 @@ def test_get_pydantic_field_names_v2_in_2() -> None:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1")
|
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1")
|
||||||
def test_get_pydantic_field_names_v1() -> None:
|
def test_get_pydantic_field_names_v1() -> None:
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user