mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
core[minor]: Add support for pydantic 2 to utility to get fields (#24899)
Add compatibility for pydantic 2 for a utility function. This will help push some small changes to master, so they don't have to be kept track of on a separate branch.
This commit is contained in:
@@ -6,8 +6,13 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from langchain_core import utils
|
||||
from langchain_core.utils import check_package_version, guard_import
|
||||
from langchain_core.utils import (
|
||||
check_package_version,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -171,3 +176,46 @@ def test_guard_import_failure(
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
)
|
||||
assert exc_info.value.msg == err_msg
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||
def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel # pydantic: ignore
|
||||
from pydantic.v1 import Field # pydantic: ignore
|
||||
|
||||
class PydanticV1Model(PydanticV1BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticV1Model)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Requires pydantic 2")
|
||||
def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||
from pydantic import BaseModel, Field # pydantic: ignore
|
||||
|
||||
class PydanticModel(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticModel)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Requires pydantic 1")
|
||||
def test_get_pydantic_field_names_v1() -> None:
|
||||
from pydantic import BaseModel, Field # pydantic: ignore
|
||||
|
||||
class PydanticModel(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticModel)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
Reference in New Issue
Block a user